Keywords

1 Introduction

Manually labeling sufficient medical data with pixel-level accuracy is time-consuming, expensive, and often requires domain-specific knowledge. To bypass the cost for labeled data, semi-supervised learning (SSL) is one of the promising, conventional ways to train models with weaker forms of supervision, given a large amount of unlabeled data. Existing SSL methods include adversarial training [12, 28, 32, 33, 37], deep co-training [21, 38], mean teacher schemes [23, 36], multi-task learning [4, 11, 16, 31], and contrastive learning [3, 9, 29, 30, 34, 35].

Fig. 1.
figure 1

Examples of two benchmarks (i.e., ACDC and LiTS) showing the large variations of class distribution.

Among the aforementioned methods, contrastive learning [5, 8] has recently prevailed for DNNs to rich visual representations from unlabeled data. The predominant promise of label-free learning is to capture the similar semantic relationship and anatomical structure between neighboring pixels from massive unannotated data. However, going to realistic clinical scenarios will have the following shortcomings. First, different medical images share similar anatomical structures, but prior methods follow the standard contrastive learning [5, 8] in comparing positive and negative pairs by binary supervision. That naturally leads to the issues of false negatives in representation learning [10, 24], which would hurt segmentation performance. Second, the underlying class distribution of medical image data is highly imbalanced, as illustrated in Fig. 1. It is well known that such imbalanced distribution will severely hurt the segmentation quality [14], which may result in blurry contours and mis-classify minority classes due to the occurrence frequencies [39]. That naturally questions whether contrastive learning can still work well in those imbalance scenarios.

In this work, we present a principled framework called Anatomical-aware ConTrastive dIstillatiON (ACTION), for multi-class medical image segmentation. In contrast to prior work [3, 9, 35] which directly distinguish two image samples of the similar anatomical features that are in the negative pairs, the key innovation in ACTION is to actively learn more balanced representations by dynamically selecting samples that are semantically similar to the queries, and contrasting the model’s anatomical-level features with the target model’s in imbalanced and unlabeled clinical scenarios. Specifically, we introduce two strategies to improve overall segmentation quality: (1) we believe that all negative samples are not equally negative. Thus, we propose relaxed contrastive learning by using soft labeling on the negatives. In other words, we randomly sample a set of image samples as anchor points to ensure diversity in the set of sampled examples. Then the teacher model predicts the underlying probability distribution over neighboring samples by computing the anatomical similarities between the query and the anchor points in the memory bank, and the student model tries to learn from the teacher model. Such a strategy is much more regularized by mincing the same neighborhood anatomical similarity to improve the quality of the anatomical features; (2) to create strong contrastive views on anatomical features, we introduce AnCo, another new contrastive loss designed at the anatomical level, by sampling a set of pixel-level representation as queries, and pulling them closer to the mean feature of all representations in a class (positive keys), and pulling other representations apart from other class (negative keys). In addition to reducing the high memory footprint and computation complexity, we use active sampling to dynamically select a sparse set of queries and keys during the training. We apply ACTION on two benchmark datasets under different unlabeled settings. Our experiments show that ACTION can dramatically outperform the state-of-the-art SSL methods. We believe that our proposed ACTION can be a strong baseline for the related medical image analysis tasks in the future.

Fig. 2.
figure 2

Overview of the ACTION framework including three stages: (1) global contrastive distillation pre-training used in existing works, (2) our proposed local contrastive distillation pre-training, and (3) our proposed anatomical contrast fine-tuning.

2 Method

Framework Overview. The workflow of our proposed ACTION is illustrated in Fig. 2. By default, ACTION is built on the BYOL pipeline [7] which is originally designed for image classification tasks, and for a fair comparison, we also follow the setting in [3] such as using 2D U-Net [22] as the backbone and non-linear projection heads H. The main differences between our proposed ACTION and [3, 9] are as follows: (1) the addition of a predictor \(g(\cdot )\) to the student network to avoid collapsed solutions; (2) the utilization of a slow-moving average of the student network as the teacher network for more semantically compact representations; (3) the use of the output probability rather than logits effectively and semantically constrains the distance between the anatomical features from the imbalanced data (i.e., multi-class label imbalance cases); (4) we propose to contrast the query image features with other random image features at the global and local level, rather than only two augmented versions of the same image features; and (5) we design a novel unsupervised anatomical contrastive loss to provide additional supervision on hard pixels.

Let (XY) be a training dataset including N labeled image slices and M unlabeled image slices, with training images \(X=\{x_i\}_{i=1}^{N+M}\) and the C-class segmentation labels \(Y=\{y_i\}_{i=1}^{N}\). Our backbone \(F(\cdot )\) (2D U-Net) consists of an encoder network \(E(\cdot )\) and a decoder network \(D(\cdot )\). The training procedure of ACTION includes three stages: (i) global contrastive distillation pre-training, (ii) local contrastive distillation pre-training, and (iii) anatomical contrast fine-tuning. In the first two stages, we use global contrastive distillation to train E on unlabeled data to learn global-level features, and use local contrastive distillation to train E and D on labeled and unlabeled data to learn local-level features.

Global Contrastive Distillation Pre-training. We follow a similar setting in [24]. Given an input query image \(q \in \{x_i\}_{i=N+1}^{N+M}\) with the spatial size \(h\times w\), we first apply two different augmentations to obtain \(q_{t}\) and \(q_{s}\), and randomly sample a set of augmented images \(\{x_j\}_{j=1}^{n}\) from a set of unlabeled image slices \(\{x_i\}_{i=N+1}^{N+M}\). We believe that such relaxation enables the model to capture more rich semantic relationships and anatomical features from its neighboring images instead of only learning from the different version of the same query image. We then feed \(\{x_j\}_{j=1}^{n}\) to the teacher encoder \(E_t\), and followed by the nonlinear projection head \(H_{t}^{g}\) to generate their projection embeddings \(\{H_{t}^{g}(E_{t}(x_j))\}_{j=1}^{n}\) as anchor points, and also feed \(q_{t}\) and \(q_{s}\) to the teacher and student (i.e., E and H), creating \(z_t = H_{t}^{g}(E_{t}(q_{t}))\) and \(z_s=H_{s}^{g}(E_{s}(q_{s}))\). Here we utilize the probabilities after SoftMax instead of the feature embedding:

$$\begin{aligned} p_t(j) = -\text {log} \frac{\text {exp}\big (\text {sim}\big (z_{t}, a_{j}\big )/\tau _t\big )}{\sum _{i=1}^n \text {exp}\big (\text {sim}\big (z_{t}, a_{i}\big )/\tau _t\big )}, \end{aligned}$$
(1)

where \(\tau _t\) is a temperature hyperparameter of the teacher, and \(\text {sim}(\cdot ,\cdot )\) is the cosine similarity. Then inspired by [7], in order to avoid collapsed solutions in an unsupervised scenario, we use a shallow multi-layer perceptron (MLP) predictor \(H_{p}^{g}(\cdot )\) to obtain the prediction \(z_{s}^{*}=H_{p}^{g}(z_{s})\). Of note, \(\{a_{i}\}_{i=1}^{n}\), \(z_{t}\), \(z_{s}\), \(z_{s}^{*}\) can be generated embedding from a set of randomly chosen augmented images, teacher’s projection embeddings, student’s projection embeddings, and student’s prediction embeddings in either Stage-i or ii. Therefore, we can calculate the similarity distance between the student’s prediction and the anchor embeddings by converting them to probability distribution.

$$\begin{aligned} p_s(j) = -\text {log} \frac{\text {exp}\big (\text {sim}\big (z_{s}^{*}, a_{j}\big )/\tau _s\big )}{\sum _{i=1}^n \text {exp}\big (\text {sim}\big (z_{s}^{*}, a_{i}\big )/\tau _s\big )}, \end{aligned}$$
(2)

where \(\tau _s\) refers to a temperature hyperparameter of the student. The unsupervised contrastive loss is computed as follows:

$$\begin{aligned} \mathcal {L}_{\text {contrast}} = \text {KL}(p_t || p_s). \end{aligned}$$
(3)

Local Contrastive Distillation Pre-training. After training the teacher’s and student’s encoder to learn global-level image features, we attach the decoders and tune the entire models to perform pixel-level contrastive learning in a semi-supervised manner. The distinction in the training strategy between ours and [9] lies in Stage-ii and iii: [9] only use labeled data in training, while we use both labeled and unlabeled data in training. Considering the training procedure of Stage-ii is similar to Stage-iii, we briefly describe it here as illustrated in Fig. 2. For the labeled data, we train our model by minimizing the supervised loss (the linear combination of cross-entropy loss and dice loss) in Stage-ii and Stage-iii. As for the unlabeled input images q and \(\{x_j\}_{j=1}^{n}\), we first apply two different augmentations to q, creating two different versions \([q_{t}^{l},q_{s}^{l}]\), and then feed them to \(F_{t}\) and \(F_{s}\), and their output features \([f_{t},f_{s}]\) are fed into \(H_{t}^{l}\) and \(H_{t}^{l}\). The student’s projection embedding is subsequently fed into \(H_{p}^{l}\) to obtain the student’s prediction embedding to enforce the similarity between the teacher and the student under the same loss as Eq. 3. We also include the randomly selected images to enforce such similarity because intuitively, it may be beneficial to ensure diversity in the set of sampled examples. It is important to note that ACTION will re-use the well-trained weight of the models \(F_t\) and \(F_s\) as initialization for Stage-iii.

Anatomical Contrast Fine-Tuning. Broadly speaking, in medical images, the same tissue types may share similar anatomical information in different patients, but different tissue types often show different class, appearance, and spatial distributions, which can be described as a complicated form of imbalance and uncertainty in real clinical data, as shown in Fig. 1. This motivates us to efficiently incorporate more useful features so the representations can be more balanced and better discriminated in such multi-class label imbalanced scenarios. Inspired by [15], we propose AnCo, a new unsupervised contrastive loss designed at the anatomical level. Specifically, we additionally attach a representation decoder head \(H_{r}\) to the student network, parallel to the segmentation head, to decode the multi-layer hidden features by first using multiple up-sampling layers for outputting dense features with the same spatial resolution as the query image and then mapping them into high m-dimensional query, positive key, and negative key embeddings: \(r_q, r_k^{+}, r_k^{-}\). The AnCo loss is then defined as:

$$\begin{aligned} \mathcal {L}_\text {anco} = \sum _{c\in \mathcal {C}} \sum _{r_q \sim \mathcal {R}^c_q} -\log \frac{\exp (r_q \cdot r_k^{c, +} / \tau _{an})}{\exp (r_q \cdot r_k^{c, +}/ \tau _{an}) + \sum _{r_k^{-}\sim \mathcal {R}^c_k} \exp (r_q \cdot r_k^{-}/ \tau _{an})}, \end{aligned}$$
(4)

where \(\mathcal {C}\) is a set of all available classes in a mini-batch, and \(\tau _{an}\) denotes a temperature hyperparameter for AnCo loss. \(\mathcal {R}_q^c\) and \(r_k^{c, +}\) are a set of query embeddings in class c and the positive key embedding, which is the mean representation of class c, respectively. \(\mathcal {R}_k^c\) is a set of negative key embeddings which are not in class c. Suppose \(\mathcal {P}\) is a set including all pixel coordinates with the same resolution with \(x_{i}\), these queries and keys are then defined as:

$$\begin{aligned} \mathcal {R}_q^c\!=\!\!\bigcup _{[m, n]\in \mathcal {P}}\!\!\mathbbm {1}(y_{[m,n]}\!=\!c)\, r_{[m,n]},\, \mathcal {R}_k^c\!=\!\!\bigcup _{[m, n]\in \mathcal {P}}\!\!\mathbbm {1}(y_{[m,n]}\!\ne \!c)\, r_{[m,n]},\, r_k^{c, +}\!\!=\!\frac{1}{| \mathcal {R}_q^c |}\sum _{r_q \in \mathcal {R}_q^c} r_q. \end{aligned}$$
(5)

In addition, we note that contrastive learning usually benefits from a large collection of positive and negative pairs, but it is usually bounded by the size of GPU memory. Therefore, we introduce two novel active hard sampling methods. To address the uncertainty on the most challenging pixels among all available classes (i.e., close anatomical or semantic relationship), we non-uniformly sample negative keys based on relative similarity distance between the query class and each negative key class. For each mini-batch, we build a graph G to measure the pair-wise class relationship to dynamically update G.

$$\begin{aligned} G[p, q] = \left( r_k^{p, +} \cdot r_k^{q, +}\right) ,\quad \forall p,q \in \mathcal {C}, \text { and } p\ne q, \end{aligned}$$
(6)

where \(G\in \mathbb {R}^{|\mathcal {C}| \times |\mathcal {C}|}\). Note that this process may be hard to allocate more samples. Thus, to learn a more accurate decision boundary, we first apply SoftMax function by normalizing the pair-wise relationships among all negative classes n from each query class c, yielding a distribution: \(\exp (G[c, v])/ \sum _{n\in \mathcal {C}, n\ne c} \exp (G[c, n])\). Then we adaptively sample negative keys from each class v to help learn the corresponding query class c. To alleviate the imbalance issue, we sample hard queries based on a defined threshold, to better discriminate the rare classes. The easy and hard queries are computed as follows:

$$\begin{aligned} \mathcal {R}_q^{c,\, easy} = \bigcup _{r_q \in \mathcal {R}^c_q} \mathbbm {1}(\hat{y}_q > \theta _s)r_q,\quad \mathcal {R}_q^{c,\, hard} = \bigcup _{r_q \in \mathcal {R}^c_q} \mathbbm {1}(\hat{y}_q \le \theta _s)r_q, \end{aligned}$$
(7)

where \(\hat{y}_q\) is the predicted confidence of label c corresponding to \(r_q\) after SoftMax function, and \(\theta _s\) is the user-defined confidence threshold.

3 Experiments

Experimental Setup. We experiment on two benchmark datasets: ACDC 2017 dataset [1] and MICCAI 2017 Liver Tumor Segmentation Challenge (LiTS) [2].

The ACDC dataset includes 200 cardiac cine MRI scans from 100 patients with annotations including three segmentation classes (i.e., left ventricle (LV), myocardium (Myo), and right ventricle (RV)). Following [16, 27], we use 140, 20, and 60 scans for training, validation, and testing, respectively.

The LiTS dataset includes 131 contrast-enhanced 3D abdominal CT volumes with annotations of two segmentation classes (i.e., liver and tumor). Following [13], we use the first 100 volumes for training, and the rest 31 for testing. For pre-processing, we follow the setting in [3] to normalize the intensity of each 3D scans, resample all 2D slices and the corresponding segmentation maps to a fixed spatial resolution (i.e., 256 \(\times \) 256 pixels). To quantitatively assess the performance of our proposed method, we report two popular metrics: Dice coefficient (DSC) and Average Surface Distance (ASD) for 3D segmentation results.

Table 1. Comparison of segmentation performance (DSC[%]/ASD[voxel]) on ACDC under two unlabeled settings (3 or 7 labeled). The best results are indicated in bold.
Fig. 3.
figure 3

Visualization of segmentation results on ACDC with 3 labeled data. As is shown, ACTION consistently produces sharper object boundaries and more accurate predictions across all methods. Different structure categories are shown in different colors. (Color figure online)

Table 2. Comparison of segmentation performance (DSC [%]/ASD [voxel]) on LiTS under two unlabeled settings (5% or 10% labeled ratio). The best results are in bold.
Fig. 4.
figure 4

Visualization of segmentation results on LiTS with 5% labeled ratio. As is shown, ACTION achieves consistently sharp and accurate object boundaries compared to other SSL methods. Different structure categories are shown in different colors. (Color figure online)

Implementation Details. All our models are implemented in PyTorch [19]. We train all methods with SGD optimizer (learning rate = 0.01, momentum = 0.9, weight decay = 0.0001, batch size = 6). All models are trained with two NVIDIA GeForce RTX 3090 GPUs. Stage-i and ii are trained with 100 epochs, and Stage-iii is with 200 epochs. We use the temperature of teacher and student as \(\tau _t= 0.01\) and \(\tau _s=0.1\). The teacher is updated using the following rule \( \theta _t \leftarrow m \theta _t + (1-m) \theta _s\), where \(\theta \) refers to the model’s parameters and the momentum hyperparameter m is 0.99. The memory bank size is 36. We follow the standard augmentation strategies in [7]. In Stage-i, we train \(E_{s}\), \(E_{t}\), \(H_{t}^{g}\), \(H_{s}^{g}\), and \(H_{p}^{g}\) on the unlabeled data with global-level \(\mathcal {L}_{\text {contrast}}\) in Eq. 3. We follow [9] to use a MLP as heads, and the setting of the predictors is similar to [7], which has a feature dimension of 512. In Stage-ii, we train \(F_{s}\), \(F_{t}\), \(H_{t}^{l}\), \(H_{s}^{l}\), and \(H_{p}^{l}\) on the labeled and unlabeled data. We train with the supervised loss [36] on labeled data, and local-level \(\mathcal {L}_{\text {contrast}}\) in Eq. 3 on unlabeled data. Given the logits output \(\hat{y} \in \mathbb {R}^{C\times h\times w}\), we use the \(1\,\times \,1\) convolutional layer to project all pixels into the latent space with the feature dimension of 512, and the output feature dimension of G is also 512. As for Stage-iii, we train \(F_{s}\), \(F_{t}\), \(H_{t}\), \(H_{s}\), and \(H_{r}\) on the labeled and unlabeled data. We use the supervised segmentation loss on labeled data, unsupervised cross-entropy loss (on pseudo-labels generated by a confidence threshold \(\theta _s\)), and \(\mathcal {L}_{\text {anco}}\) in Eq. 4 on unlabeled data. We then adaptively sample 256 query samples and 512 key samples for each mini-batch, and temperature for the student and confidence thresholds are set to \(\tau _{s}=0.5\) and \(\theta _s=0.97\), respectively. Of note, the projection heads, the predictor, and the representation decoder head are only utilized during the training, and will be removed during the inference.

Main Results. We compare our proposed method to previous state-of-the-art SSL methods using 2D Unet [22] as backbone, including UNet trained with full/limited supervisions (UNet-F/UNet-L), EM [26], CCT [18], DAN [37], URPC [17], DCT [21], ICT [25], MT [23], UAMT [36], CPS [6], SCS [9], and GCL [3]. Table 1 shows the evaluation results on ACDC dataset under two unlabeled settings (3 or 7 labeled cases). ACTION can substantially improve results on two unlabeled settings, greatly outperforming the previous state-of-the-art SSL methods. Specifically, our ACTION, trained on 3 labeled cases, dramatically improves the previous best averaged Dice score from 73.6% to 87.5% by a large margin, and even matches previous SSL methods using 7 labeled cases. When using 7 labeled cases, ACTION further pushes the state-of-the-art results to 89.7% in Dice. We observe that the gains are more pronounced on the two categories (i.e., RV and Myo), and our ACTION achieves 89.8% and 86.7% in terms of Dice, performing competitive or even better than the supervised baseline (89.2% and 86.7%). As shown in Fig. 3, we can see the clear advantage of ACTION, where the boundaries of different regions are clearly sharper and more accurate such as RV and Myo regions. Table 2 also shows the evaluation results on LiTS dataset under two unlabeled settings (5% or 10% labeled cases). On both two labeled settings, ACTION significantly outperforms all the state-of-the-art methods by a significant margin. As shown in Fig. 4, ACTION achieves consistently sharp and accurate object boundaries compared to other SSL methods.

Table 3. Ablation on (a) model component: w/o Random Sampled Images (RSI); w/o Local Contrastive Distillation (Stage-ii); w/o Anatomical Contrast Fine-tuning (Stage-iii); (b) loss formulation: w/o \(\mathcal {L}_\textrm{anco}\); w/o \(\mathcal {L}_\textrm{unsup}\);, compared to the Vanilla and our proposed ACTION. Note that \(\mathcal {L}_\textrm{unsup}\) denotes cross-entropy loss (on pseudo-labels generated by a confidence threshold \(\theta _s\)) together with \(\mathcal {L}_\textrm{anco}\) used in Stage-iii.

Ablation on Different Components. We investigate the impact of different components in ACTION. All reported results in this section are based on the ACDC dataset under the 3 labeled setting. Table 3 shows the ablation result of our model. Upon our choice of architecture, we first consider a naïve baseline (BYOL) without any random sampled images (RSI), stage-ii, and stage-iii, denoted by (1) Vanilla. Then, we consider a wide range of different settings for improved representation learning: (2) incorporating other random sampled images; (3) no stage-ii; (4) no other random sampled images and stage-ii; (5) no stage-iii; since stage-iii includes two losses, (6) no \(\mathcal {L}_\textrm{anco}\), (7) no \(\mathcal {L}_\textrm{unsup}\), and (8) our proposed ACTION. As shown in Table 3, it is notable that ACTION performs generally better than other evaluated baselines. We find that only applying any single component of ACTION often comes at the cost of performance degradation. The intuitions behind are as follows: (1) incorporating other random sampled images will enforce the diversity of the sampled data, preventing redundant anatomically and semantically similar samples; (2) using stage-ii leads to worse performance without considering local context; (3) using stage-iii enables a robust segmentation model to learn better representations with few human annotations. Using the above components confers a significant advantage at representation learning, and further illustrates the benefit of each component.

Table 4. Ablation on augmentation strategies.

Ablation on Different Augmentations. We investigate the impact of using weak or strong augmentations for ACTION on the ACDC dataset under 3 labeled setting. We summarize the effects of different data augmentation strategies in Table 4. We apply weak augmentation to the teacher’s input, including rotation, cropping, flipping, and strong augmentation to the student’s input, including rotation, cropping, flipping, random contrast, and brightness changes [20]. Empirically, we find that when using weak and strong augmentation strategies on the teacher and student network, the network performance is optimal.

4 Conclusion and Limitations

In this work, we have presented ACTION, a novel anatomical-aware contrastive distillation framework with active sampling, designed specifically for medical image segmentation. Our method is motivated by two observations that all negative samples are not equally negative, and the underlying class distribution of medical images is highly unlabeled and imbalanced. Through extensive experiments across two benchmark datasets and unlabeled settings, we show that ACTION can significantly improve segmentation performance with minimal additional memory requirements, outperforming the previous state-of-the-art by a large margin. For future work, we plan to explore a more advanced contrastive learning approach for better performance when the medical data is unlabeled and imbalanced.