Keywords

1 Introduction

There are over 120 types of brain tumors that affect the human brain [27]. As we enter the era of Artificial Intelligence (AI) for healthcare, AI-based intervention for diagnosis and surgical pre-assessment of tumors is at the verge of becoming a necessity rather than a luxury. Elaborate characterization of brain tumors with techniques such as volumetric analysis is useful to study their progression and assist in pre-surgical planning [17]. In addition to surgical applications, characterization of delineated tumors can be directly utilized for the prediction of life expectancy [32]. Brain tumor segmentation is at the forefront of all such applications.

Brain tumors are categorized into primary and secondary tumor types. Primary brain tumors originate from brain cells, while secondary tumors metastasize into the brain from other organs. The most common primary brain tumors are gliomas, which arise from brain glial cells and are characterized into low-grade (LGG) and high-grade (HGG) subtypes. High grade gliomas are an aggressive type of malignant brain tumors that grow rapidly and typically require surgery and radiotherapy and have poor survival prognosis [40]. As a reliable diagnostic tool, Magnetic Resonance Imaging (MRI) plays a vital role in monitoring and surgery planning for brain tumor analysis. Typically, several complimentary 3D MRI modalities, such as T1, T1 with contrast agent (T1c), T2 and Fluid-attenuated Inversion Recovery (FLAIR), are required to emphasize different tissue properties and areas of tumor spread. For instance, gadolinium as the contrast agent emphasizes hyperactive tumor sub-regions in the T1c MRI modality [15].

Furthermore, automated medical image segmentation techniques [18] have shown prominence for providing an accurate and reproducible solution for brain tumor delineation. Recently, deep learning-based brain tumor segmentation techniques [19, 20, 30, 31] have achieved state-of-the-art performance in various benchmarks [2, 7, 34]. These advances are mainly due to the powerful feature extraction capabilities of Convolutional Neural Networks (CNN)s. However, the limited kernel size of CNN-based techniques restricts their capability of learning long-range dependencies that are critical for accurate segmentation of tumors that appear in various shapes and sizes. Although several efforts [10, 23] have tried to address this limitation by increasing the receptive field of the convolutional kernels, the effective receptive field is still limited to local regions.

Recently, transformer-based models have shown prominence in various domains such as natural language processing and computer vision [13, 14, 37]. In computer vision, Vision Transformers [14] (ViT)s have demonstrated state-of-the-art performance on various benchmarks. Specifically, self-attention module in ViT-based models allows for modeling long-range information by pairwise interaction between token embeddings and hence leading to more effective local and global contextual representations [33]. In addition, ViTs have achieved success in effective learning of pretext tasks for self-supervised pre-training in various applications [8, 9, 35]. In medical image analysis, UNETR [16] is the first methodology that utilizes a ViT as its encoder without relying on a CNN-based feature extractor. Other approaches [38, 39] have attempted to leverage the power of ViTs as a stand-alone block in their architectures which otherwise consist of CNN-based components. However, UNETR has shown better performance in terms of both accuracy and efficiency in different medical image segmentation tasks [16].

Recently, Swin transformers [24, 25] have been proposed as a hierarchical vision transformer that computes self-attention in an efficient shifted window partitioning scheme. As a result, Swin transformers are suitable for various downstream tasks wherein the extracted multi-scale features can be leveraged for further processing. In this work, we propose a novel architecture termed Swin UNEt TRansformers (Swin UNETR), which utilizes a U-shaped network with a Swin transformer as the encoder and connects it to a CNN-based decoder at different resolutions via skip connections. We validate the effectiveness of our approach for the task of multi-modal 3D brain tumor segmentation in the 2021 edition of the Multi-modal Brain Tumor Segmentation Challenge (BraTS). Our model is one of the top-ranking methods in the validation phase and has demonstrated competitive performance in the testing phase.

2 Related Work

In the previous BraTS challenges, ensembles of U-Net shaped architectures have achieved promising results for multi-modal brain tumor segmentation. Kamnitsas et al. [21] proposed a robust segmentation model by aggregating the outputs of various CNN-based models such as 3D U-Net [12], 3D FCN [26] and Deep Medic [22]. Subsequently, Myronenko et al. [30] introduced SegResNet, which utilizes a residual encoder-decoder architecture in which an auxiliary branch is used to reconstruct the input data with a variational auto-encoder as a surrogate task. Zhou et al. [42] proposed to use an ensemble of different CNN-based networks by taking into account the multi-scale contextual information through an attention block. Zhou et al. [20] used a two-stage cascaded approach consisting of U-Net models wherein the first stage computes a coarse segmentation prediction which will be refined by the second stage. Furthermore, Isensee et al. [19] proposed the nnU-Net model and demonstrated that a generic U-Net architecture with minor modifications is enough to achieve competitive performance in multiple BraTS challenges.

Transformer-based models have recently gained a lot of attraction in computer vision [14, 24, 41] and medical image analysis [11, 16]. Chen et al. [11] introduced a 2D U-Net architecture that benefits from a ViT in the bottleneck of the network. Wang et al. [38] extended this approach for 3D brain tumor segmentation. In addition, Xie et al. [39] proposed to use a ViT-based model with deformable transformer layers between its CNN-based encoder and decoder by processing the extracted features at different resolutions. Different from these approaches, Hatamizadeh et al. [16] proposed the UNETR architecture in which a ViT-based encoder, which directly utilizes 3D input patches, is connected to a CNN-based decoder. UNETR has shown promising results for brain tumor segmentation using the MSD dataset [1]. Unlike the UNETR model, our proposed Swin UNETR architecture uses a Swin transformer encoder which extracts feature representations at several resolutions with a shifted windowing mechanism for computing the self-attention. We demonstrate that Swin transformers [24] have a great capability of learning multi-scale contextual representations and modeling long-range dependencies in comparison to ViT-based approaches with fixed resolution.

Fig. 1.
figure 1

Overview of the Swin UNETR architecture. The input to our model is 3D multi-modal MRI images with 4 channels. The Swin UNETR creates non-overlapping patches of the input data and uses a patch partition layer to create windows with a desired size for computing the self-attention. The encoded feature representations in the Swin transformer are fed to a CNN-decoder via skip connection at multiple resolutions. Final segmentation output consists of 3 output channels corresponding to ET, WT and TC sub-regions.

3 Swin UNETR

3.1 Encoder

We illustrate the architecture of Swin UNETR in Fig. 1. The input to the Swin UNETR model \(\mathcal {X} \in \mathbb {R}^{H\times {W}\times {D}\times {S}}\) is a token with a patch resolution of \((H^{\prime },W^{\prime },D^{\prime })\) and dimension of \(H^{\prime } \times W^{\prime }\times D^{\prime }\times S\). We first utilize a patch partition layer to create a sequence of 3D tokens with dimension of \(\left\lceil \frac{H}{H^{\prime }}\right\rceil \times \left\lceil \frac{W}{W^{\prime }}\right\rceil \times \left\lceil \frac{D}{D^{\prime }}\right\rceil \) and project them into an embedding space with dimension C. The self-attention is computed into non-overlapping windows that are created in the partitioning stage for efficient token interaction modeling. Figure 2 shows the shifted windowing mechanism for subsequent layers. Specifically, we utilize windows of size \(M\times M\times M\) to evenly partition a 3D token into \(\left\lceil \frac{H^{\prime }}{M}\right\rceil \times \left\lceil \frac{W^{\prime }}{M}\right\rceil \times \left\lceil \frac{D^{\prime }}{M}\right\rceil \) regions at a given layer l in the transformer encoder. Subsequently, in layer \(l+1\), the partitioned window regions are shifted by \(\left( \left\lfloor \frac{M}{2}\right\rfloor ,\left\lfloor \frac{M}{2}\right\rfloor ,\left\lfloor \frac{M}{2}\right\rfloor \right) \) voxels. In subsequent layers of l and \(l+1\) in the encoder, the outputs are calculated as

$$\begin{aligned} \begin{array}{l} \hat{{z}}^{l}=\text {W-MSA}(\text {LN}({z}^{l-1}))+{z}^{l-1} \\ {z}^{l}=\text {MLP}(\text {LN}(\hat{{z}}^{l}))+\hat{{z}}^{l} \\ \hat{{z}}^{l+1}=\text {SW-MSA}(\text {LN}({z}^{l}))+{z}^{l} \\ {z}^{l+1}=\text {MLP}(\text {LN}(\hat{{z}}^{l+1}))+\hat{{z}}^{l+1}. \end{array} \end{aligned}$$
(1)

Here, \(\text {W-MSA}\) and \(\text {SW-MSA}\) are regular and window partitioning multi-head self-attention modules respectively; \(\hat{{z}}^{l}\) and \(\hat{{z}}^{l+1}\) denote the outputs of \(\text {W-MSA}\) and \(\text {SW-MSA}\); \(\text {MLP}\) and \(\text {LN}\) denote layer normalization and Multi-Layer Perceptron respectively. For efficient computation of the shifted window mechanism, we leverage a 3D cyclic-shifting [24] and compute self-attention according to

$$\begin{aligned} \text {Attention}(Q, K, V) = \text {Softmax}\left( \frac{QK^{\top }}{\sqrt{d}}\right) V. \end{aligned}$$
(2)

In which QKV denote queries, keys, and values respectively; d represents the size of the query and key.

The Swin UNETR encoder has a patch size of \(2 \times 2 \times 2\) and a feature dimension of \(2\times 2\times 2\times 4 =32\), taking into account the multi-modal MRI images with 4 channels. The size of the embedding space C is set to 48 in our encoder. Furthermore, the Swin UNETR encoder has 4 stages which comprise of 2 transformer blocks at each stage. Hence, the total number of layers in the encoder is \(L=8\). In stage 1, a linear embedding layer is utilized to create \(\frac{H}{2} \times \frac{W}{2} \times \frac{D}{2}\) 3D tokens. To maintain the hierarchical structure of the encoder, a patch merging layer is utilized to decrease the resolution of feature representations by a factor of 2 at the end of each stage. In addition, a patch merging layer groups patches with resolution \(2 \times 2 \times 2\) and concatenates them, resulting in a 4C-dimensional feature embedding. The feature size of the representations are subsequently reduced to 2C with a linear layer. Stage 2, stage 3 and stage 4, with resolutions of \(\frac{H}{4} \times \frac{W}{4} \times \frac{D}{4}\), \(\frac{H}{8} \times \frac{W}{8} \times \frac{D}{8}\) and \(\frac{H}{16} \times \frac{W}{16} \times \frac{D}{16}\) respectively, follow the same network design.

3.2 Decoder

Swin UNETR has a U-shaped network design in which the extracted feature representations of the encoder are used in the decoder via skip connections at each resolution. At each stage i (\(i \in \{0,1,2,3,4\})\) in the encoder and the bottleneck (\(i=5\)), the output feature representations are reshaped into size \(\frac{H}{2^{i}} \times \frac{W}{2^{i}} \times \frac{D}{2^{i}}\) and fed into a residual block comprising of two \(3 \times 3 \times 3\) convolutional layers that are normalized by instance normalization [36] layers. Subsequently, the resolution of the feature maps are increased by a factor of 2 using a deconvolutional layer and the outputs are concatenated with the outputs of the previous stage. The concatenated features are then fed into another residual block as previously described. The final segmentation outputs are computed by using a \(1\times 1\times 1\) convolutional layer and a sigmoid activation function.

Fig. 2.
figure 2

Overview of the shifted windowing mechanism. Note that \(8 \times 8 \times 8\) 3D tokens and \(4 \times 4 \times 4\) window size are illustrated.

3.3 Loss Function

We use the soft Dice loss function [29] which is computed in a voxel-wise manner as

$$\begin{aligned} \begin{aligned} \mathcal {L}(G,Y)&= 1-\frac{2}{J}\sum _{j=1}^{J}\frac{\sum _{i=1}^{I} G_{i,j}Y_{i,j} }{\sum _{i=1}^{I}G^{2}_{i,j}+ \sum _{i=1}^{I}Y^{2}_{i,j}}. \end{aligned} \end{aligned}$$
(3)

where I denotes voxels numbers; J is classes number; \(Y_{i,j}\) and \(G_{i,j}\) denote the probability of output and one-hot encoded ground truth for class j at voxel i, respectively.

Table 1. Swin UNETR configurations.
Fig. 3.
figure 3

A typical segmentation example of the predicted labels whic are overlaid on T1, T1c, T2 and FLAIR MRI axial slices in each row. The first two rows depict \(\sim \)75th percentile performance based on the Dice score. Rows 3 and 4 depict \(\sim \)50th percentile performance while the last two rows are at \(\sim \)25th percentile performance. The image intensities are on a gray color scale. The blue, red and green colors correspond to TC, ET and WT sub-regions respectively. Note that all samples have been selected from the BraTS 2021 validation set. (Color figure online)

3.4 Implementation Details

Swin UNETR is implemented using PyTorchFootnote 1 and MONAIFootnote 2 and trained on a DGX-1 cluster with 8 NVIDIA V100 GPUs. Table 1 details the configurations of Swin UNETR architecture, number of parameters and FLOPs. The learning rate is set to 0.0008. We normalize all input images to have zero mean and unit standard deviation according to non-zero voxels. Random patches of \(128\times 128\times 128\) were cropped from 3D image volumes during training. We apply a random axis mirror flip with a probability of 0.5 for all 3 axes. Additionally, we apply data augmentation transforms of random per channel intensity shift in the range \((-0.1,0.1)\), and random scale of intensity in the range (0.9, 1.1) to input image channels. The batch size per GPU was set to 1. All models were trained for a total of 800 epochs with a linear warmup and using a cosine annealing learning rate scheduler. Fonr inference, we use a sliding window approach with an overlapping of 0.7 for neighboring voxels.

3.5 Dataset and Model Ensembling

The BraTS challenge aims to evaluate state-of-the-art methods for the semantic segmentation of brain tumors by providing a 3D MRI dataset with voxel-wise ground truth labels that are annotated by physicians [3,4,5,6, 28]. The BraTS 2021 challenge training dataset includes 1251 subjects, each with four 3D MRI modalities: a) native (T1) and b) post-contrast T1-weighted (T1Gd), c) T2-weighted (T2), and d) T2 Fluid-attenuated Inversion Recovery (T2-FLAIR), which are rigidly aligned, and resampled to a \(1\times 1\times 1\) mm isotropic resolution and skull-stripped. The input image size is \(240\times 240\times 155\). The data were collected from multiple institutions using various MRI scanners. Annotations include three tumor sub-regions: the enhancing tumor, the peritumoral edema, and the necrotic and non-enhancing tumor core. The annotations were combined into three nested sub-regions: Whole Tumor (WT), Tumor Core (TC), and Enhancing Tumor (ET). Figure 3 illustrates typical segmentation outputs of all semantic classes. During this challenge, two additional datasets without the ground truth labels were provided for validation and testing phases. These datasets required participants to upload the segmentation masks to the organizers’ server for evaluations. The validation dataset, which is designed for intermediate model evaluations, consists of 219 cases. Additional information regarding the testing dataset was not provided to participants.

Our models were trained on BraTS 2021 dataset with 1251 and 219 cases in the training and validation sets, respectively. Semantic segmentation labels corresponding to validation cases are not publicly available, and performance benchmarks were obtained by making submissions to the official server of BraTS 2021 challenge. We used five-fold cross-validation schemes with a ratio of 80:20. We did not use any additional data. The final result was obtained with an ensemble of 10 Swin UNETR models to improve the performance and achieve a better consensus for all predictions. The ensemble models were obtained from two separate five-fold cross-validation training runs.

4 Results and Discussion

We have compared the performance of Swin UNETR in our internal cross validation split against the winning methologies of previous years such as SegResNet [30], nnU-Net [19] and TransBTS [38]. The latter is a ViT-based approach which is tailored for the semantic segmentation of brain tumors.

Evaluation results across all five folds are presented in Table 2. The proposed Swin UNETR model outperforms all competing approaches across all 5 folds and on average for all semantic classes (e.g. ET, WT, TC). Specifically, Swin UNETR outperforms the closest competing approaches by \(0.7\%, 0.6\%\) and \(0.4\%\) for ET, WT and TC classes respectively and on average \(0.5\%\) across all classes in all folds. The superior performance of Swin UNETR in comparison to other top performing models for brain tumor segmentation is mainly due to its capability of learning multi-scale contextual information in its hierarchical encoder via the self-attention modules and effective modeling of the long-range dependencies.

Moreover, it is observed that nnU-Net and SegResNet have competitive benchmarks in these experiments, with nnU-Net demonstrating a slightly better performance. On the other hand, TransBTS, which is a ViT-based methodology, performs sub-optimally in comparison to other models. The sub-optimal performance of TransBTS could be attributed to its inefficient architecture in which the ViT is only utilized in the bottleneck as a standalone attention module, and without any connection to the decoder in different resolutions.

Table 2. Five-fold cross-validation benchmarks in terms of mean Dice score values. ET, WT and TC denote Enhancing Tumor, Whole Tumor and Tumor Core respectively.
Table 3. BraTS 2021 validation dataset benchmarks in terms of mean Dice score and Hausdorff distance values. ET, WT and TC denote Enhancing Tumor, Whole Tumor and Tumor Core respectively.

The segmentation performance of Swin UNETR in the BraTS 2021 validation set is presented in Table 3. According to the official challenge resultsFootnote 3, our benchmarks (Team: NVOptNet) are considered as one of the top-ranking methodologies across more than 2000 submissions during the validation phase, hence being the first transformer-based model to place competitively in BraTS challenges. In addition, the segmentation outputs of Swin UNETR for several cases in the validation set are illustrated in Fig. 3. Consistent with quantitative benchmarks, the segmentation outputs are well-delineated for all three sub-regions.

Furthermore, the segmentation performance of Swin UNETR in the BraTS 2021 testing set is reported in Table 4. We observe that the segmentation performance of ET and WT are very similar to those of the validation benchmarks. However, the segmentation performance of TC is decreased by \(0.9\%\).

Table 4. BraTS 2021 testing dataset benchmarks in terms of mean Dice score and Hausdorff distance values. ET, WT and TC denote Enhancing Tumor, Whole Tumor and Tumor Core respectively.

5 Conclusion

In this paper, we introduced Swin UNETR which is a novel architecture for semantic segmentation of brain tumors using multi-modal MRI images. Our proposed model has a U-shaped network design and uses a Swin transformer as the encoder and CNN-based decoder that is connected to the encoder via skip connections at different resolutions. We have validated the effectiveness of our approach by in the BraTS 2021 challenge. Our model ranks among top-performing approaches in the validation phase and demonstrates competitive performance in the testing phase. We believe that Swin UNETR could be the foundation of a new class of transformer-based models with hierarchical encoders for the task of brain tumor segmentation.