1 Introduction

Convolutional networks have revolutionized the computer vision field with outstanding feature representation capability. Currently, the convolutional encoder-decoder architectures have made substantial progress in position-sensitive tasks, like semantic segmentation [6, 11, 14, 17, 20]. The used convolutional operation captures texture features by gathering local information from neighborhood pixels. To aggregate the local filter responses globally, these models stack multiple convolutional layers and expand the receptive field through down-samplings. Despite the advances, there are two inherent limitations of this paradigm. First, the convolution only gathers information from neighborhood pixels and lacks the ability to capture long-range (global) dependency explicitly [5, 25, 26]. Second, the size and shape of convolution kernels are typically fixed thus they can not adapt to the input content [15].

Transformer architecture using the self-attention mechanism has emerged to be successful in natural language processing (NLP) [18] with its capability of capturing long-range dependency. Self-attention is a computational primitive that implements pairwise entity interactions with a context aggregation mechanism, which has the ability to capture long-range associative features. It allows the network to aggregate relevant features dynamically based on the input content. Preliminary studies with simple forms of self-attention have shown its usefulness in segmentation [4, 16], detection [24] and reconstruction [9].

Fig. 1.
figure 1

(a) The hybrid architecture of the proposed UTNet. The proposed efficient self-attention mechanism and relative positional encoding allow us to apply Transformer to aggregate global context information from multiple scales in both encoder and decoder. (b) Pre-activation residual basic block. (c) The structure of Transformer encoder block.

Although the application of image-based Transformer is promising, training and deploying of Transformer architecture has several daunting challenges. First, the self-attention mechanism has \(O(n^2)\) time and space complexity with respect to sequence length, resulting in substantial overheads of training and inference. Previous works attempt to reduce the complexity of self-attention [10, 28], but are still far from perfection. Due to the time complexity, the standard self-attention can be only applied patch-wise, e.g. [3, 27] encode images using \(16\times 16\) flattened image patches as input sequences, or on top of feature maps from CNN backbone, which are already down-sampled into low-resolution [4, 22]. However, for position-sensitive tasks like medical image segmentation, high-resolution feature plays a vital role since most mis-segmented areas are located around the boundary of the region-of-interest. Second, Transformers do not have inductive bias for images and can not perform well on a small-scale dataset [3]. For example, Transformer can be beneficial from pre-training through a large-scale dataset like full JFT-300M [3]. But even with pre-training on ImageNet, Transformer is still worse than the ResNet [7, 12], not to mention medical image datasets with much less available amounts of medical data.

In this paper, we propose a U-shape hybrid Transformer Network: UTNet, integrating the strength of convolution and self-attention strategies for medical image segmentation. The major goal is to apply convolution layers to extract local intensity features to avoid large-scale pretraining of Transformer, while using self-attention to capture long-range associative information. We follow the standard design of UNet, but replace the last convolution of the building block in each resolution (except for the highest one) to the proposed Transformer module. Towards enhanced quality of segmentation, we seek to apply self-attention to extract detailed long-range relationships on high-resolution feature maps. To this end, we propose an efficient self-attention mechanism, which reduces the overall complexity significantly from \(O(n^2)\) to approximate O(n) in both time and space. Furthermore, we use a relative position encoding in the self-attention module to learn content-position relationships in medical images. Our UTNet demonstrates superior segmentation performance and robustness in the multi-label, multi-vendor cardiac magnetic resonance imaging cohort. Given the design of UTNet, our framework holds the promise to generalize well on other medical image segmentations.

2 Method

2.1 Revisiting Self-attention Mechanism

The Transformer is built upon the multi-head self-attention (MHSA) module [18], which allows the model to jointly infer attention from different representation subspaces. The results from multiple heads are concatenated and then transformed with a feed-forward network. In this study, we use 4 heads and the dimension of multi-head is not presented for simplicity in the following formulation and in the figure. Consider an input feature map \(X\in \mathcal {R}^{C\times H\times W}\), where H,W are the spatial height, width and C is the number of channels. Three \(1\times 1\) convolutions are used to project X to query, key, value embeddings: \(\mathbf{Q}, K, V \in \mathcal {R}^{d\times H\times W}\), where d is the dimension of embedding in each head. The \(\mathbf{Q}, K, V \) is then flatten and transposed into sequences with size \(n\times d\), where \(n=HW\). The output of the self-attention is a scaled dot-product:

$$\begin{aligned} \mathrm{Attention}(\mathbf {Q, K, V})=\underbrace{\mathrm{softmax}(\frac{\mathbf {QK}^{\mathsf {T}}}{\sqrt{d}})}_{P}\mathbf {V} \end{aligned}$$
(1)

Note that \(P\in \mathcal {R}^{n\times n}\) is named context aggregating matrix, or similarity matrix. To be specific, the i-th query’s context aggregating matrix is \(P_i=\mathrm{softmax}(\frac{\mathbf {q}_i\mathbf {K}^\mathsf {T}}{\sqrt{d}})\), \(P_i\in \mathcal {R}^{1\times n}\), which computes the normalized pair-wise dot production between \(q_i\) and each element in the keys. The context aggregating matrix is then used as the weights to gather context information from the values. In this way, self-attention intrinsically has the global receptive field and is good at capturing long-range dependence. Also, the context aggregating matrix is adaptive to input content for better feature aggregation. However, the dot-product of \(n\times d\) matrices leads to \(O(n^2d)\) complexity. Typically, n is much larger than d when the resolution of feature map is large, thus the sequence length dominates the self-attention computation and makes it infeasible to apply self-attention in high-resolution feature maps, e.g. \(n=256\) for \(16\times 16\) feature maps, and \(n=16384\) for \(128\times 128\) feature maps.

Fig. 2.
figure 2

The proposed efficient multi-head self-attention (MHSA). (a) The MHSA used in the Transformer encoder. (b) The MHSA used in the Transformer decoder. They share similar concepts, but (b) takes two inputs, including the high-resolution features from skip connections of the encoder, and the low-resolution features from the decoder.

2.2 Efficient Self-attention Mechanism

As images are highly structured data, most pixels in high-resolution feature maps within local footprint share similar features except for the boundary regions. Therefore, the pair-wise attention computation among all pixels is highly inefficient and redundant. From a theoretical perspective, self-attention is essentially low rank for long sequences [21], which indicates that most information is concentrated in the largest singular values. Inspired by this finding, we propose an efficient self-attention mechanism for our task as seen in Fig. 2.

The main idea is to use two projections to project key and value: \(\mathbf {K, V}\in \mathcal {R}^{n\times d}\) into low-dimensional embedding: \(\overline{\mathbf {K}}, \overline{\mathbf {V}}\in \mathcal {R}^{k\times d}\), where \(k=hw\ll n\), h and w are the reduced size of feature map after sub-sampling. The proposed efficient self-attention is now:

$$\begin{aligned} \mathrm{Attention}(\mathbf{Q} , \overline{\mathbf {K}}, \overline{\mathbf {V}})=\underbrace{\mathrm{softmax}(\frac{\mathbf {Q}\overline{\mathbf {K}}^{\mathsf {T}}}{\sqrt{d}})}_{\overline{P}:n\times k}\underbrace{\overline{\mathbf{V }}}_{k\times d} \end{aligned}$$
(2)

By doing so, the computational complexity is reduced to O(nkd). Notably, the projection to low-dimensional embedding can be any down-sampling operations, such as average/max pooling, or strided convolutions. In our implementation, we use \(1\times 1\) convolution followed by a bilinear interpolation to down-sample the feature map, and the reduced size is 8.

2.3 Relative Positional Encoding

Standard self-attention module totally discards the position information and is perturbation equivariant [1], making it ineffective for modeling image contents that are highly structured. The sinusoidal embedding in previous works [13] does not have the property of translation equivariance in convolutional layers. Therefore, we use the 2-dimensional relative position encoding by adding relative height and width information [1]. The pair-wise attention logit before softmax using relative position encoding between pixel \(i=(i_x,i_y)\) and pixel \(j=(j_x,j_y)\) :

$$\begin{aligned} l_{i,j}=\frac{q_i^\mathsf {T}}{\sqrt{d}}(k_j+r_{j_x-i_x}^{W}+r_{j_y-i_y}^{H}) \end{aligned}$$
(3)

where \(q_i\) is the query vector of pixel i, \(k_i\) is the key vector for pixel j, \(r_{j_x-i_x}^{W}\) and \(r_{j_y-i_y}^{H}\) are learnable embeddings for relative width \(j_x-i_x\) and relative height \(j_y-i_y\) respectively. Similar to the efficient self-attention, the relative width and height are computed after low-dimensional projection. The efficient self-attention including relative position embedding is:

$$\begin{aligned} \mathrm{Attention}(\mathbf {Q}, \overline{\mathbf {K}}, \overline{\mathbf {V}})=\underbrace{\mathrm{softmax}(\frac{\mathbf {Q}\overline{\mathbf {K}}^{\mathsf {T}}+\mathbf {S}_H^{rel}+\mathbf {S}_W^{rel}}{\sqrt{d}})}_{\overline{P}:n\times k}\underbrace{\overline{\mathbf {V}}}_{k\times d} \end{aligned}$$
(4)

where \(\mathbf{S} _H^{rel}, \mathbf{S} _W^{rel}\in \mathcal {R}^{HW\times hw}\) are matrics of relative position logits along height and width dimensions that satisfy \(\mathbf{S} _H^{rel}[i,j]=q_i^{\mathsf {T}}r^H_{j_y-i_y}, \mathbf{S} _W^{rel}[i,j]=q_i^{\mathsf {T}}r^W_{j_x-i_x}\).

2.4 Network Architecture

Figure 1 highlights the architecture of UTNet. We seek to combine the strength from both convolution and self-attention mechanism. Therefore, the hybrid architecture can leverage the inductive bias of image from convolution to avoid large-scale pretraining, as well as the capability of Transformer to capture long-range relationships. Because the mis-segmented region usually locates at the boundary of region-of-interest, the high-resolution context information could play a vital role in segmentation. As a result, our focus is placed on the proposed self-attention module, making it feasible to handle large-size feature maps efficiently. Instead of naively integrating the self-attention module on top of the feature maps from the CNN backbone, we apply the Transformer module to each level of the encoder and decoder to collect long-range dependency from multiple scales. Note that we do not apply Transformer on the original resolution, as adding Transformer module in the very shallow layers of the network does not help in experiments but introduces additional computation. A possible reason is that the shallow layers of the network focus more on detailed textures, where gathering global context may not be informative. The building block of UTNet is shown in Fig. 1 (b) and (c), including residual basic block and Transformer block. For both blocks, we use the pre-activation setting for identity mapping in the short cut. This identity mapping has been proven to be effective in vision [8] and NLP tasks [19].

3 Experiment

3.1 Experiment Setup

We systematically evaluate the UTNet on the multi-label, multi-vendor cardiac magnetic resonance imaging (MRI) challenge cohort [2], including the segmentation of left ventricle (LV), right ventricle (RV), and left ventricular myocardium (MYO). In the training set, we have 150 annotated images from two different MRI vendors (75 images of each vendor), including A: Siemens; B: Philips. In the testing set, we have 200 images from 4 different MRI vendors (50 images of each vendor), including A: Siemens; B: Philips; C: GE; D: Canon, where vendor C and D are completely absent in the training set (we discard the unlabeled data). The MRI scans from different vendors have marked differences in appearance, allowing us to measure model robustness and compare with other models under different settings. Specifically, we have performed two experiments to highlight the performance and robustness of UTNet. First, we report primary results with training and testing data are both from the same vendor A. Second, we further measure the cross-vendor robustness of models. Such setting is more challenging since the training and testing data are from independent vendors. We report Dice score and Hausdorff distance of each model to compare the performance.

3.2 Implementation Detail

For data preprocessing, we resample the in-plane spacing to \(1.2\times 1.2\) mm, while keeping the spacing along the z-axis unchanged. We train all models from scratch for 150 epochs. We use the exponentially learning rate scheduler with a base learning rate of 0.05. We use the SGD optimizer with a batch size of 16 on one GPU, momentum and weight decay are set to 0.9 and \(1e-4\) respectively. Data augmentation is applied on the fly during model training, including random rotation, scaling, translation, additive noise and gamma transformation. All images are randomly cropped to \(256\times 256\) before entering the models. We use the combine of Dice loss and cross-entropy loss to train all networks.

Table 1. Segmentation performance in term of Dice score and efficiency comparison. All models are trained and tested using data from vendor A. The Hausdorff distant result is reported in the supplementary.
Fig. 3.
figure 3

Ablation study. (a) Effect of different self-attention position. (b) Effect of reduce size and projection of efficient self-attention. (c) Effect of Transformer encoder, Transformer decoder, and the relative positional encoding.

3.3 Segmentation Results

We compare the performance of UTNet with multiple state-of-the-art segmentation models. UNet [14] builds on top of the fully convolutional networks with a U-shaped architecture to capture context information. The ResUNet is similar to UNet in architecture, but it uses residual blocks as the building block. CBAM [23] uses two sequential convolutional modules to infer channel and spatial attention to refine intermediate feature maps adaptively. Dual attention network [4] uses two kinds of self-attention to model the semantic inter-dependencies in spatial and channel dimensions, respectively. We have implemented CBAM and dual attention in ResUNet backbone for better comparison. The dual attention is only applied in the feature maps after 4 down-samplings due to its quadratic complexity.

As seen in Table 1, UTNet demonstrates leading performance in all segmentation outcomes (LV, MYO and RV). By introducing residual connections, ResUNet is slightly improved than the original UNet. The spatial and channel attention from CBAM are inferred from convolutional layers, it still suffers from limited receptive field. Thus CBAM only has limited improvement compared with ResUNet. We also recognize that dual-attention approach was almost the same as ResUNet, as it suffers from quadratic complexity that can not process higher resolution feature maps to fix errors in the segmentation boundary. Meanwhile, our UTNet presents less parameters than dual-attention approach and it can capture global context information from high-resolution feature maps.

Table 2. Robustness comparison, measured with Dice score. All models are trained on data from vendor A,B, and are tested on data from vendor A,B,C,D. The number in brackets of C and D indicates the performance drop compared with the average of A and B.
Fig. 4.
figure 4

Hard cases visualization on unseen testing data from vendor C and D. First two rows and the bottom two rows present the results and a zoom-in view of vendor C and D, respectively. The outline indicates the ground-truth annotation. Best viewed in color with LV(green), MYO(yellow), and RV(red). The test case from vendor C is blur due to motion artifacts, while the test case from vendor D is noisy and has low contrast in the boundary. Only UTNet provides consistent segmentation, which demonstrates its robustness. More visualization of segmentation outcomes are presented in the supplementary. (Color figure online)

Ablation Study. Figure 3 (a) shows the performance of different self-attention positions. The number in the x-axis indicates the level where self-attention is places, e.g., ‘34’ means the level where 3 and 4 times down-samplings are performed. As the level goes up, the self-attention can gather more fine-grained detail information with increased performance. However, the curve saturates when adding to the original resolution. We reason this as the very shallow layer tends to be more focused on local texture, where global context information is not informative anymore. Figure 3 (b) shows the result of efficient self-attention’s reduced size of 4, 8, 16. The reduced size 8 results in the best performance. The interpolation down-sampling is slightly better than using max-pooling. Figure 3 (c) shows the effect of the Transformer encoder, decoder, and the relative positional encoding using the optimal hyper-parameter from (a) and (b). The combination of the Transformer encoder and decoder gives the optimal performance. The relative positional encoding also plays a vital role, as removing it causes a large performance drop.

For a head-to-head comparison with standard self-attention on space and time complexity, we further apply dual attention in four resolutions (1, 2, 3, 4, same as UTNet), and use the same input image size and batch size (\(256\times 256\times 16\)) to test the inference time and memory consumption. UTNet gains superior advantage over dual attention with quadratic complexity, where GPU memory: 3.8 GB vs 36.9 GB and time: 0.146 s vs 0.243 s.

Robustness Analysis. Table 2 shows results on training models with data from vendor A and B, and then test the models on vendor A, B, C, and D, respectively. When viewing results on C and D vendors, competing approaches suffer from vendor differences while UTNet retains competitive performance. This observation can probably be attributed to the design of self-attention on multiple levels of feature maps and the content-position attention, allowing UTNet to be better focused on global context information instead of only local textures. Figure 4 further shows that UTNet displays the most consistent results of boundaries, while the other three methods are unable to capture subtle characteristics of boundaries, especially for RV and MYO regions in cardiac MRI.

4 Conclusion

We have proposed a U-shape hybrid Transformer network (UTNet) to merge advances of convolutional layers and self-attention mechanism for medical image segmentation. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. The novel self-attention allows us to extend operations at different levels of the network in both encoder and decoder for better capturing long-range dependencies. We believe that this design will help richly-parameterized Transformer models become more accessible in medical vision applications. Also, the ability to handle long sequences efficiently opens up new possibilities for the use of the UTNet on more downstream medical image tasks.