Abstract
Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at different scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from \(O(n^2)\) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi-label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.
Access provided by Autonomous University of Puebla. Download conference paper PDF
Similar content being viewed by others
1 Introduction
Convolutional networks have revolutionized the computer vision field with outstanding feature representation capability. Currently, the convolutional encoder-decoder architectures have made substantial progress in position-sensitive tasks, like semantic segmentation [6, 11, 14, 17, 20]. The used convolutional operation captures texture features by gathering local information from neighborhood pixels. To aggregate the local filter responses globally, these models stack multiple convolutional layers and expand the receptive field through down-samplings. Despite the advances, there are two inherent limitations of this paradigm. First, the convolution only gathers information from neighborhood pixels and lacks the ability to capture long-range (global) dependency explicitly [5, 25, 26]. Second, the size and shape of convolution kernels are typically fixed thus they can not adapt to the input content [15].
Transformer architecture using the self-attention mechanism has emerged to be successful in natural language processing (NLP) [18] with its capability of capturing long-range dependency. Self-attention is a computational primitive that implements pairwise entity interactions with a context aggregation mechanism, which has the ability to capture long-range associative features. It allows the network to aggregate relevant features dynamically based on the input content. Preliminary studies with simple forms of self-attention have shown its usefulness in segmentation [4, 16], detection [24] and reconstruction [9].
Although the application of image-based Transformer is promising, training and deploying of Transformer architecture has several daunting challenges. First, the self-attention mechanism has \(O(n^2)\) time and space complexity with respect to sequence length, resulting in substantial overheads of training and inference. Previous works attempt to reduce the complexity of self-attention [10, 28], but are still far from perfection. Due to the time complexity, the standard self-attention can be only applied patch-wise, e.g. [3, 27] encode images using \(16\times 16\) flattened image patches as input sequences, or on top of feature maps from CNN backbone, which are already down-sampled into low-resolution [4, 22]. However, for position-sensitive tasks like medical image segmentation, high-resolution feature plays a vital role since most mis-segmented areas are located around the boundary of the region-of-interest. Second, Transformers do not have inductive bias for images and can not perform well on a small-scale dataset [3]. For example, Transformer can be beneficial from pre-training through a large-scale dataset like full JFT-300M [3]. But even with pre-training on ImageNet, Transformer is still worse than the ResNet [7, 12], not to mention medical image datasets with much less available amounts of medical data.
In this paper, we propose a U-shape hybrid Transformer Network: UTNet, integrating the strength of convolution and self-attention strategies for medical image segmentation. The major goal is to apply convolution layers to extract local intensity features to avoid large-scale pretraining of Transformer, while using self-attention to capture long-range associative information. We follow the standard design of UNet, but replace the last convolution of the building block in each resolution (except for the highest one) to the proposed Transformer module. Towards enhanced quality of segmentation, we seek to apply self-attention to extract detailed long-range relationships on high-resolution feature maps. To this end, we propose an efficient self-attention mechanism, which reduces the overall complexity significantly from \(O(n^2)\) to approximate O(n) in both time and space. Furthermore, we use a relative position encoding in the self-attention module to learn content-position relationships in medical images. Our UTNet demonstrates superior segmentation performance and robustness in the multi-label, multi-vendor cardiac magnetic resonance imaging cohort. Given the design of UTNet, our framework holds the promise to generalize well on other medical image segmentations.
2 Method
2.1 Revisiting Self-attention Mechanism
The Transformer is built upon the multi-head self-attention (MHSA) module [18], which allows the model to jointly infer attention from different representation subspaces. The results from multiple heads are concatenated and then transformed with a feed-forward network. In this study, we use 4 heads and the dimension of multi-head is not presented for simplicity in the following formulation and in the figure. Consider an input feature map \(X\in \mathcal {R}^{C\times H\times W}\), where H,W are the spatial height, width and C is the number of channels. Three \(1\times 1\) convolutions are used to project X to query, key, value embeddings: \(\mathbf{Q}, K, V \in \mathcal {R}^{d\times H\times W}\), where d is the dimension of embedding in each head. The \(\mathbf{Q}, K, V \) is then flatten and transposed into sequences with size \(n\times d\), where \(n=HW\). The output of the self-attention is a scaled dot-product:
Note that \(P\in \mathcal {R}^{n\times n}\) is named context aggregating matrix, or similarity matrix. To be specific, the i-th query’s context aggregating matrix is \(P_i=\mathrm{softmax}(\frac{\mathbf {q}_i\mathbf {K}^\mathsf {T}}{\sqrt{d}})\), \(P_i\in \mathcal {R}^{1\times n}\), which computes the normalized pair-wise dot production between \(q_i\) and each element in the keys. The context aggregating matrix is then used as the weights to gather context information from the values. In this way, self-attention intrinsically has the global receptive field and is good at capturing long-range dependence. Also, the context aggregating matrix is adaptive to input content for better feature aggregation. However, the dot-product of \(n\times d\) matrices leads to \(O(n^2d)\) complexity. Typically, n is much larger than d when the resolution of feature map is large, thus the sequence length dominates the self-attention computation and makes it infeasible to apply self-attention in high-resolution feature maps, e.g. \(n=256\) for \(16\times 16\) feature maps, and \(n=16384\) for \(128\times 128\) feature maps.
2.2 Efficient Self-attention Mechanism
As images are highly structured data, most pixels in high-resolution feature maps within local footprint share similar features except for the boundary regions. Therefore, the pair-wise attention computation among all pixels is highly inefficient and redundant. From a theoretical perspective, self-attention is essentially low rank for long sequences [21], which indicates that most information is concentrated in the largest singular values. Inspired by this finding, we propose an efficient self-attention mechanism for our task as seen in Fig. 2.
The main idea is to use two projections to project key and value: \(\mathbf {K, V}\in \mathcal {R}^{n\times d}\) into low-dimensional embedding: \(\overline{\mathbf {K}}, \overline{\mathbf {V}}\in \mathcal {R}^{k\times d}\), where \(k=hw\ll n\), h and w are the reduced size of feature map after sub-sampling. The proposed efficient self-attention is now:
By doing so, the computational complexity is reduced to O(nkd). Notably, the projection to low-dimensional embedding can be any down-sampling operations, such as average/max pooling, or strided convolutions. In our implementation, we use \(1\times 1\) convolution followed by a bilinear interpolation to down-sample the feature map, and the reduced size is 8.
2.3 Relative Positional Encoding
Standard self-attention module totally discards the position information and is perturbation equivariant [1], making it ineffective for modeling image contents that are highly structured. The sinusoidal embedding in previous works [13] does not have the property of translation equivariance in convolutional layers. Therefore, we use the 2-dimensional relative position encoding by adding relative height and width information [1]. The pair-wise attention logit before softmax using relative position encoding between pixel \(i=(i_x,i_y)\) and pixel \(j=(j_x,j_y)\) :
where \(q_i\) is the query vector of pixel i, \(k_i\) is the key vector for pixel j, \(r_{j_x-i_x}^{W}\) and \(r_{j_y-i_y}^{H}\) are learnable embeddings for relative width \(j_x-i_x\) and relative height \(j_y-i_y\) respectively. Similar to the efficient self-attention, the relative width and height are computed after low-dimensional projection. The efficient self-attention including relative position embedding is:
where \(\mathbf{S} _H^{rel}, \mathbf{S} _W^{rel}\in \mathcal {R}^{HW\times hw}\) are matrics of relative position logits along height and width dimensions that satisfy \(\mathbf{S} _H^{rel}[i,j]=q_i^{\mathsf {T}}r^H_{j_y-i_y}, \mathbf{S} _W^{rel}[i,j]=q_i^{\mathsf {T}}r^W_{j_x-i_x}\).
2.4 Network Architecture
Figure 1 highlights the architecture of UTNet. We seek to combine the strength from both convolution and self-attention mechanism. Therefore, the hybrid architecture can leverage the inductive bias of image from convolution to avoid large-scale pretraining, as well as the capability of Transformer to capture long-range relationships. Because the mis-segmented region usually locates at the boundary of region-of-interest, the high-resolution context information could play a vital role in segmentation. As a result, our focus is placed on the proposed self-attention module, making it feasible to handle large-size feature maps efficiently. Instead of naively integrating the self-attention module on top of the feature maps from the CNN backbone, we apply the Transformer module to each level of the encoder and decoder to collect long-range dependency from multiple scales. Note that we do not apply Transformer on the original resolution, as adding Transformer module in the very shallow layers of the network does not help in experiments but introduces additional computation. A possible reason is that the shallow layers of the network focus more on detailed textures, where gathering global context may not be informative. The building block of UTNet is shown in Fig. 1 (b) and (c), including residual basic block and Transformer block. For both blocks, we use the pre-activation setting for identity mapping in the short cut. This identity mapping has been proven to be effective in vision [8] and NLP tasks [19].
3 Experiment
3.1 Experiment Setup
We systematically evaluate the UTNet on the multi-label, multi-vendor cardiac magnetic resonance imaging (MRI) challenge cohort [2], including the segmentation of left ventricle (LV), right ventricle (RV), and left ventricular myocardium (MYO). In the training set, we have 150 annotated images from two different MRI vendors (75 images of each vendor), including A: Siemens; B: Philips. In the testing set, we have 200 images from 4 different MRI vendors (50 images of each vendor), including A: Siemens; B: Philips; C: GE; D: Canon, where vendor C and D are completely absent in the training set (we discard the unlabeled data). The MRI scans from different vendors have marked differences in appearance, allowing us to measure model robustness and compare with other models under different settings. Specifically, we have performed two experiments to highlight the performance and robustness of UTNet. First, we report primary results with training and testing data are both from the same vendor A. Second, we further measure the cross-vendor robustness of models. Such setting is more challenging since the training and testing data are from independent vendors. We report Dice score and Hausdorff distance of each model to compare the performance.
3.2 Implementation Detail
For data preprocessing, we resample the in-plane spacing to \(1.2\times 1.2\) mm, while keeping the spacing along the z-axis unchanged. We train all models from scratch for 150 epochs. We use the exponentially learning rate scheduler with a base learning rate of 0.05. We use the SGD optimizer with a batch size of 16 on one GPU, momentum and weight decay are set to 0.9 and \(1e-4\) respectively. Data augmentation is applied on the fly during model training, including random rotation, scaling, translation, additive noise and gamma transformation. All images are randomly cropped to \(256\times 256\) before entering the models. We use the combine of Dice loss and cross-entropy loss to train all networks.
3.3 Segmentation Results
We compare the performance of UTNet with multiple state-of-the-art segmentation models. UNet [14] builds on top of the fully convolutional networks with a U-shaped architecture to capture context information. The ResUNet is similar to UNet in architecture, but it uses residual blocks as the building block. CBAM [23] uses two sequential convolutional modules to infer channel and spatial attention to refine intermediate feature maps adaptively. Dual attention network [4] uses two kinds of self-attention to model the semantic inter-dependencies in spatial and channel dimensions, respectively. We have implemented CBAM and dual attention in ResUNet backbone for better comparison. The dual attention is only applied in the feature maps after 4 down-samplings due to its quadratic complexity.
As seen in Table 1, UTNet demonstrates leading performance in all segmentation outcomes (LV, MYO and RV). By introducing residual connections, ResUNet is slightly improved than the original UNet. The spatial and channel attention from CBAM are inferred from convolutional layers, it still suffers from limited receptive field. Thus CBAM only has limited improvement compared with ResUNet. We also recognize that dual-attention approach was almost the same as ResUNet, as it suffers from quadratic complexity that can not process higher resolution feature maps to fix errors in the segmentation boundary. Meanwhile, our UTNet presents less parameters than dual-attention approach and it can capture global context information from high-resolution feature maps.
Ablation Study. Figure 3 (a) shows the performance of different self-attention positions. The number in the x-axis indicates the level where self-attention is places, e.g., ‘34’ means the level where 3 and 4 times down-samplings are performed. As the level goes up, the self-attention can gather more fine-grained detail information with increased performance. However, the curve saturates when adding to the original resolution. We reason this as the very shallow layer tends to be more focused on local texture, where global context information is not informative anymore. Figure 3 (b) shows the result of efficient self-attention’s reduced size of 4, 8, 16. The reduced size 8 results in the best performance. The interpolation down-sampling is slightly better than using max-pooling. Figure 3 (c) shows the effect of the Transformer encoder, decoder, and the relative positional encoding using the optimal hyper-parameter from (a) and (b). The combination of the Transformer encoder and decoder gives the optimal performance. The relative positional encoding also plays a vital role, as removing it causes a large performance drop.
For a head-to-head comparison with standard self-attention on space and time complexity, we further apply dual attention in four resolutions (1, 2, 3, 4, same as UTNet), and use the same input image size and batch size (\(256\times 256\times 16\)) to test the inference time and memory consumption. UTNet gains superior advantage over dual attention with quadratic complexity, where GPU memory: 3.8 GB vs 36.9 GB and time: 0.146 s vs 0.243 s.
Robustness Analysis. Table 2 shows results on training models with data from vendor A and B, and then test the models on vendor A, B, C, and D, respectively. When viewing results on C and D vendors, competing approaches suffer from vendor differences while UTNet retains competitive performance. This observation can probably be attributed to the design of self-attention on multiple levels of feature maps and the content-position attention, allowing UTNet to be better focused on global context information instead of only local textures. Figure 4 further shows that UTNet displays the most consistent results of boundaries, while the other three methods are unable to capture subtle characteristics of boundaries, especially for RV and MYO regions in cardiac MRI.
4 Conclusion
We have proposed a U-shape hybrid Transformer network (UTNet) to merge advances of convolutional layers and self-attention mechanism for medical image segmentation. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. The novel self-attention allows us to extend operations at different levels of the network in both encoder and decoder for better capturing long-range dependencies. We believe that this design will help richly-parameterized Transformer models become more accessible in medical vision applications. Also, the ability to handle long sequences efficiently opens up new possibilities for the use of the UTNet on more downstream medical image tasks.
References
Bello, I., Zoph, B., Vaswani, A., Shlens, J., Le, Q.V.: Attention augmented convolutional networks. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3286–3295 (2019)
Campello, V.M., Palomares, J.F.R., Guala, A., Marakas, M., Friedrich, M., Lekadir, K.: Multi-Centre, Multi-Vendor & Multi-Disease Cardiac Image Segmentation Challenge (March 2020)
Dosovitskiy, A., et al.: An image is worth 16x16 words: transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020)
Fu, J., et al.: Dual attention network for scene segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3146–3154 (2019)
Gao, Y., et al.: Focusnetv 2: imbalanced large and small organ segmentation with adversarial shape constraint for head and neck CT images. Med. Image Anal. 67, 101831 (2021)
Gao, Y., Liu, C., Zhao, L.: Multi-resolution path CNN with deep supervision for intervertebral disc localization and segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 309–317. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8_35
He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016)
He, K., Zhang, X., Ren, S., Sun, J.: Identity mappings in deep residual networks. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 630–645. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46493-0_38
Huang, Q., Yang, D., Wu, P., Qu, H., Yi, J., Metaxas, D.: MRI reconstruction via cascaded channel-wise attention network. In: 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019), pp. 1622–1626. IEEE (2019)
Huang, Z., Wang, X., Huang, L., Huang, C., Wei, Y., Liu, W.: CCNET: criss-cross attention for semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 603–612 (2019)
Isensee, F., Jaeger, P.F., Kohl, S.A., Petersen, J., Maier-Hein, K.H.: nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nat. Methods 18(2), 203–211 (2021)
Kolesnikov, A., et al.: Big transfer (bit): General visual representation learning. arXiv preprint arXiv:1912.11370 6(2), 8 (2019)
Parmar, N., et al.: Image transformer. In: International Conference on Machine Learning, pp. 4055–4064. PMLR (2018)
Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28
Schlemper, J., et al.: Attention gated networks: learning to leverage salient regions in medical images. Med. Image Anal. 53, 197–207 (2019)
Sinha, A., Dolz, J.: Multi-scale self-guided attention for medical image segmentation. IEEE J. Biomed. Health Inform. 25(1), 121–130 (2020)
Tajbakhsh, N., Jeyaseelan, L., Li, Q., Chiang, J.N., Wu, Z., Ding, X.: Embracing imperfect datasets: a review of deep learning solutions for medical image segmentation. Med. Image Anal. 63, 101693 (2020)
Vaswani, A., et al.: Attention is all you need. In: NIPS (2017)
Wang, Q., Li, B., Xiao, T., Zhu, J., Li, C., Wong, D.F., Chao, L.S.: Learning deep transformer models for machine translation. arXiv preprint arXiv:1906.01787 (2019)
Wang, S., et al.: Central focused convolutional neural networks: developing a data-driven model for lung nodule segmentation. Med. Image Anal. 40, 172–183 (2017)
Wang, S., Li, B., Khabsa, M., Fang, H., Ma, H.: Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768 (2020)
Wang, X., Girshick, R., Gupta, A., He, K.: Non-local neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7794–7803 (2018)
Woo, S., Park, J., Lee, J.Y., Kweon, I.S.: CBAM: convolutional block attention module. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 3–19 (2018)
Yi, J., Wu, P., Jiang, M., Huang, Q., Hoeppner, D.J., Metaxas, D.N.: Attentive neural cell instance segmentation. Med. Image Anal. 55, 228–240 (2019). https://doi.org/10.1016/j.media.2019.05.004
Yu, F., Koltun, V.: Multi-scale context aggregation by dilated convolutions. arXiv preprint arXiv:1511.07122 (2015)
Zhao, H., Shi, J., Qi, X., Wang, X., Jia, J.: Pyramid scene parsing network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2881–2890 (2017)
Zheng, S., et al.: Rethinking semantic segmentation from a sequence-to-sequence perspective with transformers. arXiv preprint arXiv:2012.15840 (2020)
Zhu, Z., Xu, M., Bai, S., Huang, T., Bai, X.: Asymmetric non-local neural networks for semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 593–602 (2019)
Acknowledgement
This research was supported in part by NSF: IIS 1703883, NSF IUCRC CNS-1747778 and funding from SenseBrain, CCF-1733843, IIS-1763523, IIS-1849238, MURI- Z8424104 -440149 and NIH: 1R01HL127661-01 and R01HL127661-05. and in part by Centre for Perceptual and Interactive Intellgience (CPII) Limited, Hong Kong SAR.
Author information
Authors and Affiliations
Corresponding author
Editor information
Editors and Affiliations
1 Electronic supplementary material
Below is the link to the electronic supplementary material.
Rights and permissions
Copyright information
© 2021 Springer Nature Switzerland AG
About this paper
Cite this paper
Gao, Y., Zhou, M., Metaxas, D.N. (2021). UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation. In: de Bruijne, M., et al. Medical Image Computing and Computer Assisted Intervention – MICCAI 2021. MICCAI 2021. Lecture Notes in Computer Science(), vol 12903. Springer, Cham. https://doi.org/10.1007/978-3-030-87199-4_6
Download citation
DOI: https://doi.org/10.1007/978-3-030-87199-4_6
Published:
Publisher Name: Springer, Cham
Print ISBN: 978-3-030-87198-7
Online ISBN: 978-3-030-87199-4
eBook Packages: Computer ScienceComputer Science (R0)