Keywords

1 Introduction

Semantic segmentation is a prerequisite for a broad range of medical imaging applications, including disease diagnosis and treatment [13], surgical workflow analysis [6, 9], operation room planning, and surgical outcome prediction [7]. While supervised deep learning approaches have yielded satisfactory performance in semantic segmentation [8, 10], their performance is heavily limited by the labeled training dataset distribution. Indeed, a network trained on a dataset acquired with a specific device or configuration can dramatically underperform when evaluated on a different device or conditions. Overcoming this entails new annotations per device, a demand that is hard to meet, especially for semantic segmentation, and even more so in the medical domain, where expert knowledge is essential.

Driven by the need to overcome this challenge, numerous semi-supervised learning paradigms have looked to alleviate annotation requirements in the target domain. Semi-supervised learning refers to methods that encourage learning abstract representations from an unlabeled dataset and extending the decision boundaries towards a more-generalized or target dataset distribution. These techniques can be categorized into (i) consistency regularization [4, 15,16,17, 19, 22], (ii) contrastive learning [2, 11], (iii) adversarial learning [22], and (iv) self-training [24,25,26]. Consistency regularization techniques aim to inject knowledge via penalizing inconsistencies for identical images that have undergone different distortions, such as transformations or dropouts, or fed into networks with different initializations [4]. Specifically, the \(\Uppi \) model [15] penalizes differences between the predictions of two transformed versions of each input image to reinforce consistent and augmentation-invariant predictions. Temporal ensembling [15] is designed to alleviate the negative effect of noisy predictions by integrating predictions of consecutive training iterations. Cross-pseudo supervision regularizes the networks by enforcing similar predictions from differently initialized networks.

More recent deep self-training approaches based on pseudo labels have emerged as promising techniques for unsupervised domain adaptation. These techniques assume that a trained network can approximate the ground-truth labels for unlabeled images. Since no metric guarantees pseudo-label reliability, several methods have been developed to alleviate pseudo-label error back-propagation. To progressively improve pseudo-labeling performance, reciprocal learning [25] adopts a teacher-student framework where the student network performance on the source domain drives the teacher network weights updates. ST++ [24] proposes to evaluate the reliability of image-based pseudo labels based on the consistency of predictions in different network checkpoints. Subsequently, half of the more reliable images are utilized to re-train the network in the first step, and the trained network is used for pseudo-labeling the whole dataset for a second re-training step. Despite the effectiveness of state-of-the-art pseudo-labeling strategies, we argue that one important aspect has been underexplored: how can a trained network self-assess the reliability of its pixel-level predictions?

To this end, we propose a novel self-training framework with a self-assessment strategy for pseudo-label reliability. The proposed framework uses transformation-invariant highly-confident predictions in the target dataset for self-training. This objective is achieved by considering an ensemble of high-confidence predictions from transformed versions of identical inputs. To validate the effectiveness of our proposed framework on a variety of tasks, we evaluate our approach on three different semantic segmentation imaging modalities, including video (cataract surgery), optical coherence tomography (retina), and MRI (prostate), as shown in Fig. 1. We perform comprehensive experiments to validate the performance of the proposed framework, namely “Transformation-Invariant Self-Training”Footnote 1 (TI-ST). The experimental results indicate that TI-ST significantly improves segmentation performance for unlabeled target datasets compared to numerous state-of-the-art alternatives.

Fig. 1.
figure 1

Example images from the three adopted datasets: (1) cross-device-and-center instrument segmentation in cataract surgery videos (Cat101 vs. CaDIS), cross-device fluid segmentation in OCT (Spectralis vs. Topcon), and cross-institution prostate segmentation in MRI (BMC vs. BIDMC).

2 Methodology

Consider a labeled source dataset, \(\mathcal {S}\), with training images \(\mathcal {X_S}\) and corresponding segmentation labels \(\mathcal {Y_S}\), while we denote a target dataset \(\mathcal {T}\), containing only target images \(\mathcal {X_T}\). We aim to train a network using \(\mathcal {X_S}\), \(\mathcal {Y_S}\), and \(\mathcal {X_T}\) for semantic segmentation in the target dataset.

We propose to train the model using a self-supervised approach on the images \(\mathcal {X_T}\) by assigning pseudo labels during training. Typical pseudo labels are computed from independent predictions of unlabeled images. Instead, our proposed framework adopts a self-assessment strategy to determine the reliability of predictions in an unsupervised fashion. Specifically, we propose to target highly-reliable predictions generated by a network aiming for transformation-invariant confidence. Compared to self-ensembling strategies that penalize the distant predictions corresponding to the transformed versions of identical inputs, our goal is to filter out transformation-variant predictions. Indeed, our method reinforces the ensemble of high-confidence predictions from two versions of the same target sample. Our proposed TI-ST framework simultaneously trains on the source and target domains, so as to progressively bridge the intra-domain distribution gap. Figure 2 depicts our TI-ST framework, which we detail in the following sections.

Fig. 2.
figure 2

Overview of the proposed semi-supervised domain adaptation framework based on transformation-invariant self-training (TI-ST). Ignored pseudo-labels during unsupervised loss computation are shown in turquoise.

2.1 Model

At training time, images from the source dataset are augmented using spatial \(g(\cdot )\) and non-spatial \(f(\cdot )\) transformations and passed through a segmentation network, \(N(\cdot )\), by which the network is trained using a standard supervision loss. At the same time, images from the target dataset are also passed to the network. Specifically, we feed two versions of each target image to the network: (1) the original target image \(x_\mathcal {T}\), and (2) its non-spatially transformed version, \(\acute{x_\mathcal {T}} = f(x_\mathcal {T})\). Once fed through the network, the corresponding predictions can be defined as \(\tilde{y_{\mathcal {T}}} = \sigma (N(x_{\mathcal {T}}))\) and \(\tilde{\acute{y_{\mathcal {T}}}} = \sigma (N(\acute{x_{\mathcal {T}}}))\), where \(\sigma (\cdot )\) is the Softmax operation. We then define a confidence-mask ensemble as

$$\begin{aligned} \mathcal {M}_{cnf} = Cnf(\tilde{{y_{\mathcal {T}}}}) \odot Cnf(\tilde{\acute{y_{\mathcal {T}}}}), \end{aligned}$$
(1)

where \(\odot \) refers to Hadamard product used for element-wise multiplication, and Cnf is the high confidence masking function,

$$\begin{aligned} Cnf_{{\in \,(W\times H)}}(y) = {\left\{ \begin{array}{ll} 1, &{} \text {if } {{\,\textrm{max}\,}}_{{C}}(y) > \uptau \\ 0, &{} \text {else. } \end{array}\right. } \end{aligned}$$
(2)

where \(\uptau \in (0.5,1) \) is the confidence threshold, and H, W, and C are the height, width, and number of classes in the output, respectively. Specifically, \(\mathcal {M}_{cnf}\) encodes regions of confident predictions that are invariant to transformations. We can then compute the pseudo-ground-truth mask for each input from the target dataset as

$$\begin{aligned} \hat{\acute{y_{\mathcal {T}}}} = {\left\{ \begin{array}{ll} {{\,\textrm{argmax}\,}}_{\text {C}} (\tilde{\acute{y_{\mathcal {T}}}}), &{} \text {if } \,\mathcal {M}_{cnf} = 1\\ \text {ignore}, &{} \text {else. } \end{array}\right. } \end{aligned}$$
(3)

2.2 Training

To train our model, we simultaneously consider both the source and target samples by minimizing the following loss,

$$\begin{aligned} \mathcal {L}_{overall} = \mathcal {L}_{Sup}( \tilde{y_\mathcal {S}}, y_\mathcal {S}) + \lambda \Big (\mathcal {L}_{Ps}(\tilde{\acute{y_{\mathcal {T}}}}, \hat{\acute{y_{\mathcal {T}}}})\Big ), \end{aligned}$$
(4)

where \(\mathcal {L}_{Sup}\) and \(\mathcal {L}_{Ps}\) indicate the supervised and pseudo-supervised loss functions used, respectively. We set \(\lambda \) as a time-dependent weighing function that gradually increases the share of pseudo-supervised loss. Intuitively, our pseudo-supervised loss enforces predictions on transformation-invariant highly-confident regions for unlabeled images.

Discussion: The quantity and distribution of supervised data are determining factors in neural networks’ performance. With highly distributed large-scale supervisory data, neural networks converge to an optimal state efficiently. However, when only limited supervisory data with heterogeneous distribution from the inference dataset are available, using more sophisticated methods to leverage a priori knowledge is essential. Our proposed use of invariance of network predictions with respect to data augmentation is a strong form of knowledge that can be learned through dataset-dependent augmentations. The trained network is then expected to provide consistent predictions under diverse transformations. Hence, the transformation variance of the network predictions can indicate the network’s prediction doubt and low confidence correspondingly. We take advantage of this characteristic to assess the reliability of predictions and filter out unreliable pseudo-labels.

3 Experimental Setup

Datasets: We validate our approach on three cross-device/site datasets for three different modalities:

  • Cataract: instrument segmentation in cataract surgery videos [12, 21]. We set the “Cat101” [21] as the source dataset and the “CaDIS” as the target domain dataset [12].

  • OCT: IRF Fluid segmentation in retinal OCTs [1]. We use the high-quality “Spectralis” dataset as the source and the lower-quality “Topcon” dataset as the target domain.

  • MRI: multi-site prostate segmentation [18]. We sample volumes from “BMC” and “BIDMC” as the source and target domain, respectively.

We follow a four-fold validation strategy for all three cases and report the average results over all folds. The average number of labeled training images (from the source domain), unlabeled training images (from the target domain), and test images per fold are equal to (207, 3189, 58) for Cataract, (391, 569, 115) for OCT, and (273, 195, 65) for MRI dataset.

Baseline Methods: We compare the performance of our proposed transformation-invariant self-training (SI-ST) method against seven state-of-the-art semi-supervised learning methods: \(\Uppi \) models [15], temporal ensembling [15], mean teacher [19], cross pseudo supervision (CSP) [4], reciprocal learning (RL) [25], self-training (ST) [24], and mutual correction framework (MCF) [23].

Networks and Training Settings: We evaluate our TI-ST framework using two different architectures: (1) DeepLabV3+ [3] with ResNet50 backbone [14] and (2) scSE [20] with VGG16 backbone. Both backbones are initialized with the ImageNet [5] pre-trained parameters. We use a batch size of four for the Cataract and MRI datasets and a batch size of two for the OCT dataset. For all training strategies, we set the number of epochs to 100. The initial learning rate is set to 0.001 and decayed by a factor of \(\gamma = 0.8\) every two epochs. The input size of the networks is \(512\times 512\) for cataract and OCT and \(384\times 384\) for the MRI dataset. As spatial transformations \(g(\cdot )\), we apply cropping and random rotation (up to 30 degrees). The non-spatial transformations, \(f(\cdot )\), include color jittering (brightness = 0.7, contrast = 0.7, saturation = 0.7), Gaussian blurring, and random sharpening. The confidence threshold \(\uptau \) for the self-training framework and the proposed TI-ST framework is set to 0.85 except in the ablation studies (See the next section). In Eq. (4), the weighting function \(\lambda \) ramps up from the first epoch along a Gaussian curve equal to \(\exp [-5(1-\text {current-epoch}/{\text {total-epochs}})]\). The self-supervised loss is set to the cross-entropy loss, and the supervised loss is set to the cross entropy log dice loss, which is a weighted sum of cross-entropy and the logarithm of soft dice coefficient. For the TI-ST framework, we only use non-spatial transformations for the self-training branch for simplicity.

Table 1. Quantitative comparisons in Dice score (%) among the proposed (TI-ST) and alternative methods for DeepLabV3+ [3] (DLV3+) and scSENet [20] and the three datasets. Relative Dice computed over the Supervised baseline. The best results are shown in .
Fig. 3.
figure 3

Ablation studies on the pseudo-labeling threshold and size of the labeled dataset.

4 Results

Table 1 compares the performance of our transformation-invariant self-training (TI-ST) approach with alternative methods across three tasks and using two network architectures. According to the quantitative results, TI-ST, RL, ST, and CPS are the best-performing methods. Nevertheless, our proposed TI-ST achieves the highest average relative improvement in dice score compared to naive supervised learning (\(16.18\%\) average improvement). Considering our main competitor (RL), we note that our proposed TI-ST method is a one-stage framework using one network. In contrast, RL is a two-stage framework (requiring a pre-training stage) and uses a teacher-student network. Hence, TI-ST is also more efficient than RL in terms of time and computation. Furthermore, the proposed strategy demonstrates the most consistent results when evaluated on different tasks, regardless of the utilized neural network architecture.

Fig. 4.
figure 4

Ablation study on the performance stability of TI-ST vs. ST across the different experimental segmentation tasks.

Fig. 5.
figure 5

Qualitative comparisons between the performance of TI-ST and four existing methods.

Figure 3-(a–b) demonstrates the effect of the pseudo-labeling threshold on TI-ST performance compared with regular ST. We observe that filtering out unreliable pseudo-labels based on transformation variance can remarkably boost pseudo-supervision performance regardless of the threshold. Figure 3-(c) compares the performance of the supervised baseline, ST, and TI-ST with respect to the number of source-domain labeled training images. While ST performance converges when the number of labeled images increases, our TI-ST pushes decision boundaries toward the target domain dataset by avoiding training with transformation variant pseudo-labels. We validates the stability of TI-ST vs. ST with different labeling thresholds (0.80 and 0.85) over four training folds in Fig. 4, where TI-ST achieves a higher average improvement relative to supervised learning for different tasks and network architectures. This analysis also shows that the performance of ST is sensitive to the pseudo-labeling threshold and generally degrades by reducing the threshold due to resulting in wrong pseudo labels. However, TI-ST can effectively ignore false predictions in lower thresholds and take advantage of a higher amount of correct pseudo labels. This superior performance is depicted qualitatively in Fig. 5.

5 Conclusion

We proposed a novel self-training framework with a self-assessment strategy for pseudo-label reliability, namely “Transformation-Invariant Self-Training” (TI-ST). This method uses transformation-invariant highly-confident predictions in the target dataset by considering an ensemble of high-confidence predictions from transformed versions of identical inputs. We experimentally show the effectiveness of our approach against numerous existing methods across three different source-to-target segmentation tasks, and when using different model architectures. Beyond this, we show that our approach is resilient to changes in the methods hyperparameter, making it well-suited for different applications.