Keywords

1 Introduction

Automated and accurate segmentation of brain tumors plays an essential role in clinical assessment and diagnosis. Magnetic Resonance Imaging (MRI) is a common neuroimaging technique for the quantitative evaluation of brain tumors in clinical practice, where multiple imaging modalities, i.e., T1-weighted (T1), contrast-enhanced T1-weighted (T1c), T2-weighted (T2), and Fluid Attenuated Inversion Recovery (FLAIR) images, are provided. Each imaging modality provides a distinctive contrast of the brain structure and pathology. The joint learning of multimodal images for brain tumor segmentation is essential and can significantly boost the segmentation performance. Plenty of methods have been widely explored to effectively fuse multimodal MRIs for brain tumor segmentation by, for example, concatenating multimodal images in channel dimension as the input or fusing features in the latent space [17, 23]. However, in clinical practice, it is not always possible to acquire a complete set of MRIs due to data corruption, various scanning protocols, and unsuitable conditions of patients. In this situation, most existing multimodal methods may fail to deal with incomplete imaging modalities and face a severe degradation in segmentation performance. Consequently, a robust multimodal method is highly desired for a flexible and practical clinical application with one or more missing modalities.

Incomplete multimodal learning, also known as hetero-modal learning [8], aims at designing methods that are robust with any subset of available modalities at inference. A straightforward strategy for incomplete multimodal learning of brain tumor segmentation is synthesizing the missing modalities by generative models [18]. Another stream of methods explores knowledge distillation from complete modalities to incomplete ones [2, 10, 21]. Although promising results are obtained, such methods have to train and deploy a specific model for each subset of missing modalities, which is complicated and burdensome in clinical application. Zhang et al. [22] proposed an ensemble learning of single-modal models with adaptive fusion to achieve multimodal segmentation. However, it only works when one or all modalities are available. Meanwhile, all these methods require complete modalities during the training process.

Recent methods focused on learning a unified model, instead of a bunch of distilled networks, for incomplete multimodal segmentation [8, 16]. For example, HeMIS [8] learns an embedding of multimodal information by computing mean and variance across features from any number of available modalities. U-HVED [4] further introduces multimodal variational auto-encoder to benefit incomplete multimodal segmentation with generation of missing modalities. More recent methods also proposed to exploit feature disentanglement [1] and attention mechanism [3] for robust multimodal brain tumor segmentation. Fully Convolutional Network (FCN) [11, 15] has achieved great success in medical image segmentation and is widely used for feature extraction in the methods mentioned above. Despite its excellent performance, the inductive bias of convolution, i.e., the locality, makes FCN difficult to build long-range dependencies explicitly. In incomplete multimodal learning of brain tumor segmentation, the features extracted with limited receptive fields tend to be biased when dealing with varying modalities. In contrast, a modality-invariant embedding with global semantic information of tumor region across different modalities may contribute to more robust segmentation, especially when one or more modalities are missing.

Transformer was originally proposed to model long-range dependencies for sequence-to-sequence tasks [19], and also shows state-of-the-art performance on various computer vision tasks [5]. Concurrent works [7, 14, 20] exploited Transformer for brain tumor segmentation from the view of backbone network. However, the dedicated Transformer for multimodal modeling of brain tumor segmentation has not been carefully tapped yet, letting alone the incomplete multimodal segmentation.

Fig. 1.
figure 1

Overview of the proposed mmFormer, which is composed of four hybrid modality-specific encoders, a modality-correlated encoder, and a convolutional decoder. Meanwhile, auxiliary regularizers are introduced in both encoder and decoder. The skip connections between the convolutional encoder and decoder are hidden for clear display.

This paper aims to exploit Transformer to build a unified model for incomplete multimodal learning of brain tumor segmentation. We propose Multimodal Medical Transformer (mmFormer) that leverages hybrid modality-specific encoders and a modality-correlated encoder to build the long-range dependencies both within and across different modalities. With the modality-invariant representations extracted by explicitly building and aligning global correlations between different modalities, the proposed mmFormer demonstrates superior robustness to incomplete multimodal learning of brain tumor segmentation. Meanwhile, auxiliary regularizers are introduced into mmFormer to encourage both encoder and decoder to learn discriminative features even when a certain number of modalities are missing. We validate mmFormer on the task of multimodal brain tumor segmentation with BraTS 2018 dataset [12]. The proposed method outperforms the state-of-the-art methods in the average Dice metric over all settings of missing modalities, especially by an average 19.07% improvement in Dice on enhancing tumor segmentation with only one available modality. To the best of our knowledge, this is the first attempt to involve the Transformer for incomplete multimodal learning of brain tumor segmentation.

2 Method

In this paper, we propose mmFormer for incomplete multimodal learning of brain tumor segmentation. We adopt an encoder-decoder architecture to construct our mmFormer, including a hybrid modality-specific encoder for each modality, a modality-correlated encoder, and a convolutional decoder. Besides, auxiliary regularizers are introduced in both encoder and decoder. An overview of mmFormer is illustrated in Fig. 1. We elaborate on the details of each component in the followings.

2.1 Hybrid Modality-Specific Encoder

The hybrid modality-specific encoder aims to extract both local and global context information within a specific modality by bridging a convolutional encoder and an intra-modal Transformer. We denote the complete set of modalities by \(M=\{FLAIR, T1c, T1, T2\}\). Given an input of \(\textbf{X}_m \in \mathbb {R}^{1\times D\times H\times W}\) with a size of \(D\times H\times W\), \(m\in M\), we first utilize the convolutional encoder to generate compact feature maps with the local context and then leverage the intra-modal Transformer to model the long-range dependency in a global space.

Convolutional Encoder. The convolutional encoder is constructed by stacking convolutional blocks, similar to the encoder part of U-Net [15]. The feature maps with the local context within each modality produced by the convolutional encoder \(\mathcal {F}^{conv}_m\) can be formulated as

$$\begin{aligned} \begin{aligned} \textbf{F}^{local}_m = \mathcal {F}^{conv}_m(\textbf{X}_m; \theta ^{conv}_m) \end{aligned} \end{aligned}$$
(1)

where \(\textbf{F}^{local}_m \in \mathbb {R}^{C\times \frac{D}{2^{l-1}}\times \frac{H}{2^{l-1}}\times \frac{W}{2^{l-1}}}\), C is the channel dimension, and l is the number of the stages in the encoder. Concretely, we build a five-stage encoder, and each stage consists of two convolutional blocks. Each block contains cascaded group normalization, ReLU, and convolutional layers with kernel size of 3, while the first convolutional block in the first stage only contains a convolutional layer. Between two consecutive blocks, a convolutional layer with stride of 2 is employed to downsample the feature maps. The number of filters at each level of the encoder is 16, 32, 64, 128, and 256, respectively.

Intra-modal Transformer. Limited by the intrinsic locality of the convolutional network, the convolutional encoder fails to effectively build the long-range dependency within each modality. Therefore, we exploit the Intra-modal Transformer for explicitly long-range contextual modeling. The Intra-modal Transformer contains a tokenizer, a Multi-head Self Attention (MSA), and a Feed-Forward Network (FFN).

As Transformer processes the embeddings in a sequence-to-sequence manner, the local feature maps \(\textbf{F}^{local}_m\) produced by the convolutional encoder is first flattened into a 1D sequence and transformed into token space by a linear projection. However, the flattening operation inevitably collapses the spatial information, which is critical to image segmentation. To address this issue, we introduce a learnable position embedding \(\mathbf {P_m}\) to supplement the flattened features via element-wise summation, which is formulated as

$$\begin{aligned} \begin{aligned} \textbf{F}^{token}_m = \textbf{F}^{local}_m\textbf{W}_m + \textbf{P}_m, \end{aligned} \end{aligned}$$
(2)

where \(\textbf{F}^{token}_m \in \mathbb {R}^{C' \times \frac{DHW}{2^{3(l-1)}}}\) denotes the token and \(\textbf{W}_m\) denotes the weights of linear projection. The MSA builds the relationship within each modality by looking over all possible locations in the feature map, which is formulated as

$$\begin{aligned} \begin{aligned} head^i_m = Attention(\textbf{Q}^i_m,\textbf{K}^i_m,\textbf{V}^i_m) = softmax(\frac{\textbf{Q}^i_m\textbf{K}^{i\textrm{T}}_m}{\sqrt{d_k}})\textbf{V}^i_m, \end{aligned} \end{aligned}$$
(3)
$$\begin{aligned} \begin{aligned} MSA_m = [head^1_m, ..., head^N_m]\textbf{W}^o_m, \end{aligned} \end{aligned}$$
(4)

where \(\textbf{Q}^i_m=LN(\textbf{F}^{token}_m)\textbf{W}^{Qi}_m\), \(\textbf{K}^i_m=LN(\textbf{F}^{token}_m)\textbf{W}^{Ki}_m\), \(\textbf{V}^i_m=LN(\textbf{F}^{token}_m)\textbf{W}^{Vi}_m\), \(LN(\cdot )\) is layer normalization, \(d_k\) is the dimension of \(\textbf{K}_m\), \(N=8\) is the number of attention heads, and \([\cdot , \cdot ]\) is a concatenation operation. The FFN is a two-layer perceptron with GELU [9] activation. The feature maps with global context within each modality produced by the intra-modal Transformer is defined as

$$\begin{aligned} \begin{aligned} \textbf{F}^{global}_m = FFN_m(LN(z)) + z, z = MSA_m(LN(\textbf{F}^{token}_m)) + \textbf{F}^{token}_m, \end{aligned} \end{aligned}$$
(5)

where \(\textbf{F}^{global}_m \in \mathbb {R}^{C' \times \frac{DHW}{2^{3(l-1)}}}\).

2.2 Modality-Correlated Encoder

The modality-correlated encoder is designed to build the long-range correlations across modalities for modality-invariant features with global semantics corresponding to the tumor region. It is implemented as an inter-modal Transformer.

Inter-modal Transformer. In contrast to the intra-modal Transformer, the inter-modal Transformer combines the embeddings from all modality-specific encoders by concatenation as the input multimodal token, which is defined as

$$\begin{aligned} \begin{aligned} \textbf{F}^{token} = [\delta _{FLAIR}\textbf{F}^{global}_{FLAIR}, \delta _{T1c}\textbf{F}^{global}_{T1c}, \delta _{T1}\textbf{F}^{global}_{T1}, \delta _{T2}\textbf{F}^{global}_{T2}]\textbf{W} + \textbf{P}, \end{aligned} \end{aligned}$$
(6)

where \(\delta _m \in \{0, 1\}\) is a Bernoulli indicator that aims to grant robustness when building long-range dependencies between different modalities even when some modalities are missing. This kind of modality-level dropout is randomly conducted during training by setting \(\delta _m\) to 0. In case of missing modalities, the multimodal token for the missing modalities will be held by a zero vector. Subsequently, it is processed by MSD and FFN for modality-invariant features across modalities, which is formulated as

$$\begin{aligned} \begin{aligned} \textbf{F}^{global} = FFN(LN(z)) + z, z = MSA(LN(\textbf{F}^{token})) + \textbf{F}^{token}, \end{aligned} \end{aligned}$$
(7)

where \(\textbf{F}^{global} \in \mathbb {R}^{C' \times \frac{DHW}{2^{(l-1)}}}\).

2.3 Convolutional Decoder

The convolutional decoder is designed to progressively restore the spatial resolution from high-level latent space to original mask space. The output sequence \(\textbf{F}^{global}\) of the modality-correlated Transformer is reshaped into feature maps corresponding to the size before flattening. The convolutional decoder has a symmetric architecture of convolutional encoder, similar to U-Net [15]. Besides, the skip connections between encoder and decoder are also added to keep more low-level details for better segmentation. The features from convolutional encoders of different modalities at a specific level are concatenated and forwarded as skip features to the convolutional decoder.

2.4 Auxiliary Regularizer

Conventional multimodal learning models tend to recognize brain tumors relying on the discriminative modalities [1, 3]. Such models are likely to face severe degradation when the discriminative modalities are missing. Therefore, it is critical to encourage each convolutional encoder to segment brain tumors even without the assistance of other modalities. To this end, the outputs of convolutional encoders are upsampled by a shared-weight decoder to segment tumors from each modality separately. The shared-weight decoder has the same architecture with the convolutional decoder. Besides, we also introduce auxiliary regularizers in the convolutional decoder to force the decoder to generate accurate segmentation even when certain modalities are missing. It is achieved by interpolating the feature maps in each stage of the convolutional decoder to segment tumors via deep supervision [6]. Dice loss [13] is employed as the regularizer. Combining the training loss of the network’s output with the auxiliary regularizers, the overall loss function is defined as

$$\begin{aligned} \begin{aligned} \mathcal {L}=1-Dice=1-\frac{2 \sum _{c=1}^{C} \sum _{i=1}^{N_c} g_{i}^{c} p_{i}^{c}}{\sum _{c=1}^{C} \sum _{i=1}^{N_c} g_{i}^{c 2}+\sum _{c=1}^{C} \sum _{i=1}^{N_c} p_{i}^{c 2}}, \end{aligned} \end{aligned}$$
(8)
$$\begin{aligned} \begin{aligned} \mathcal {L}_{\text{ total } } = \sum _{i\in M}\mathcal {L}^{encoder}_i + \sum ^{l-1}_{i=1}\mathcal {L}^{decoder}_i + \mathcal {L}^{output}, \end{aligned} \end{aligned}$$
(9)

where C is the number of segmentation classes, and \(N_c\) is the number of voxels of class c, \(g_i^c\) is a binary indicator if class label c is the correct classification for pixel i, \(p_i^c\) is the corresponding predicted probability, \(M=\{FLAIR, T1c, T1, T2\}\), and l is the number of stages in the convolutional decoder.

Table 1. Results of the proposed method and state-of-the-art unified models, i.e., HeMIS [8] and U-HVED [4], on BraTS 2018 dataset [12]. Dice similarity coefficient (DSC) [%] is employed for evaluation with every combination settings of modalities. \(\bullet \) and \(\circ \) denote available and missing modalities, respectively.

3 Experiments and Results

Dataset and Implementation. The experiments are conducted on BraTS 2018 datasetFootnote 1 [12], which consists of 285 multi-contrast MRI scans with four modalities: T1, T1c, T2, and FLAIR. Different subregions of brain tumors are combined into three nested subregions: whole tumor, tumor core, and enhancing tumor. All the volumes have been co-registered to the same anatomical template and interpolated to the same resolution by the organizers. Dice Similarity Coefficient (DSC) as defined in Eq. (8) is employed for evaluation. The framework is implemented with PyTorch 1.7 on four NVIDIA Tesla V100 GPUs. The input size is \(128 \times 128 \times 128\) voxels and batch size is 1. Random flip, crop, and intensity shifts are employed for data augmentation. The mmFormer has 106M parameters and 748G FLOPs. The network is trained with the Adam optimizer with an initial learning rate of 0.0002 for 1000 epochs. The model is trained for about 25 h with 17G memory on each GPU.

Performance of Incomplete Multimodal Segmentation. We evaluate the robustness of our method to incomplete multimodal segmentation. The absence of modality is implemented by setting \(\delta _i, i\in \{FLAIR, T1c, T1, T2\}\) to be zero for dropping the specific modalities at inference. We compare our method with two representative models using shared latent space, i.e., HeMIS [8] and U-HVED [4]. For a fair comparison, we use the same data split in [21] and directly reference the results. In Table 1, our method significantly outperforms HeMIS and U-HVED on the segmentation of enhancing tumor and tumor core on all the 15 possible combinantions of available modalities and the segmentation of the whole tumor on 12 out of 15. In Table 2, we show that with the increased number of missing modalities, the average improvement obtained by mmFormer is more considerable. Meanwhile, it is observed that mmFormer gains more improvement when the target is more difficult to segment. These results demonstrate the effectiveness of mmFormer for incomplete multimodal learning of brain tumor segmentation. Figure 2 shows that even with one modality available, mmFormer can achieve proper segmentation for brain tumor.

Fig. 2.
figure 2

Segmentation results of mmFormer with various available modalities.

We also compare mmFormer with ACN [21]. ACN relies on knowledge distillation for incomplete multimodal brain tumor segmentation. In the case of N modalities in total, ACN has to train \(2^4-2\) times to distill \(2^N-2\) student models for all conditions of missing modalities, while our mmFormer only learns once by a unified model. Specifically, ACN is trained for 672 h with 144M parameters for 1 teacher and 14 student models, while mmFormer requires only 25 h with 106 M parameters. Nevertheless, the average DSC for enhancing tumor, tumor core, and whole tumor of mmFormer (59.85, 72.97 and 82.94, respectively) is still close to it of ACN (61.21, 77.62, and 85.92, respectively).

Performance of Complete Multimodal Segmentation. We compare our method with a recent Transformer-based method, i.e., TransBTS [20], for multimodal brain tumor segmentation with full modalities. We reproduce the results with the official repository. TransBTS obtains DSC of 72.66%, 72.69%, and 79.99% on enhancing tumor, tumor core, and the whole tumor, respectively. Our mmFormer outperforms TransBTS on all subregions of brain tumor with DSC of 77.61%, 85.78%, and 89.64%, demonstrating the effectiveness of mmFormer even for complete multimodal brain tumor segmentation.

Ablation Study. We investigate the effectiveness of intra-modal Transformer, inter-modal Transformer, and auxiliary regularizer as three critical components in our method. We analyze the effectiveness of each component by excluding one of them from mmFormer. In Table 3, we compare the performance of the three variants to mmFormer with DSC, averaging over the 15 possible combinations of input modalities. It shows that intra-modal Transformer, inter-modal Transformer, and auxiliary regularizer bring performance improvement across all the tumor subregions.

Table 2. Average improvements of mmFormer upon HeMIS [8] and U-HVED [4] with different numbers of missing modalities evaluated by DSC [%].
Table 3. Ablation study of critical components of mmFormer.

4 Conclusion

We proposed a Transformer-based method for incomplete multimodal learning of brain tumor segmentation. The proposed mmFormer bridges Transformer and CNN to build the long-range dependencies both within and across different modalities of MRI images for a modality-invariant representation. We validated our method on brain tumor segmentation under various combinations of missing modalities, and it outperformed state-of-the-art methods on the BraTS benchmark. Our method gains more improvements when more modalities are missing and/or the target ones are more difficult to segment.