Keywords

1 Introduction

Developing automatic, accurate, and robust medical image segmentation methods have been one of the principal problems in medical imaging as it is essential for computer-aided diagnosis and image-guided surgery systems. Segmentation of organs or lesion from a medical scan helps clinicians make an accurate diagnosis, plan the surgical procedure, and propose treatment strategies. Following the popularity of deep convolutional neural networks (ConvNets) in computer vision, ConvNets were quickly adopted for medical image segmentation. Networks like U-Net [15], V-Net [13], 3D U-Net [3], Res-UNet [25], Dense-UNet [11], Y-Net [12], U-Net++ [28], KiU-Net [19, 20] and U-Net3+ [7] have been proposed specifically for performing image and volumetric segmentation for various medical imaging modalities. These methods achieve impressive performance on many difficult datasets, proving the effectiveness of ConvNets in learning discriminative features to segment the organ or lesion from a medical scan.

ConvNets are currently the basic building blocks of most methods proposed for image segmentation. However, they lack the ability to model long-range dependencies present in an image. More precisely, in ConvNets each convolutional kernel attends to only a local-subset of pixels in the whole image and forces the network to focus on local patterns rather than the global context. There have been works that have focused on modeling long-range dependencies for ConvNets using image pyramids [26], atrous convolutions [2] and attention mechanisms [8]. However, it can be noted that there is still a scope of improvement for modeling long-range dependencies as the majority of previous methods do not focus on this aspect for medical image segmentation tasks.

Fig. 1.
figure 1

(a) Input Ultrasound of in vivo preterm neonatal brain ventricle. Predictions by (b) U-Net, (c) Res-UNet, (d) MedT, and (e) Ground Truth. The red box highlights the region which are miss-classified by ConvNet based methods due to lack of learned long-range dependencies. The ground truth here was segmented by an expert clinician. Although it shows some bleeding inside the ventricle area, it does not correspond to the segmented area. This information is correctly captured by transformer-based models. (Color figure online)

To first understand why long-range dependencies matter for medical images, we visualize an example ultrasound scan of a preterm neonate and segmentation predictions of brain ventricles from the scan in Fig. 1. For a network to provide an efficient segmentation, it should be able to understand which pixels correspond to the mask and which to the background. As the background of the image is scattered, learning long-range dependencies between the pixels corresponding to the background can help in the network to prevent miss-classifying a pixel as the mask leading to reduction of false positives (considering 0 as background and 1 as segmentation mask). Similarly, whenever the segmentation mask is large, learning long-range dependencies between the pixels corresponding to the mask is also helpful in making efficient predictions. In Fig. 1(b) and (c), we can see that the convolutional networks miss-classify the background as a brain ventricle while the proposed transformer-based method does not make that mistake. This happens as our proposed method learns long-range dependencies of the pixel regions with that of the background.

In many natural language processing (NLP) applications, transformers [4] have shown to be able to encode long-range dependencies. This is due to the self-attention mechanism which finds the dependency between given sequential input. Following their popularity in NLP applications, transformers have been adopted to computer vision applications very recently [5, 18]. With regard to transformers for segmentation tasks, Axial-Deeplab [22] utilized the axial attention module [6], which factorizes 2D self-attention into two 1D self-attentions and introduced position-sensitive axial attention design for segmentation. In Segmentation Transformer (SETR) [27], a transformer was used as encoder which inputs a sequence of image patches and a ConvNet was used as decoder resulting in a powerful segmentation model. In medical image segmentation, transformer-based models have not been explored much. The closest works are the ones that use attention mechanisms to boost the performance [14, 24]. However, the encoder and decoder of these networks still have convolutional layers as the main building blocks.

It was observed that the transformer-based models work well only when they are trained on large-scale datasets [5]. This becomes problematic while adopting transformers for medical imaging tasks as the number of images, with corresponding labels, available for training in any medical dataset is relatively scarce. Labeling process is also expensive and requires expert knowledge. Specifically, training with fewer images causes difficulty in learning positional encoding for the images. To this end, we propose a gated position-sensitive axial attention mechanism where we introduce four gates that control the amount of information the positional embedding supply to key, query, and value. These gates are learnable parameters which make the proposed mechanism to be applied to any dataset of any size. Depending on the size of the dataset, these gates would learn whether the number of images would be sufficient enough to learn proper position embedding. Based on whether the information learned by the positional embedding is useful or not, the gate parameters either converge to 0 or to some higher value. Furthermore, we propose a Local-Global (LoGo) training strategy, where we use a shallow global branch and a deep local branch that operates on the patches of the medical image. This strategy improves the segmentation performance as we do not only operate on the entire image but focus on finer details present in the local patches. Finally, we propose Medical Transformer (MedT), which uses our gated position-sensitive axial attention as the building blocks and adopts our LoGo training strategy.

In summary, this paper (1) proposes a gated position-sensitive axial attention mechanism that works well even on smaller datasets, (2) introduces Local-Global (LoGo) training methodology for transformers which is effective, (3) proposes Medical-Transformer (MedT) which is built upon the above two concepts proposed specifically for medical image segmentation, and (4) successfully improves the performance for medical image segmentation tasks over convolutional networks and fully attention architectures on three different datasets.

Fig. 2.
figure 2

(a) The main architecture diagram of MedT which uses LoGo strategy for training. (b) The gated axial transformer layer which is used in MedT. (c) Gated Axial Attention layer which is the basic building block of both height and width gated multi-head attention blocks found in the gated axial transformer layer.

2 Medical Transformer (MedT)

2.1 Self-attention Overview

Let us consider an input feature map \(x \in \mathbb {R}^{C_{in} \times H \times W}\) with height H, weight W and channels \(C_{in}\). The output \(y \in \mathbb {R}^{C_{out} \times H \times W}\) of a self-attention layer is computed with the help of projected input using the following equation:

$$\begin{aligned} y_{ij} \ = \ \sum _{h=1}^{H} \sum _{w=1}^{W} {\text {softmax}} \left( q_{ij}^{T} k_{hw}\right) v_{hw}, \end{aligned}$$
(1)

where queries \(q=W_Q x\), keys \(k=W_K x\) and values \(v=W_V x\) are all projections computed from the input x. Here, \(q_{ij}, k_{ij}, v_{ij}\) denote query, key and value at any arbitrary location \(i \in \{1, \dots , H\}\) and \(j \in \{1, \dots , W\}\), respectively. The projection matrices \(W_Q, W_K, W_V \in \mathbb {R}^{C_{in} \times C_{out}}\) are learnable. As shown in Eq. 1, the values v are pooled based on global affinities calculated using \({\text {softmax}} (q^T k)\). Hence, unlike convolutions the self-attention mechanism is able to capture non-local information from the entire feature map. However, computing such affinities are computationally very expensive and with increased feature map size it often becomes infeasible to use self-attention for vision model architectures. Moreover, unlike convolutional layer, self-attention layer does not utilize any positional information while computing the non-local context. Positional information is often useful in vision models to capture structure of an object.

Axial-Attention. To overcome the computational complexity of calculating the affinities, self-attention is decomposed into two self-attention modules. The first module performs self-attention on the feature map height axis and the second one operates on the width axis. This is referred to as axial attention [6]. The axial attention consequently applied on height and width axis effectively model original self-attention mechanism with much better computational efficacy. To add positional bias while computing affinities through self-attention mechanism, a position bias term is added to make the affinities sensitive to the positional information [16]. This bias term is often referred to as relative positional encodings. These positional encodings are typically learnable through training and have been shown to have the capacity to encode spatial structure of the image. Wang et al. [22] combined both the axial-attention mechanism and positional encodings to propose an attention-based model for image segmentation. Additionally, unlike previous attention model which utilizes relative positional encodings only for queries, Wang et al. [22] proposed to use it for all queries, keys and values. This additional position bias in query, key and value is shown to capture long-range interaction with precise positional information [22]. For any given input feature map x, the updated self-attention mechanism with positional encodings along with width axis can be written as:

$$\begin{aligned} y_{ij} \ = \ \sum _{w=1}^{W} {\text {softmax}} \left( q_{ij}^{T} k_{iw} + q_{ij}^{T} r^q_{iw} + k_{iw}^{T} r^k_{iw} \right) (v_{iw} + r^v_{iw}), \end{aligned}$$
(2)

where the formulation in Eq. 2 follows the attention model proposed in [22] and \(r^q, r^k, r^v \in \mathbb {R}^{W \times W}\) for the width-wise axial attention model. Note that Eq. 2 describes the axial attention applied along the width axis of the tensor. A similar formulation is also used to apply axial attention along the height axis and together they form a single self-attention model that is computationally efficient.

2.2 Gated Axial-Attention

We discussed the benefits of using the axial-attention mechanism proposed in [22] for visual recognition. Specifically, the axial-attention proposed in [22] is able to compute non-local context with good computational efficiency, able to encode positional bias into the mechanism and enables the ability to encode long-range interaction within an input feature map. However, their model is evaluated on large-scale segmentation datasets and hence it is easier for the axial-attention to learn positional bias at key, query and value. We argue that for experiments with small-scale datasets, which is often the case in medical image segmentation, the positional bias is difficult to learn and hence will not always be accurate in encoding long-range interactions. In the case where the learned relative positional encodings are not accurate enough, adding them to the respective key, query and value tensor would result in reduced performance. Hence, we propose a modified axial-attention block that can control the influence positional bias can exert in the encoding of non-local context. With the proposed modification the self-attention mechanism applied on the width axis can be formally written as:

$$\begin{aligned} y_{ij} \ = \ \sum _{w=1}^{W} {\text {softmax}} \left( q_{ij}^{T} k_{iw} + G_Q q_{ij}^{T} r^q_{iw} + G_K k_{iw}^{T} r^k_{iw}\right) ( G_{V1} v_{iw} + G_{V2} r^v_{iw}), \end{aligned}$$
(3)

where the self-attention formula closely follows Eq. 2 with added gating mechanism. Also, \(G_Q, G_K, G_{V1}, G_{V2} \in \mathbb {R}\) are learnable parameters and together they create gating mechanism which control influence of the learned relative positional encodings have on encoding non-local context. Typically, if a relative positional encoding is learned accurately, the gating mechanism will assign it high weight compared to the ones which are not learned accurately. Figure 2(c) illustrates the feed-forward in a typical gated axial attention layer.

2.3 Local-Global Training

It is evident that a transformer on patches is faster but patch-wise training alone is not sufficient for the tasks like medical image segmentation. Patch-wise training restricts the network in learning any information or dependencies for inter-patch pixels. To improve the overall understanding of the image, we propose to use two branches in the network, i.e., a global branch which works on the original resolution of the image, and a local branch which operates on patches of the image. In the global branch, we reduce the number of gated axial transformer layers as we observe that the first few blocks of the proposed transformer model is sufficient to model long range dependencies. In the local branch, we create 16 patches of size \(I/4 \times I/4\) of the image where I is the dimensions of the original image. In the local branches, each patch is feed forwarded through the network and the output feature maps are re-sampled based on their location to get the output feature maps. The output feature maps of both of the branches are then added and passed through a \(1 \times 1\) convolution layer to produce the output segmentation mask. This strategy improves the performance as the global branch focuses on high-level information and the local branch can focus on finer details. The proposed Medical Transformer (MedT) uses gated axial attention layer as the basic building block and uses LoGo strategy for training. It is illustrated in Fig. 2(a). More details on the architecture and an ablation study with regard to the architecture can be found in the supplementary file.

3 Experiments and Results

3.1 Dataset Details

We use Brain anatomy segmentation (ultrasound) [21, 23], Gland segmentation (microscopic) [17] and MoNuSeg (microscopic) [9, 10] datasets for evaluating our method. More details about the datasets can be found in the supplementary.

3.2 Implementation Details

We use binary cross-entropy (CE) loss between the prediction and the ground truth to train our network and can be written as:

$$\begin{aligned} \mathcal {L}_{CE(p,\hat{p})} = - \left( \frac{1}{wh} \sum _{x=0}^{w-1}\sum _{y=0}^{h-1}(p(x,y) \log (\hat{p}(x,y)) ) + (1-p(x,y))\log (1-\hat{p}(x,y))\right) \end{aligned}$$

where w and h are the dimensions of the image, p(xy) corresponds to the pixel in the image and \(\hat{p}(x,y)\) denotes the output prediction at a specific location (xy). The training details are provided in the supplementary document.

For baseline comparisons, we first run experiments on both convolutional and transformer-based methods. For convolutional baselines, we compare with fully convolutional network (FCN) [1], U-Net [15], U-Net++ [28] and Res-Unet [25]. For transformer-based baselines, we use Axial-Attention U-Net with residual connections inspired from [22]. For our proposed method, we experiment with all the individual contributions. In gated axial attention network, we use axial attention U-Net with all its axial attention layers replaced with the proposed gated axial attention layers. In LoGo, we perform local global training for axial attention U-Net without using the gated axial attention layers. In MedT, we use gated axial attention as the basic building block for global branch and axial attention without positional encoding for local branch.

3.3 Results

Table 1. Quantitative comparison of the proposed methods with convolutional and transformer based baselines in terms of F1 and IoU scores.

For quantitative analysis, we use F1 and IoU scores for comparison. The quantitative results are tabulated in Table 1. It can be noted that for datasets with relatively more images like Brain US, fully attention (transformer) based baseline performs better than convolutional baselines. For GlaS and MoNuSeg datasets, convolutional baselines perform better than fully attention baselines as it is difficult to train fully attention models with less data [5]. The proposed method is able to overcome such issue with the help of gated axial attention and LoGo both individually perform better than the other methods. Our final architecture MedT performs better than Gated axial attention, LoGo and all the previous methods. The improvements over fully attention baselines are 0.92 %, 4.76 % and 2.72 % for Brain US, GlaS and MoNuSeg datasets, respectively. Improvements over the best convolutional baseline are 1.32 %, 2.19 % and 0.06 %. All of these values are in terms of F1 scores. For the ablation study, we use the Brain US data for all our experiments. The results for the same has been tabulated in Table 2.

Table 2. Ablation study
Fig. 3.
figure 3

Qualitative results on sample test images from Brain US, Glas and MoNuSeg datasets. The red box highlights regions where exactly MedT performs better than the other methods in comparison making better use of long range dependencies. (Color figure online)

Furthermore, we visualize the predictions from U-Net [15], Res-UNet [25], Axial Attention U-Net [22] and our proposed method MedT in Fig. 3. It can be seen that the predictions of MedT captures the long range dependencies really well. For example, in the second row of Fig. 3, we can observe that the small segmentation mask highlighted on red box goes undetected in all the convolutional baselines. However, as fully attention model encodes long range dependencies, it learns to segment well thanks to the encoded global context. In the first and fourth row, other methods make false predictions at the highlighted regions as those pixels are in close proximity to the segmentation mask. As our method takes into account pixel-wise dependencies that are encoded with gating mechanism, it is able to learn those dependencies better than the axial attention U-Net. This makes our predictions more precise as they do not miss-classify pixels near the segmentation mask.

4 Conclusion

In this work, we explored the use of transformer-based architectures for medical image segmentation. Specifically, we propose a gated axial attention layer which is used as the building block for multi-head attention models. We also proposed a LoGo training strategy to train the image in both full resolution as well in patches. The global branch helps learn global context features by modeling long-range dependencies, where as the local branch focus on finer features by operating on patches. Using these, we propose MedT (Medical Transformer) which has gated axial attention as its main building block for the encoder and uses LoGo strategy for training. Unlike other transformer-based model the proposed method does not require pre-training on large-scale datasets. Finally, we conduct extensive experiments on three datasets where we achieve a good performance for MedT over ConvNets and other related transformer-based architectures.