Keywords

1 Introduction

Fig. 1.
figure 1

The framework of semi-supervised uncertainty-aware mean teacher transformer network for medical image segmentation

Medical image semantic segmentation is an essential computer vision task with a wide range of applications including robotic surgery, clinical diagnosis, and image alignment. The goal of image semantic segmentation is to classify each pixel of an input image as to whether or not it is part of a Region Of Interest (ROI) or background. Various deep-learning-based methods haven been widely studied in medical imaging community. The Encoder-Decoder style of Convolutional Neural Network (CNN) is one of the most commonly used segmentation techniques i.e. U-Net [14], and many researchers have studied 3D convolution, atrous convolution, residual learning, attention mechanism with U-Net for a wide range of medical imaging tasks which results in a family of U-Net such as 3D UNet, ResUNet, DenseUNet, Attention-UNet for MRI, ultrasound, CT segmentation [3, 5, 11, 20, 21]. There are three main concerns are yet to be further studied: a) the success of deep learning methods relies on a large amount of high-quality annotation data, which is high-cost, time consuming, and difficult to access especially in the clinical domain, b) the semantic feature information cannot be sufficiently condensed and transferred through traditional deep CNN layers or down/up-sampling operations, c) the limitation of the receptive fields in CNNs is not able to model long-range feature information. On order to tacke this challenge, Transformers [18] which use a pure self-attention architecture to model long-range dependencies in natural language processing without CNN are currently studied in the computer vision community. In a similar vein, we propose a ViT network in a semi-supervised manner with uncertainty estimation scheme for medical image semantic segmentation.

We first present a semi-supervised framework that effectively leverages the unlabeled data by encouraging consistent predictions of the same input under different perturbations. Following the Mean Teacher [17] to overcome limitation of Temporal Ensembling [7], the framework consists of the student model and the teacher model where the student model is able to update parameters with gradient descent, and teacher model is updated as an exponential moving average of the student weights. The whole training process is to minimize the segmentation supervision loss between student’s machine segmentation (MS) and ground truth (GT), and consistency semi-supervision loss between the teacher’s MS and the student’s MS. Secondly, inspired by uncertainty estimation [6, 23], we utilized Monte Carlo Dropout [6] to estimate the uncertainty with cross-entropy, thus enable student-teacher gradually learn from properly filtering reliable and valuable feature information. And then, to tackle the lack of semantic feature information being transferred through the CNN multi-layers and pooling, we introduce a pure self-attention-based ViT [4] as the semantic segmentation backbone. The segmentation performance benefits from a context model from Natural Language Processing [18], which is also helpful in computer vision especially in pixel-level classification tasks [8]. Finally, the evaluation results demonstrate our method’s promising performance against other state-of-the-art semi-supervised methods. Ablation studies include proposed ViT against different CNN-based backbones, several approaches of filtering uncertainty map, and the assumption of different ratio of labeled data provided for training are also explored.

2 Methodology

In the task of semi-supervised learning, \(\mathbf {L}\), \(\mathbf {U}\), \(\mathbf {T}\) normally denote labeled training dataset, unlabeled training dataset, and testing set. We denote a batch of labeled data as \((\mathbf {X}, \mathbf {Y}_\mathrm{gt}) \in \mathbf {L}, (\mathbf {X}, \mathbf {Y}_\mathrm{gt}) \in \mathbf {T}\), and a batch of only raw data as \((\mathbf {X})\in \mathbf {U}\) in unlabeled dataset, where \(\mathbf {X} \in \mathbb {R}^ {h {\times } w} \) representing a 2D image. \(\mathbf {Y}_\mathrm{t}, \mathbf {Y}_\mathrm{s}\) are the dense map predicted by the teacher ViT \(f_\mathrm{t}: \mathbf {X}\mapsto {\mathbf {Y}_\mathrm{t}}\), and student ViT \( f_\mathrm{s}: \mathbf {X}\mapsto {\mathbf {Y}_\mathrm{s}}\), respectively. \(\mathcal {L}_\mathrm{s}:(\mathbf {Y}_\mathrm{s}, \mathbf {Y}_\mathrm{gt})\mapsto {\mathbb {R}}, \mathcal {L}_\mathrm{c}:(\mathbf {Y}_\mathrm{s}, \mathbf {Y}_\mathrm{t})\mapsto {\mathbb {R}}\) represent supervised segmentation loss, and semi-supervised consistency loss. In general, the training is to update the parameter of student ViT \( f_\mathrm{s}\) aiming to minimize the combined loss \(\mathcal {L}\), which is detailed in Eq. 1. Exponential Moving Average (EMA) [17] is utilized to update parameters of teacher ViT \(f_\mathrm{t}\) from student ViT \( f_\mathrm{s}\) in each training iteration. Uncertainty estimation scheme is applied in \(\mathcal {L}_\mathrm{c}\) that enable \(f_\mathrm{t}\) to properly guide the training of \(f_\mathrm{s}\) with the certain part of inference. The proposed framework is sketched in Fig. 1. Details of the framework including semi-supervised mean teacher with uncertainty estimation scheme, and segmentation ViT, are discussed in Sects. 2.1 and 2.2.

2.1 Semi-supervised Learning Framework

Inspired by temporal ensembling [7], mean teachers [17], and uncertainty-aware self-ensembling [23], we propose a semi-supervised mean teacher framework with uncertainty estimation scheme for medical image semantic segmentation. The framework is designed to effectively leverage the unlabeled data by encouraging consistent predictions from different perturbations. In each training iteration, the student ViT \( f_\mathrm{s}\) is updated with gradient decent to minimize the combined loss \(\mathcal {L}_\mathrm{s} + \lambda \mathcal {L}_\mathrm{c}= \mathcal {L}\), which is detailed in Eq. 1. \(\lambda \) for \(\mathcal {L}_\mathrm{c}\) is calculated based on consistency ramp-up method, because it can enable both \( f_\mathrm{s}, f_\mathrm{t}\) can properly make a consistency prediction, and also allow whole framework is able to put more focus on unlabeled data [7] during training process. In the end of each training iteration, EMA is utilized to update parameters of \(f_\mathrm{t}\), and the prediction of \(f_\mathrm{t}\) is more likely to be correct than \(f_\mathrm{s}\) after a series of study [17].

To further improve semi-supervised performance by enabling \(f_\mathrm{t}\) guide \( f_\mathrm{s}\) to learn feature information via semi-supervised consistency loss \(\mathcal {L}_\mathrm{c}\), i.e. study on the region where with confident and reliable inference should be utilized to calculate for \(\mathcal {L}_\mathrm{c}\). We hereby propose uncertainty-aware scheme to enable \(f_\mathrm{s}\) is optimized with \(\mathcal {L}_\mathrm{c}\) only on confident and reliable inference images. Uncertainty estimation of inference of each pixel, and the approach of filtering the certain/uncertain inference are hereby introduced. Uncertainty estimation is mainly based on the Monte Carlo Dropout [6] on \(f_\mathrm{t}\), where 8 times stochastic forward passes with dropout and input Gaussian noise. In semantic segmentation task, each pixel is classified with the probability \(\mathbf {p}\) of ROI, and it is calculated as \(\mathbf {p} = \frac{1}{T} \sum _{t} \mathbf {p}'_{t} \) as dropout is utilized, where \(\mathbf {p}'\) is the probability before dropout. The cross-entropy of predictive \(\mathbf {U}\) is selected as the metric to estimate the uncertainty of targets [23], and it is calculated as \(\mathbf {U} = -\sum \mathbf {p}\log \mathbf {p}\). Therefore, only the region of reliable targets provided by \(f_\mathrm{t}\) (including both ROI and background) are filtered by a threshold \(\tau \) for \(f_\mathrm{s}\) to be trained with consistency semi-supervision loss \(\mathcal {L}_\mathrm{c}\), which is detailed in Eq. 2. The supervision segmentation loss \(\mathcal {L}_\mathrm{s}\) is detailed in Eq. 3.

$$\begin{aligned} \mathcal {L} = \alpha \mathcal {L}_\mathrm{s}(f_\mathrm{s}(\mathbf {X}),\mathbf {Y}_\mathrm{gt}) + \lambda \mathcal {L}_\mathrm{c}(f_\mathrm{t}(\mathbf {X}),f_\mathrm{s}(\mathbf {X})) \end{aligned}$$
(1)
$$\begin{aligned} \mathcal {L}_\mathrm{c}(f_\mathrm{t}(\mathbf {X}),f_\mathrm{s}(\mathbf {X})) = \frac{\Vert \mathcal {I} (\mathbf {U}< \tau ) \odot (f_\mathrm{t}(\mathbf {X})- f_\mathrm{s}(\mathbf {X}))^2\Vert _1}{2\Vert \mathcal {I} (\mathbf {U} < \tau )\Vert _1+\epsilon } \end{aligned}$$
(2)
$$\begin{aligned} \mathcal {L}_{s}(f_\mathrm{s}(\mathbf {X}),\mathbf {Y}_\mathrm{gt}) = \frac{1}{2}( \mathrm{CrossEntropy}(f_\mathrm{s}(\mathbf {X}),\mathbf {Y}_\mathrm{gt}) + \mathrm{Dice}(f_\mathrm{s}(\mathbf {X}),\mathbf {Y}_\mathrm{gt}) ) \end{aligned}$$
(3)

where \(\epsilon =10^{-6}\), \(\tau \) is the threshold which is modified in each training iteration based on ramp-up approach. In this way, less data will be removed in training process that enable student model to gradually learn from certain to less certain feature information. \(\lambda \) is a factor for \(\mathcal {L}_\mathrm{c}\) which is also modified in each training iteration which make the whole framework move focus on minimizing the \(\mathcal {L}_\mathrm{s}\) to \(\mathcal {L}_\mathrm{c}\) of training process [23].

2.2 Segmentation Transformer

Semantic feature information is essential in semantic segmentation. The image feature, however, is going to be blurred after multiple layers of CNN encoding. In U-Net, copy and crop are utilized between encoder and decoder to make sufficient semantic feature information been transferred through CNN which results in dominant position in segmentation [14]. The boundary of ROI, especially the information of edge response, can be lost after CNN layers and pooling layers which is harmful for performance [25]. In this section, we introduce a pure self-attention-based vision transformer without CNN for semantic segmentation aiming to achieve sufficient global image context modeling. The model is inspired by Transformer [18], Vision Transformer [4], DETR [2], and Segmentor [16]. The setting of ViT encoder and ViT mask decoder are discussed in this section, and the technical hyper-parameters setting details was introduced in Sect. 3.2.

As shown in Fig. 1, a sequence of patches \(\mathbf {X}' = [x'_{1}\cdots x'_{N}]^\top \in \mathbb {R}^ {N \times P^{2}}\) is processed from an medical image \(\mathbf {X} \in \mathbb {R}^{h {\times } w}\), where P is the patch size, and \(N = \frac{h\times w}{P^{2}}\) is the number of patch from each input image. Each patch is then flatten into a 1D vector and been projected with patch embedding \(\mathbf {X}_0 = [E_{1}\cdots {E_{N}}]^\top , E_{1\cdots {N}} \in \mathbb {R}^{D \times P^{2}}\). The positional embeddings to collect positional information \(pos = [pos_{1}\cdots pos_{N}]^\top \in \mathbb {R}^{N \times D}\) are added, and the final input sequence of tokens for encoder is \(\mathbf {Z}_{0} = \mathbf {X}_{0} + pos\). The transformer encoder consists of a multi-headed self-attention (\(\mathrm{MSA}\)) block followed by a point-wise \(\mathrm{MLP}\) block of two layers. Residual connections and layer normalization (\(\mathrm{LN}\)) are both applied in each block. The details of \(\mathrm{MSA}\) and \(\mathrm{MLP}\) block for feature learning are demonstrated in Eq. 45, where \(i \in {1\cdots {L}} \), and L is the number of layers in encoder. The self-attention mechanism is composed of three point-wise linear layers mapping tokens to intermediate representations: quires \(\mathbf {Q}\), keys \(\mathbf {K}\), and values \(\mathbf {V}\), which is introduced in Eq. 6. In this way, the transformer encoder maps input sequence \(\mathbf {Z}_0=[z_{0,1}\cdots {z}_{0,N}]\) with position to \(\mathbf {Z}_L=[z_{L,1},...,z_{L,N}]\). All these settings are following by [4]. In this way, the much richer sufficient semantic feature information are fully used in the encoder.

$$\begin{aligned} \mathbf {A}_{i-1} = \mathrm{MSA}(\mathrm{LN} (\mathbf {Z}_{i-1})) + \mathbf {Z}_{i-1} \end{aligned}$$
(4)
$$\begin{aligned} \mathbf {Z}_{i} = \mathrm{MLP}(\mathrm{LN}(\mathbf {A}_{i-1})) + \mathbf {A}_{i-1} \end{aligned}$$
(5)

where MSA is calculated by:

$$\begin{aligned} \mathrm{MSA}(\mathbf {Z}') = \mathrm{softmax}(\frac{\mathbf {Q}\mathbf {K}}{\sqrt{D}})\mathbf {V}, \end{aligned}$$
(6)

and the \(\mathbf {Q},\mathbf {K},\mathbf {V}\) are given by:

$$\begin{aligned} \mathbf {Q}=\mathrm{Linear}_\mathrm{Q}(\mathbf {Z}'), \mathbf {K}=\mathrm{Linear}_\mathrm{K}(\mathbf {Z}'), \mathbf {V}=\mathrm{Linear}_\mathrm{V}(\mathbf {Z}') \end{aligned}$$
(7)

The sequence of \(Z_{L}\) is then decoded to dense map \(\mathbf {S} \in \mathbb {R}^{h {\times } w {\times } k}\) as segmentation results via a transformer mask decoder, where k is the number of classes. The decoder acts as mapping patch from encoder and unsample to pixel-level probability of dense map [16]. The learnable class embedding \( cls \) is processed with \(\mathbf {Z}_{L}\) in mask decoder same with transformer encoder with M layers. The output patch sequence is then reshaped to a 2D mask and been bilinearly upsampled to the original image size as prediction results. In transformer mask decoder, both class embedding and patch sequence are jointly processed, and semantic segmentation mask is finally inferenced.

3 Experiments

3.1 Datasets

In this experiment, a MRI cardiac segmentation dataset is selected from the automated cardiac diagnosis MICCAI Challenge 2017 [1]. It consists of 100 different patients divided into 5 evenly distributed subgroups including normal, myocardial infarction, dilated cardiomyopathy, hypertrophic cardiomyopathy, and abnormal right ventricle. We use 44,025 \(232{\times }256\) images from 100 patients. All images are resize to \(256{\times }256\). 20% of images are selected as testing set, and the rest of dataset is for training. The ratio of assumed labeled data/training set is 10% for direct comparison experiment with similarity measures and difference measures against other semi-supervised methods, other segmentation backbones, and ablation studies, 1%, 2%, 3%, 5%, 10%, 15% and 20% for direct comparison with IOU against other semi-supervised methods.

3.2 Training Details

Our code has been developed under Ubuntu 20.04 in Python 3.8.8 using Pytorch 1.10 [12] and CUDA 11.3 using four Nvidia GeForce RTX 3090 GPU with 24 GB memory, and Intel (R) Intel Core i9-10900K. All the baseline algorithms are directly utilized from [10], and the ViT for segmentation purpose is based on [16] from [15] and TIMM library [22]. The runtime averaged around 3.5 h, including the data transfer, model training, inference and evaluation. All semi-supervised methods are trained with same settings, i.e. training for 30,000 iterations then been tested directly, batch size is set to 24, optimizer is SGD, and learning rate is initially set to 0.01, momentum is 0.9, and weight decay is 0.0001. After multi-times experiments, we finally come up with a proper hyper parameters setting for segmentation ViT which achieve the best results with limited computation resources(6 GB in GPU memory costs): The patch size is \(16 {\times } 16\), the number of multi-attention heads is 6, the number of layers L of encoder is 12, normalization method is same with Transformer [18], and the number of layers M of decoder is 2.

3.3 Evaluation

Our proposed semi-supervised method is compared with mean teachers [17], deep adversarial network [24], adversarial entropy minimization for domain adaptation [19], uncertainty-aware self-ensembling model [23], and deep co-training [13] as semi-supervised baseline methods with U-Net [14] as backbone. The direct comparison experiments are conducted with a variety of evaluation metrics including similarity measures: Dice, IOU, Accuracy, Precision, Recall/Sensitivity, Specificity, which are the higher the better. We also investigate difference measures: Relative Volume Difference (RVD), Hausdorff Distance (HD), Average Symmetric Surface Distance (ASSD), which are the lower the better.

Fig. 2.
figure 2

The example raw images and inference results on testing set (Color figure online)

3.4 Results

Figure 2 illustrates some examples of raw images, and MS against GT where and Black represent as True Positive, False Positive, False Negative and True Negative pixel, respectively. Example raw images with uncertainty map, and mask of certain image in three different training stages are illustrated in Appendix. The best result was in Bold, and quantitative results are detailed in Table 1 and Table 2. The evaluation results demonstrate that proposed method promising performance against other semi-supervised methods. Figure 3 gives a systematic review of how the IOU varies when 1%, 2%, 3%, 5%, 10%, 15% and 20% of the training set is labeled. More details of quantitative analysis for different assumed ratio of labeled data given is illustrated in Appendix.

Table 1. Direct comparison with similarity measures on cardiac MRI testing set (the higher, the better)
Table 2. Direct comparison with difference measures on cardiac MRI testing set (the lower, the better)
Fig. 3.
figure 3

The IOU performance on test set with different ratio of labeled/total training set

3.5 Ablation Study

In order to analyze the effects of each of the proposed contributions and their combinations, extensive ablation experiments have been conducted. Table 3 annotates with  the use of the mandatory mean teacher for semi-supervise purpose, demonstrating how the removal of uncertainty estimation compromises the overall performance. The model is selected and tested with U-Net [14], E-Net [12], and proposed segmentation ViT. Further experiments under the assumption of fully supervised learning are also conducted annotated with full  in Table 3. Our proposed ViT with uncertainty estimation scheme shows promising performance especially in IOU and sensitivity in both semi-supervised and fully-supervised manner, respectively. The extended experiments of threshold setting of \(\tau \) and weight \(\lambda \) of \(\mathcal {L}_\mathrm{s}\) in training process is illustrated in Appendix.

Table 3. Ablation studies on contributions of architecture and modules (the higher, the better)

4 Conclusion

Our semi-supervised uncertainty-aware segmentation is successful in using ViT for medical image semantic segmentation via a mean teacher framework. Experimental results on the public MRI dataset demonstrate our method’s promising performance compared against both supervised and semi-supervised existing methods. In the future, multi-task learning and multi-view learning which potentially improve semi-supervised learning performance will be further studied.

Table 4. The IOU results under different assumption of ratio of label/total data on MRI cardiac test set (the higher, the better)