Keywords

1 Introduction

Deep learning has shown outstanding results in medical image analysis problems [16, 19,20,21, 28, 29]. However, this performance usually depends on the availability of labelled datasets, which is expensive to obtain given that the labelling process requires expert doctors. This limitation motivates the study of semi-supervised learning (SSL) methods that train models with a small set of labelled data and a large set of unlabelled data.

The current state-of-the-art (SOTA) SSL is based on pseudo-labelling methods [18, 26], consistency-enforcing approaches [2, 17, 27], self-supervised and semi-supervised learning (S\(^4\)L) [5, 34], and graph-based label propagation [1]. Pseudo-labelling is an intuitive SSL technique, where confident predictions from the model are transformed into pseudo-labels for the unlabelled data, which are then used to re-train the model [18]. Consistency-enforcing regularisation is based on training for a consistent output given model [22, 27] or input data [2, 17] perturbations. S\(^4\)L methods are based on self-supervised pre-training [4, 12], followed by supervised fine-tuning using few labelled samples [5, 34]. Graph-based methods rely on label propagation on graphs [1]. Recently, Yang et al. [33] suggested that self-supervision pre-training provides better feature representations than consistency-enforcing approaches in SSL. However, previous S\(^4\)L approaches use only the labelled data in the fine-tuning stage, missing useful training information present in the unlabelled data. Furthermore, self-supervised pre-training [4, 12] tends to use limited amount of samples to represent each class, but recently, Cai et al. [3] showed that better representation can be obtained with an infinite amount of samples. Also, recent research [26] suggests that the student-teacher framework, such as the mean-teacher [27], works better in multi-label semi-supervised tasks than other SSL methods. We speculate that this is because other methods are usually designed to work with softmax activation that only works in multi-class problems, while mean-teacher [27] does not have this constraint and can work in multi-label problems.

In this paper, we propose a self-supervised mean-teacher for semi-supervised (S\(^2\)MTS\(^2\)) learning approach that combines S\(^4\)L [5, 34] with consistency-enforcing learning based on the mean-teacher algorithm [27]. The main contribution of our method is the self-supervised mean-teacher pre-training with the joint contrastive learning [3]. To the best of our knowledge, this is the first approach, in our field, to train the mean teacher model with self-supervised learning. This model is then fine-tuned with semi-supervised learning using the exponential moving average teacher framework [27]. We evaluate our proposed method on the thorax disease multi-label datasets ChestX-ray 14 [32] and CheXpert [15], and on the multi-class skin condition dataset ISIC2018 [8, 30]. We show that our method outperforms the SOTA on semi-supervised learning [1, 11, 22, 31]. Moreover, we investigate each component of our framework for their contribution to the overall model in the ablation study.

Fig. 1.
figure 1

Description of the proposed self-supervised mean-teacher for semi-supervised (S\(^2\)MTS\(^2\)) learning. The main contribution of the paper resides in the top part of the figure, with the self-supervised mean-teacher pre-training based on joint contrastive learning, which uses an infinite number of pairs of positive query and key features sampled from the unlabelled images to minimise \(\ell _p(.)\) in (1). This model is then fine-tuned with the exponential moving average teacher in a semi-supervised learning framework that uses both labelled and unlabelled sets to minimise \(\ell _{cls}(.)\) and \(\ell _{con}(.)\) in (2).

2 Related Works

SSL is a research topic that is gaining attention from the medical image analysis community due to the expensive image annotation process [7] and the growing number of large-scale datasets available in the field [32]. The current SOTA SSL methods are based on consistency-enforcing approaches that leverage the unlabelled data to regularise the model prediction consistency [17, 27]. Other related papers [9] extend the mean teacher [27] to encourage consistency between the prediction by the student and teacher models for atrium and brain lesion segmentation. The SOTA SSL method on Chest X-ray images [22] exploits the consistency in the relations between labelled and unlabelled data. None of these methods explores a self-supervised consistency-enforcing method to pre-train an SSL model, as we propose.

Self-supervised learning methods [4, 12] are also being widely investigated in SSL because they can provide good representations [5, 34]. However, these methods ignore the large amount of unlabelled data to be used during SSL, which may lead to unsatisfactory generalisation process. An important point in self-supervised learning is on how to define the classes to be learned. In general, each class is composed of a single pair of augmented images from the same image, and many pairs of augmentations from different images [3,4,5, 12]. The use of a single pair of images to form a class has been criticised by Cai et al. [3], who propose the joint contrastive learning (JCL), which is an efficient way to form a class with an infinite number of augmented images from the same image to leverage the statistical dependency between different augmentations.

3 Method

In this section, we introduce our two-stage learning framework in detail (see Fig. 1). We assume that we have a small labelled dataset, denoted by \(\mathcal {D}_L = \{ (\mathbf {x}_i,\mathbf {y}_i)\}_{i=1}^{|\mathcal {D}_L|}\), where the image is represented by \(\mathbf {x} \in \mathcal {X} \subset \mathbb {R}^{H \times W \times C}\), and class \(\mathbf {y} \in \{0,1\}^{|\mathcal {Y}|}\), where \(\mathcal {Y}\) represents the label set. We consider a multi-label problem and thus \(\sum _{c=1}^{|\mathcal {Y}|}\mathbf {y}_i(c) \in [0,|\mathcal {Y}|]\). The unlabelled dataset is defined by \(\mathcal {D}_U = \{ \mathbf {x}_i\}_{i=1}^{|\mathcal {D}_U|}\) with \(|\mathcal {D}_L|<< |\mathcal {D}_U|\).

Our model consists of a student and a teacher model [27], denoted by parameters \(\theta ,\theta ' \in \mathbf {\Theta }\), respectively, which parameterize the classifier \(f_{\theta }:\mathcal {X} \rightarrow [0,1]^{|\mathcal {Y}|}\). This classifier can be decomposed as \(f_{\theta } = h_{\theta _1}\circ g_{\theta _2}\), with \(g_{\theta _2}:\mathcal {X} \rightarrow \mathcal {Z}\) and \(h_{\theta _1}:\mathcal {Z} \rightarrow [0,1]^{|\mathcal {Y}|}\). The first stage (top of Fig. 1) of the training consists of a self-supervised learning that uses the images from \(\mathcal {D}_L\) and \(\mathcal {D}_U\), denoted by \(\mathcal {D}^{\mathcal {X}} = \{ \mathbf {x}_i | \mathbf {x}_i \in \mathcal {D}^{\mathcal {X}}_L \bigcup \mathcal {D}_U\}_{i=1}^{|\mathcal {D}_{\mathcal {X}}|}\), with \(\mathcal {D}^{\mathcal {X}}_L\) representing the images from the set \(\mathcal {D}_L\), where our method minimises the joint contrastive learning loss [3], defined in (1). This means that during this first stage, we only learn the parameters for \(g_{\theta _2}\). The second stage (bottom of Fig. 1) fine-tunes this pre-trained student-teacher model using the semi supervised consistency loss defined in (2). Below we provide details on the losses and training.

3.1 Joint Contrastive Learning to Self-supervise the Mean-Teacher Pre-training

The self-supervised pre-training of the mean-teacher using joint contrastive learning (JCL) [3], presented in this section, is the main technical contribution of this paper. The teacher and student process an input image to return the keys \(\mathbf {k}\in \mathcal {Z}\) and the queries \(\mathbf {q}\in \mathcal {Z}\) with \(\mathbf {k}=g_{\theta '_2}(\mathbf {x})\) and \(\mathbf {q}=g_{\theta _2}(\mathbf {x})\). We also assume that we have a set of augmentation functions, i.e., random crop and resize, rotation and Gaussian blur, denoted by \(\mathcal {A}=\{ a_l:\mathcal {X} \rightarrow \mathcal {X} \}_{l=1}^{|\mathcal {A}|}\). Then JCL minimises the following loss [3]:

$$\begin{aligned} { \ell _{p}(\mathcal {D}^{\mathcal {X}},\theta _2,\theta '_2)= -\frac{1}{|\mathcal {D}^{\mathcal {X}}|} \frac{1}{M} \sum _{i=1}^{|\mathcal {D}^{\mathcal {X}}|} \sum _{m=1}^{M} \left[ \log \frac{\exp \left[ \frac{1}{\tau } \mathbf {q}_i^{\top } \mathbf {k}_{i,m}^{+} \right] }{\exp \left[ \frac{1}{\tau } \mathbf {q}_i^{\top } \mathbf {k}_{i,m}^{+} \right] + \sum _{j=1}^K \exp \left[ \frac{1}{\tau } \mathbf {q}_i^{\top } \mathbf {k}_{i,j}^{-} \right] } \right] ,} \end{aligned}$$
(1)

where \(\tau \) is the temperature hyper-parameter, the query \(\mathbf {q}_i = g_{\theta _2}(a(\mathbf {x}_i))\), with \(a \in \mathcal {A}\). the positive key \(\mathbf {k}_{i,m}^{+} \sim p(\mathbf {k}_i^{+})\), with \(p(\mathbf {k}_i^{+}) = \mathcal {N}(\mu _{\mathbf {k}_i},\varSigma _{\mathbf {k}_i})\) and \(\mathbf {k}_i = g_{\theta '_2}(a(\mathbf {x}_i))\) (i.e., a sample from the data augmentation distribution for \(\mathbf {x}\)), and the negative keys \(\mathbf {k}_{i,j}^{-} \in \{ \mu _{\mathbf {k}_j} \}_{i,j \in \{1,..,|\mathcal {D}^{\mathcal {X}}|\}, i \ne j }\) represents a negative key for query \(\mathbf {q}_i\). In (1), M denotes the number of positive keys, and Cai et al. [3] describe a loss that minimises a bound to (1) for \(M \rightarrow \infty \) – below, the minimisation of \(\ell _p(.)\) in (1) is realised by the minimisation of this bound. As defined above, the generative model \(p(\mathbf {k}_i^{+})\) is denoted by the Gaussian \(\mathcal {N}(\mu _{\mathbf {k}_i},\varSigma _{\mathbf {k}_i})\), where the mean \(\mu _{\mathbf {k}_i}\) and covariance \(\varSigma _{\mathbf {k}_i}\) are estimated from a set of keys \(\{ \mathbf {k}^{+}_{i,l} = g_{\theta '_2}(a_l(\mathbf {x}_i)) \}_{a_l \in \mathcal {A}}\) formed by different views of \(\mathbf {x}_i\). The set of negative keys \(\{ \mu _{\mathbf {k}_j} \}_{i,j \in \{1,..,|\mathcal {D}^{\mathcal {X}}|\}, i \ne j }\) is stored in a memory queue [12] that is updated in a first-in-first-out way, where the mean of the keys in \(\{ \mu _{\mathbf {k}_i} \}_{i=1}^{|\mathcal {D}^{\mathcal {X}}|} \) are inserted to the memory queue to replace the oldest key means from previous training iterations. The memory queue has been designed to increase the number of negative samples without sacrificing computation efficiency.

The training of the student-teacher model [12, 27, 35] is achieved by updating the student parameter using the loss in (1), as in \(\theta _2(t) = \theta _2(t-1) - \nabla _{\theta _2}\ell _{p}(\mathcal {D}^{\mathcal {X}},\theta _2,\theta '_2)\), where t is the training iteration. The teacher model parameter is updated with exponential moving average (EMA) with \(\theta '_2(t) = \alpha \theta '_2(t-1) + (1-\alpha )\theta _2(t) \), with \(\alpha \in [0,1]\). For this pre-training stage, we notice that training for more epochs always improve the model regularisation given that it is difficult to overfit the training set with the loss in (1). Hence, we select the last epoch student model \(g_{\theta _2}(.)\) to initialise the fine-tuning stage, defined below in Sec. 3.2.

3.2 Fine-Tuning the Mean Teacher

To fine tune the mean teacher, we follow the approach in [12, 27] using the following loss to train the student model:

$$\begin{aligned} \ell _{t}(\mathcal {D}_L,\mathcal {D}_U,\theta ,\theta ') = \frac{1}{|\mathcal {D}_L|}\sum _{(\mathbf {x}_i,\mathbf {y}_i) \in \mathcal {D}_L} \ell _{cls}(\mathbf {y}_i,f_{\theta }(\mathbf {x}_i)) + \frac{1}{|\mathcal {D}|}\sum _{\mathbf {x}_i \in \mathcal {D}} \ell _{con}(f_{\theta }(\mathbf {x}_i),f_{\theta '}(\mathbf {x}_i) ), \end{aligned}$$
(2)

where \(\ell _{cls}(\mathbf {y}_i,f_{\theta }(\mathbf {x}_i))=-\mathbf {y}_i^{\top }\log (f_{\theta }(\mathbf {x}_i))\), \(\ell _{con}(f_{\theta }(\mathbf {x}_i),f_{\theta '}(\mathbf {x}_i) ) = \Vert f_{\theta }(\mathbf {x}_i) - f_{\theta '}(\mathbf {x}_i) \Vert ^2\), and \(\mathcal {D}=\mathcal {D}_U \bigcup \mathcal {D}^{\mathcal {X}}_L\). The training of the student-teacher model [12, 27, 35] is achieved by updating the student parameter using the loss in (2), as in \(\theta (t) = \theta (t-1) - \nabla _{\theta }\ell _{t}(\mathcal {D}_L,\mathcal {D}_U,\theta ,\theta ')\), where t is the training iteration. The teacher model parameter is updated with exponential moving average (EMA) with \(\theta '(t) = \alpha \theta '(t-1) + (1-\alpha )\theta (t) \), with \(\alpha \in [0,1]\). After finishing the fine-tuning stage, we select the teacher model \(f_{\theta '}(.)\) to estimate the multi-label classification for test images.

4 Experiment

4.1 Dataset Setup

We use Chest X-ray14 [32], CheXpert [15] and ISIC2018 [8, 30] datasets to evaluate our method.

Chest X-ray14 contains 112,120 chest x-ray images from 30,805 different patients. There are 14 different labels (each label represents a disease) in the dataset, where each patient can have multiple diseases at the same time, forming a multi-label classification problem. To compare with previous papers [1, 22], we adopt the official train/test data split. For the self-supervised pre-training of the mean teacher, we used all the unlabelled images (86k samples) from the training set. For the semi-supervised fine-tuning of the mean teacher, we follow the papers [1, 22] and experiment with training sets containing different proportions of labelled data (2%,5%,10%,15%,20%). We report the classification result on the official test set (26,000 samples) using area under the ROC curve (AUC).

CheXpert contains around 220,000 images with 14 different diseases, and similarly to Chest X-ray14, each patient can have multiple diseases at the same time. For pre-processing, we remove all lateral view images and treat uncertain label as negative labels. We follow the semi-supervised setup from [11], and experiment with 100/200/300/400/500 labelled samples per class. We report results on the official test set using AUC.

ISIC2018 is a multi-class skin condition dataset that contains 10,015 images with seven different labels. Each image is associated with one of the seven labels, forming a multi-class classification problem. We follow [22] train/test split for fair comparison, where the training contains 20% of the samples labelled, and the remaining 80% unlabelled. We report the AUC, Sensitivity, and F1 score results to compare with baselines.

4.2 Implementation Details

For all datasets, we use DenseNet121 [14] as our backbone model. For self-supervised pre-training, we follow [5] and replace the two-layer multi-layer perceptron (MLP) projection head by a three-layer MLP. For dataset pre-processing, we resized Chest X-ray14 images to 512 \(\times \) 512 for faster processing and CheXpert and ISIC2018 to 224 \(\times \) 224 for fair comparison with baselines. We use the data augmentation proposed in [4], consisting of random resize and crop, random rotation, random horizontal flipping, except for random grayscale because X-ray images are originally in grayscale. The batch size is 128 for Chest X-ray14 and 256 for CheXpert and ISIC2018, and learning rate is 0.05. For the fine-tuning stage, we use batch size 32 with 16 labelled and 16 unlabelled. The fine-tuning takes 30 epochs with learning rate decayed by 0.1 at 15 and 25 epochs for all datasets. We use the SGD optimiser with 0.9 momentum for the pre-training stage, and Adam optimiser in fine-tuning stage. The code is written in Pytorch [24]. We use 4 Nvidia Volta-100 for the self-supervised stage and 1 Nvidia RTX 2080ti for fine-tuning.

4.3 Experimental Results

We evaluate our approach on the official test set of ChestX-ray14 using different percentage of labelled training data (i.e., 2\(\%\), 5\(\%\), 10\(\%\), 15\(\%\), 20\(\%\)), as shown in Table 1. The set of labelled data used for each percentage above follows the same strategy of previous works [1, 22]. Our S\(^4\)L achieves the SOTA AUC results on all different percentages of labels. Our model surpasses the previous SOTA SRC-MT [22] by a large margin of 8.7% and 6.8% AUC for the 2% and 5% labelled set cases, respectively, where we use a backbone architecture of lower complexity (Densenet121 instead of the DenseNet169 of [22]). Using the same Densenet121 backbone, GraphXnet [1] fails to classify precisely for the 2% and 5% labelled set cases. Our method surpasses GraphXnet by more than 20% AUC in both cases. Furthermore, we achieve the SOTA results of the field for the 10%, 15% and 20% labelled set cases, outperforming all previous semi-supervised methods [1, 22]. It is worth noting that our model trained with 5% of the labelled set achieves better results than SRC-MT with 15% of labelled. We also compare with a recently proposed self-supervised pre-training methods, MoCo V2 [6], adapted to our semi-supervised task, followed by the fine-tuning stage using different percentages of labelled data. Our method outperforms MoCo V2 by almost 10% AUC when using 2% of labelled set, and almost 3% AUC for 10% of labelled set. Our result for 20% labelled set achieves comparable 81.06% AUC performance as the supervised learning approaches – 81.20% from MoCo V2 (Densenet 121) and 81.75% from SRC-MT (Densenet 169) using 100% of the labelled samples. Such result indicates the effectiveness of our proposed S\(^2\)MTS\(^2\) in SSL benchmark problems.

We also show the class-level performance using 20% of the labelled data and compare with other SOTA methods in Table 2. We compare with the previous baselines, namely original mean teacher (MT) with Densenet169, SRC-MT with Densenet169, MoCo V2, and GraphXNet with Densenet121. We also train a baseline Densenet121 model with 20% labelled data using Imagenet pre-trained model. Our method achieves the best results on nine classes, surpassing the original MT [27] and its extension SRC-MT [22] by a large margin, demonstrating the effectiveness of our self-supervised learning.

Furthermore, we compare our approach on the fully-supervised Chest X-ray14 benchmark in Table 3. To the best of our knowledge, Hermoza et al. [13] has the SOTA supervised classification method containing a complex structure (relying on the weakly-supervised localisation of lesions) with a mean AUC of 82.1% (over the 14 classes), while ours reports a mean AUC of 82.5%. Hence, our model, using the whole labelled set, achieves the SOTA performance on 8 classes and an average that surpasses the previous supervised methods by a minimum of 0.4% and a maximum of 8% AUC.

The results on CheXpert and ISIC2018 datasets are shown in Tables 4 and 5, respectively. In particular, for CheXpert in Table 4, we compare our method with LatentMixing [11] and our result is better in all cases. For ISIC2018 on Table 5, using the test set from SRC-MT [22], our method outperforms all baselines (Supervised, MT, and SRC-MT) for all measures.

Table 1. Mean AUC result over the 14 disease classes of Chest X-Ray14 for different label set training percentages. * indicates the methods that use Densenet169 as backbone architecture.
Table 2. Class-level AUC comparison between our S\(^2\)MTS\(^2\) and other semi-supervised SOTA approaches trained with 20% of labelled data on Chest X-Ray14. * denotes the methods that use Densenet-169 as backbone.
Table 3. Class-level AUC comparison between our S\(^2\)MTS\(^2\) and other supervised SOTA approaches trained with 100% of labelled data on Chest X-Ray14.
Table 4. Mean AUC result (over the 14 disease classes) on CheXpert for different number of training samples per class.
Table 5. AUC, Sensitivity and F1 result on ISIC2018 using 20% of labelled training samples.

4.4 Ablation Study

We study the impact of different components of our proposed S\(^2\)MTS\(^2\) in Table 6 using Chest X-Ray14. Using the proposed self-supervised learning with just the student model, our model achieves at least 72.95% mean AUC on various percentages of labelled training data. Adding the JCL component improves the baseline by around 1% mean AUC on each training percentage. Adding the mean teacher boosts the result by 1.5% to 2% mean AUC on each training percentage. The combination of all our proposed three components achieves SOTA performance on semi-supervised task.

Table 6. Ablation studies of our method with different components on Chest X-Ray14. “Self-supervised” indicates the traditional self-supervised learning with contrastive loss [12]. “JCL” replaces contrastive loss with (1), “MT” stands for fine-tuned with student-teacher learning instead only fine-tuned on only labelled samples.

5 Conclusion

In this paper, we presented a novel semi-supervised framework, the Self-supervised Mean Teacher for Semi-supervised (S\(^2\)MTS\(^2\)) learning. The main contribution of S\(^2\)MTS\(^2\) is the self-supervised mean teacher pre-trained based on joint contrastive learning [3], using an infinite number of pairs of positive query and key features. This model is then fine-tuned with the exponential moving average teacher framework. S\(^2\)MTS\(^2\) is validated on the thorax disease multi-label classification problem from the datasets Chest X-ray14 [32] and CheXpert [15], and the multi-class classification from the skin condition dataset ISIC2018 [8, 30]. The experiments show that our method outperforms the previous SOTA semi-supervised learning methods by a large margin in all benchmarks containing a varying percentage of labelled data. We also show that the method holds the SOTA results on Chest X-ray14 [32] even for the fully-supervised problem. The ablation study shows the importance of three main components of the method, namely self-supervised learning, JCL, and the mean-teacher model. We will investigate the performance of our method on other semi-supervised medical imaging benchmarks in the future.