Keywords

1 Introduction

Accurate and robust segmentation of organs or lesions from medical images is of great importance for many clinical applications such as disease diagnosis and treatment planning. With a large amount of labeled data, deep learning has achieved great success in automatic image segmentation [7, 10]. In medical imaging domain, especially for volumetric images, reliable annotations are difficult to obtain as expert knowledge and time are both required. Unlabeled data, on the other hand, are easier to acquire. Therefore, semi-supervised approaches with unlabeled data occupying a large portion of the training set are worth exploring.

Bai et al. [1] introduced a self-training-based method for cardiac MR image segmentation, in which the segmentation prediction for unlabeled data and the network parameters were alternatively updated. Xia et al. [14] utilized co-training for pancreas and liver tumor segmentation tasks by exploiting multi-viewpoint consistency of 3D data. These methods enlisted more available training sources by creating pseudo labels, however, they did not consider the reliability of the pseudo labels which may leads to meaningless guidance. Some approaches to semi-supervised learning were inspired by the success of self-ensembling method. For example, Li et al. [5] embedded the transformation consistency into \(\varPi \)-model [3] to enhance the regularization for pixel-wise predictions. Yu et al. [16] designed an uncertainty-aware mean teacher framework, which can generate more reliable predictions for student to learn. To exploit the structural information for prediction, Hang et al. [2] proposed a local and global structure-aware entropy regularized mean teacher for left atrium segmentation. In general, most teacher-student methods update teacher’s parameters using exponential moving average (EMA), which is an useful ensemble strategy. However, the EMA focuses on weighting the student’s parameters at each stage during training process, without evaluating the quality of parameters explicitly. It is more expected that the teacher model could purposefully update the parameters through a parameter evaluation strategy, so as to generate more reliable pseudo-labels.

In this paper, we design a novel strategy named reciprocal learning for semi-supervised segmentation. Specifically, we make better use of the limited labeled data by using reciprocal learning strategy so that the teacher model can update its parameters with gradient descent algorithm and generate more reliable annotations for unlabeled set as the number of reciprocal learning step increases. We evaluate our approach on the pancreas CT dataset and the Atrial Segmentation Challenge dataset with extensive comparisons to existing methods. The results demonstrate that our segmentation network consistently outperforms the state-of-the-art method in respect to the evaluation metrics of Dice Similarity (Dice), Jaccard Index (Jaccard), 95\(\%\) Hausdorff Distance (95 HD) and Average Symmetric Surface Distance (ASD). Our main contributions are three folds:

  • We present a simple yet efficient reciprocal learning strategy for segmentation to reduce the labeling efforts. Inspired by the idea from learning to learn, we design a feedback mechanism for teacher network to generate more reliable pseudo labels by observing how pseudo labels would affect the student. In our implementation, the feedback signal is the performance of the student on the labeled set. By reciprocal learning strategy, the teacher can update its parameters autonomously.

  • The proposed reciprocal learning strategy can be utilized directly in any CNN architecture. Specifically, any segmentation network can be used as the backbone, which means there are still opportunities for further enhancements.

  • Experiments on two public datasets show our proposed strategy can further raise semi-supervised segmentation quality compared with existing methods.

2 Methods

Figure 1 illustrates our reciprocal learning framework for semi-supervised segmentation. We deploy a meta-learning concept for teacher model to generate better pseudo labels by observing how pseudo labels would affect the student. Specifically, the teacher and student are trained in parallel: the student learns from pseudo labels generated by the teacher, and the teacher learns from the feedback signal of how well the student performs on the labeled set.

Fig. 1.
figure 1

The schematic illustration of our reciprocal learning framework for semi-supervised segmentation. In this paper, V-Net is used as the backbone. Again, we emphasize any segmentation network could be used as the backbone in our framework.

2.1 Notations

We denote the labeled set as \((x_l, y_l)\) and the unlabeled set as \(x_u\), where x is the input volume and y is the ground-truth segmentation. Let T and S respectively be the teacher model and the student model, and let their corresponding parameters be \(\theta _T\) and \(\theta _S\). We denote the soft predictions of teacher network on the \(x_u\) as \(T(x_u; \theta _T)\) and likewise for the student.

2.2 Reciprocal Learning Strategy

Figure 1 shows the workflow of our proposed reciprocal learning strategy. Firstly, the teacher model should be well pre-trained on labeled set \((x_l, y_l)\) in a supervised manner. We use cross-entropy loss (CE) as loss function:

$$\begin{aligned} \mathcal {L}_{pre-train} = {CE}(y_l, T(x_l; \theta _T)). \end{aligned}$$
(1)

Then we use the teacher’s prediction on unlabeled set as pseudo labels \(\widehat{y}_u\) to train the student model. Specifically, Pseudo Labels (PL) trains the student model to minimize the cross-entropy loss on unlabeled set \(x_u\):

$$\begin{aligned} \widehat{y}_u \sim T(x_u; \theta _T), \end{aligned}$$
(2)
$$\begin{aligned} \theta _S^\text {PL} = \mathop {\arg \min }_{\theta _S}\, {CE}(\widehat{y}_u, S(x_u; \theta _S)). \end{aligned}$$
(3)

After the student model updated, it’s expected to perform well on the labeled set and achieve a low cross-entropy loss, i.e. \({CE}(y_l, S(x_l; \theta _S^\text {PL}))\). Notice that the optimal student parameters \(\theta _S^\text {PL}\) always depend on the teacher parameters \(\theta _T\) via the pseudo labels (see Eq. (2) and (3)). Therefore, we express the dependency as \(\theta _S^\text {PL}(\theta _T)\) and further optimize \(\mathcal {L}_{feedback}\) with respect to \(\theta _T\):

$$\begin{aligned} \mathop {\min }_{\theta _T}\quad \mathcal {L}_{feedback}(\theta _S^\text {PL}(\theta _T)) = {CE}(y_l, S(x_l; \theta _S^\text {PL}(\theta _T))). \end{aligned}$$
(4)

For each reciprocal learning step (including one update for the student using Eq. (3) and one update for the teacher using Eq. (4) respectively), however, solving Eq. (3) to optimize \(\theta _S\) until complete convergence is inefficient, as computing the gradient \(\nabla _{\theta _T}\mathcal {L}_{feedback}(\theta _S^\text {PL}(\theta _T))\) requires unrolling the entire student training process. Instead, a meta-learning approach [6] is utilized to approximate \(\theta _S^\text {PL}\) with one-step gradient update of \(\theta _S\):

$$\begin{aligned} \theta _S^\text {PL} \approx \theta _S - \eta _S\nabla _{\theta _S}{CE}(\widehat{y}_u, S(x_u; \theta _S)), \end{aligned}$$
(5)

where \(\eta _S\) is the learning rate. In this way, the student model and the teacher model have an alternating optimization:

  1. (1)

    Draw a batch of unlabeled set \(x_u\), then sample \(T(x_u;\theta _T)\) from the teacher model, and optimize with stochastic gradient descent (SGD):

    $$\begin{aligned} \theta _S' = \theta _S - \eta _S\nabla _{\theta _S}{CE}(\widehat{y}_u, S(x_u; \theta _S)). \end{aligned}$$
    (6)
  2. (2)

    Draw a batch of labeled set \((x_l,y_l)\), and reuse the student’s update to optimize with SGD:

    $$\begin{aligned} \theta _T' = \theta _T - \eta _T\nabla _{\theta _T}\mathcal {L}_{feedback}(\theta _S'). \end{aligned}$$
    (7)

Optimize \(\theta _S\) with Eq. (6) can be simply computed via back-propagation. We now present the derivation for optimizing \(\theta _T\). Firstly, by the chain rule, we have

$$\begin{aligned} \begin{aligned} \frac{\partial \mathcal {L}_{feedback}(\theta _S')}{\partial \theta _T}&= \frac{\partial {CE}(y_l,S(x_l;\theta _S'))}{\partial \theta _T}\\&=\left. \frac{\partial \text {CE} \left( y_{l}, S\big (x_{l}; \theta _S' \big ) \right) }{\partial \theta _S} \right| _{\theta _S = \theta _S'} \cdot \frac{\partial \theta _S'}{\partial \theta _T} \end{aligned} \end{aligned}$$
(8)

We focus on the second term in Eq. (8)

$$\begin{aligned} \begin{aligned} \frac{\partial \theta _S'}{\partial \theta _T}&= \frac{\partial }{\partial \theta _T} [\theta _S - \eta _S\nabla _{\theta _S}{CE}(\widehat{y}_u, S(x_u; \theta _S))]\\&= \frac{\partial }{\partial \theta _T} \left[ - \eta _S \cdot \left( \left. \frac{\partial {CE}(\widehat{y}_u, S(x_u; \theta _S))}{\partial \theta _S} \right| _{\theta _S = \theta _S} \right) ^\top \right] \end{aligned} \end{aligned}$$
(9)

To simplify notations, we define the gradient

$$\begin{aligned} g_S(\widehat{y}_u) = \left( \left. \frac{\partial {CE}(\widehat{y}_u, S(x_u; \theta _S))}{\partial \theta _S} \right| _{\theta _S = \theta _S} \right) ^\top \end{aligned}$$
(10)

Since \(g_S(\widehat{y}_u)\) has dependency on \(\theta _T\) via \(\widehat{y}_u\), we apply the REINFORCE equation [13] to achieve

$$\begin{aligned} \begin{aligned} \frac{\partial \theta _S'}{\partial \theta _T}&= - \eta _S \cdot \frac{\partial g_S(\widehat{y}_u)}{\partial \theta _T} \\&= - \eta _S \cdot g_S(\widehat{y}_u) \cdot \frac{\partial \log {P\left( \widehat{y}_u | x_u; \theta _T \right) }}{\partial \theta _T} \\&= \eta _S \cdot g_S(\widehat{y}_u) \cdot \frac{\partial {CE}(\widehat{y}_u, T(x_u; \theta _T))}{\partial \theta _T} \end{aligned} \end{aligned}$$
(11)

Finally, we obtain the gradient

$$\begin{aligned} \begin{aligned} \nabla _{\theta _T}\mathcal {L}_{feedback}(\theta _S') =\,&\eta _S \cdot \left( \nabla _{\theta _S'} {CE}(y_l,S(x_l;\theta _S'))\right) ^\top \cdot \\ {}&\nabla _{\theta _S}{CE}(\widehat{y}_u,S(x_u;\theta _S)) \cdot \nabla _{\theta _T}{CE}(\widehat{y}_u,T(x_u;\theta _T)). \end{aligned} \end{aligned}$$
(12)

However, it might lead to overfitting if we rely solely on the student’s performance to optimize the teacher model. To overcome this, we leverage labeled set to supervise teacher model throughout the course of training. Therefore, the ultimate optimal equation of the teacher model can be summarized as: \(\theta _T' = \theta _T - \eta _T\nabla _{\theta _T}[\mathcal {L}_{feedback}(\theta _S')+\lambda {CE}(y_l, T(x_l; \theta _T))]\), where \(\lambda \) is the weight to balance the importance of different losses.

3 Experiments

3.1 Materials and Pre-processing

To demonstrate the effectiveness of our proposed method, experiments were carried on two different public datasets.

The first dataset is the pancreas dataset [11] obtained using Philips and Siemens MDCT scanners. It includes 82 abdominal contrast enhanced CT scans, which have resolutions of 512 \(\times \) 512 pixels with varying pixel sizes and slice thickness between 1.5–2.5 mm. We used the soft tissue CT window range of [−125, 275] HU, and cropped the images centering at pancreas regions based on the ground truth with enlarged margins (25 voxels)Footnote 1 after normalizing them as zero mean and unit variance. We used 62 scans for training and 20 scans for validation.

The second dataset is the left atrium dataset [15]. It includes 100 gadolinium-enhanced MR images, which have a resolution of 0.625 \(\times \) 0.625 \(\times \) 0.625 mm\(^3\). We cropped centering at heart regions and normalized them as zero mean and unit variance. We used 80 scans for training and 20 scans for validation.

In this work, we report the performance of all methods trained with 20\(\%\) labeled images and 80\(\%\) unlabeled images as the typical semi-supervised learning experimental setting.

3.2 Implementation Details

Our proposed method was implemented with the popular library Pytorch, using a TITAN Xp GPU. In this work, we employed V-Net [9] as the backbone. More importantly, it’s flexible that any segmentation network can be the backbone. We set \(\lambda =1\). Both the teacher model and the student model share the same architecture but have independent weights. Both networks were trained by the stochastic gradient descent (SGD) optimizer for 6000 iterations, with an initial learning rate \(\eta _T=\eta _S=0.01\), decayed by 0.1 every 2500 iterations. To tackle the issues of limited data samples and demanding 3D computations cost, we randomly cropped 96 \(\times \) 96 \(\times \) 96 (pancreas dataset) and 112 \(\times \) 112 \(\times \) 80 (left atrium dataset) sub-volumes as the network input and adopted data augmentation for training. In the inference phase, we only utilized the student model to predict the segmentation for the input volume and we used a sliding window strategy to obtain the final results, with a stride of 10 \(\times \) 10 \(\times \) 10 for the pancreas dataset and 18 \(\times \) 18 \(\times \) 4 for the left atrium dataset.

3.3 Segmentation Performance

We compared results of our method with several state-of-the-art semi-supervised segmentation methods, including mean teacher self-ensembling model (MT) [12], uncertainty-aware mean teacher model (UA-MT) [16], shape-aware adversarial network (SASSNet) [4], uncertainty-aware multi-view co-training (UMCT) [14] and transformation-consistent self-ensembling model (TCSM) [5]. Note that we used the official code of MT, UA-MT, SASSNet, TCSM and reimplemented the UMCT which didn’t release the official code. For a fair comparison, we obtained the results of our competitors by using the same backbone (V-Net) and re-training their networks to obtain the best segmentation results on the Pancreas dataset and the Left Atrium dataset.

The metrics employed to quantitatively evaluate segmentation include Dice, Jaccard, 95 HD and ASD. A better segmentation shall have larger values of Dice and Jaccard, and smaller values of other metrics.

Fig. 2.
figure 2

2D visualization of our proposed semi-supervised segmentation method under 20\(\%\) labeled images. The first two rows are the segmentation results of pancreas and the last two rows are the segmentation results of left atrium. Red and blue colors show the ground truths and the predictions, respectively. (Color figure online)

Table 1. Quantitative comparison between our method and other semi-supervised methods on the pancreas CT dataset.

We first evaluated our proposed method on pancreas dataset. The first two rows of Fig. 2 visualize 12 slices of the pancreas segmentation results. Apparently, our method consistently obtained similar segmented boundaries to the ground truths. Table 1 presents the quantitative comparison of several state-of-the-art semi-supervised segmentation methods. Compared with using only 20\(\%\) annotated images (the first row), all semi-supervised segmentation methods achieved greater performance proving that they could both utilize unlabeled images. Notably, our method improved the segmentation by 9.76\(\%\) Dice and 12.00\(\%\) Jaccard compared with the fully supervised baseline’s results. Furthermore, our method achieved the best performance over the state-of-the-art semi-supervised methods on all metrics. Compared with other methods, our proposed method utilized the limited labeled data in a better way by using reciprocal learning strategy so that the teacher model could update its parameters autonomously and generate more reliable annotations for unlabeled data as the number of reciprocal learning step increases. The first two rows of Fig. 3 visualize the pancreas segmentation results of different semi-supervised segmentation methods in 3D. Compared with other methods, our method produced less false positive predictions especially in the case as shown in the first row in Fig. 3.

Fig. 3.
figure 3

Four cases of 3D visualization of different semi-supervised segmentation methods under 20\(\%\) labeled images. The first two rows are the results of pancreas segmentation and the last two rows are the results of left atrium segmentation.

Table 2. Quantitative comparison between our method and other semi-supervised methods on the Left Atrium MRI dataset.

We also evaluated our method on the left atrium dataset, which is a widely-used dataset for semi-supervised segmentation. The last two rows of Fig. 2 visualize 12 segmented slices. Obviously, our results can successfully infer the ambiguous boundaries and have a high overlap ratio with the ground truths. A quantitative comparison is shown in Table 2. Compared with using only 20\(\%\) labeled images (the first row), our method improved the segmentation by 5.65\(\%\) Dice and 8.47\(\%\) Jaccard, which were very close to using 100\(\%\) labeled images (the second row). In addition, it can be observed that our method achieved the best performance than the state-of-the-art semi-supervised methods on all evaluation metrics, corroborating that our reciprocal learning strategy has the fully capability to utilize the limited labeled data. The last two rows of Fig. 3 visualize the left atrium segmentation results of different semi-supervised segmentation methods in 3D. Compared with other methods, our results were close to the ground truths and preserved more details and produced less false positives, which demonstrates the efficacy of our proposed reciprocal learning strategy.

We further conducted an ablation study to demonstrate the efficacy of the proposed reciprocal learning strategy. Specifically, we discarded our reciprocal learning strategy by fixing teacher model after it was well pretrained. The results degraded to 73.82\(\%\)/86.82\(\%\) Dice, 59.38\(\%\)/77.27\(\%\) Jaccard, 4.62/3.69 ASD and 17.78/12.29 95HD on pancreas/left atrium datasets, which shows our reciprocal learning contributes to the performance improvement.

4 Conclusion

This paper develops a novel reciprocal learning strategy for semi-supervised segmentation. Our key idea is to fully utilize the limited labeled data by updating parameters of the teacher and the student model in a reciprocal learning way. Meanwhile, our strategy is simple and can be used directly in existing state-of-the-art network architectures, where the performance can be effectively enhanced. Experiments on two public datasets demonstrate the effectiveness, robustness and generalization of our proposed method. In addition, our proposed reciprocal learning strategy is a general solution and has the potential to be used for other image segmentation tasks.