Keywords

1 Introduction

In the presence of a large training dataset that covers all possible data variations, deep neural networks (DNNs) can achieve super-human performance in image recognition and semantic segmentation tasks. However, in medical image segmentation tasks large annotated training datasets are often scarce. In addition, training and test data are drawn from different distributions. For example, the images were obtained using different scanners at different sites or the demographics of the subjects differ. This violation of the i.i.d. assumption (i.e., that training and test data are drawn independently from the same distribution) typically has the effect that the performance on the test data is significantly worse than on the training data.

Domain adaptation (DA) approaches try to alleviate the problem of applying models in new domains with different characteristics. In particular, semi-supervised DA methods provide a way to learn structure from unlabeled data in new domains. Among the several semi-supervised DA (SSL-DA) methods proposed, the most popular one is adversarial training based domain adaptation (ADA). ADA relies on generating features that are invariant with respect to a domain discriminator. ADA requires extensive parameter optimization due to the necessity of a robust discriminator. And a recent study pointed out the flaws in the evaluation of SSL-DA methods [1].

In this paper, we evaluate a modified knowledge distillation (KD) [2, 3] method for generalizing DNNs to new domains with a common clinical problem in contrast to using ADA methods. The datasets chosen for evaluation not only involve different magnetic resonance images (MRIs), but also were acquired on subjects with different demographic makeup. Through our evaluation, we show that the proposed KD is generally able to achieve better dice scores in segmenting white matter hyperintensities (WMH) on datasets that are not a part of the training data and do not share any attributes when compared to baseline and ADA.

2 Related Work

Among the recent works on DA, several methods rely on using a small amount of data (annotated) to fine-tune a baseline model [4, 5]. The performance of this approach not only relies on a new – albeit small – set of annotations but also on the choice of the set. In contrast, SSL-DA do not use data annotations on new target domains. Adversarial training is a popular SSL-DA method [6,7,8]. Here, networks are trained in such a way that the generated features are agnostic to the data domain with respect to a domain discriminator. A similar solution, ADA, was employed by [9] to adapt networks to be agnostic to domain changes.

Another class of DA method use KD to transfer representations between data domains. For instance, [10] proposed using KD to transfer knowledge between different modalities of the same scene. Closely related to our work is [11], where the authors propose to use omni-supervised learning (OSL) to include unlabelled data in the learning process. Here, data distillation is used to generate an ensemble of predictions from multiple transformations of unlabeled data, using a teacher model, to generate new training annotations. The proposed method differs from this method on two accounts: (a) Only soft labels are used to train the single student network, where the idea is to improve segmentation by learning label similarities from unannotated data (b) the data included in the training of the student involves data from new domains in small amounts in contrast to OSL.

3 Methods

In SSL-DA methods, we assume the source domain images and their annotations, \((x_s, y_s) \in \mathbf {X}_s\), are drawn from a distribution \(p_s(x_s,y)\). The target domain images \(x_t \in \mathbf {X}_t\), are drawn from a distribution \(p_t(x_t,y)\) where there are no annotations available. We consider classification into K classes. In an ideal scenario, where \(p_s\) and \(p_t\) are sufficiently similar, the goal is to find a feature representation mapping f that maps an input to K scores, where the \(i^{\text {th}}\) score models (up to a constant) the logarithm of the probability that the input belongs to class K. These scores can then be mapped by \(\sigma : \mathbb {R}^{K} \rightarrow \mathbb {R}^{K}\) to probability maps over the classes. SSL-DA first finds a function \(f_s\) performing well on a source domain and then finds a new \(f_t\) based on \(f_s\) that performs well on the target domain. Vanilla supervised learning methods rely on including annotations from both \(\mathbf {X}_s\) and \(\mathbf {X}_t\).

In the popular ADA method, the goal is to minimize the distance between the empirical distributions of \(p_s(f_s(\mathbf {X}_s)|y)\) and \(p_t(f_t(\mathbf {X}_t)|y)\). Here, a discriminator D is a neural network that distinguishes between the two domains. Therefore, the discriminator acts as a discrepancy measure that brings the two distributions together. Overall, adversarial training involves train a network that generates f in a standard supervised manner that is indistinguishable by a discriminator [6, 9].

3.1 Knowledge Distillation for Domain Adaptation

KD [2] was originally intended to compress neural networks with high number of parameters with networks of lower complexity. The objective is to teach a simpler student network to imitate a more complex trained teacher network, through a loss function called the distillation loss. To perform unsupervised domain adaptation, we proposed to use the teacher/student learning strategy. Specifically, the data from the source domain is used to train a teacher model in a supervised fashion. Then, the trained teacher is used to generate posterior probability maps or soft labels on the union of source and target data. These posterior probabilities are used instead of usual hard labels to train the student or target model. Note, this approach can take advantage of large amounts of unlabeled data acquired from any number of domains. An attractive feature of distillation loss is the soft representation of one-hot encoded label vectors which allow the student to be optimized over a smoother optimization landscape. Moreover, the smooth representation of labels also allows the learning of label similarities, which is particularly useful in learning boundaries in semantic segmentation tasks. The proposed semi-supervised learning method is formulated below.

Training the Teacher or Source Domain Model: Consider a set of N manually annotate images from a source domain \(\mathbf {X}_s = \{ (x_i,y_i), i=1 \dots N \}\), where \(x_i \in \mathbb {R}^d\) represent a d-dimensional MR scan, with \(v=1 \dots V\) voxels, and \(y_i \in [0,1]^K\) with \(\Vert y_i\Vert _1=1\) its correspondent label. Assuming there is a set \(F_s\) that holds functions \(f:\mathbb {R}^d \rightarrow \mathbb {R}^K\) we aim to learn a feature representation \(f_s\) (teacher model) which follows the optimization of a loss function, l, according to Eq. (1)

$$\begin{aligned} \mathop {{{\,\mathrm{arg\,min}\,}}}\limits _{f\in F_s}&\frac{1}{N} \sum _{x_i\in \mathbf {X_s}}l( y_i, \sigma (f_s(x_i))) \end{aligned}$$
(1)
$$\begin{aligned} {[\sigma (z)]_{k}}&= \frac{\mathrm {e}^{[z]_{k}}}{\sum _{l=1}^{K} \mathrm {e}^{[z]_l}} \end{aligned}$$
(2)

In a standard supervised learning way, the teacher network is optimized using the cross-entropy loss function (or any differentiable loss function of choice).

Training the Student or Target Model: Even though \(f_s\) is suitable to segment the images from the source domain \(\mathbf {X}_s\), it may not be suitable for data coming from a different data distribution \(\mathbf {X}_t\). Our goal is find a function \(f_t \in F_t\), which is suitable to segment data from \(\mathbf {X}_t\). Assuming, we have access to a limited set of unlabeled scans in the target domain \(\mathbf {X}_t=\{x_i, i= 1 \dots M\}\), we can then create a set

$$\begin{aligned} \mathbf {X}_U = \{(x_i, y_i) \,|\, x_i \in \mathbf {X}_s, y_i=f_s(x_i), 1\le i \le N \} \cup \\ \{(x_i, y_i) \,|\, x_i \in \mathbf {X}_t, y_i=f_s(x_i), 1\le i \le M \} \end{aligned}$$

that may be used to optimize a student using the distillation loss. Through soft-representations of this union dataset, the student is expected to learn a better mapping to the labels than the teacher network. When training the student network, we consider probability distributions over the labels as targets, not single classes. This representation reflects the uncertainty of the prediction by the teacher network. The function \(f_t\) is found by (approximately) solving,

$$\begin{aligned} \mathop {{{\,\mathrm{arg\,min}\,}}}_{f\in F_t} \frac{1}{(N+M)} \sum _{x_i\in \mathbf {X_U}}l( \sigma (T^{-1}f_s(x_i)), \sigma (f_{t}(x_i))), \end{aligned}$$
(3)

Here, \(T > 1\) is the temperature parameter which controls the softness of the class probability prediction given by \(f_s\).

4 Experiments and Results

4.1 Databases

The WMH segmentation challenge (https://wmh.isi.uu.nl/) dataset is a public database that contains T1-weighted and FLAIR scans for 60 subjects from three different clinics. The data also consists of manual annotations of WMH from presumed vascular origin. T1-weighted images have been registered to FLAIR since annotations were performed in this space. The images were also corrected for bias field inhomogenities using SPM12. An important feature of this dataset is that the scanners and demographics have variance as show in the Table 1.

Table 1. Summary of data characteristics in the WMH challenge database

4.2 Experimental Setup

One of the main objectives of the paper is to use semi-supervised learning to perform domain adaptation. We use the WMH challenge dataset to perform cross-clinical experiments in segmenting WMH on FLAIR images. We consider several scenarios to establish the performances of ADA and KD. The scenarios are described below. Note that, to evaluate the performance of the algorithms, dice overlap measures are used throughout.

  • Lower bound baseline, L-bound: Here a baseline DNN model is trained on the source dataset to establish a lower bound performance. The DNN is trained on the source domain images henceforth referred to as S, and tested on 20 subjects from a target dataset T.

  • Upper bound baseline, U-bound: Here, a baseline DNN model is trained like L-Bound, however, the training dataset is a union of images from both S and a subset of T (10 subjects, with annotations). The network is evaluated on the remaining 10 subjects in T.

  • Adversarial domain adaptation, ADA: Following [9], we attempt at training a DNN model that is invariant to data domains. In this paper, to be consistent with KD, we train the domain discriminator based on the final layer of the baseline, in contrast to what was proposed in [9]. We use a discriminator composed of 4 convolutional layers with 8, 16 32, 64 number of filters, followed by 3 fully connected layers with 64, 128 and 2 neurons. For this experiment, like U-bound, the training dataset is a union of images from both S and a subset of T (10 subjects, without annotations). The network is evaluated on the remaining 10 subjects in T.

  • Knowledge distillation, KD: The experimental setup for KD is the same as ADA. A temperature of 2 is used in the softmax for the distillation loss. The student network trained is identical to the teacher network whose architecture is a standard UNet (like L-bound, U-bound, and ADA) optimized with an ADAM loss function and a learning rate of \(10^{-4}\) with is gradual decrease after epoch 150. The network is trained for 400 epochs.

  • Adaptation on-the-fly: A clinically relevant scenario is adapting to a small set of test images on the fly by keeping the teacher/baseline model constant. To validate this scenario, we apply ADA and KD on the same 10 unannotated T that are included in the training, but subject-wise. In other words, separate adaptation is performed on each instance of T, instead of including them together.

Table 2. Illustrates dice overlaps (with variance). Bold fond indicates statistical significance at \(5\%\), p-values (paired-sample t-test at was used to computed p-values, which were \(0.0002< p < 0.02\)). Only ADA and KD methods are considered in the statistical comparison.
Table 3. Mean dice overlaps from the adaptation-on-the-fly scenario. Bold fond indicates statistical significance at \(5\%\), p-values (paired-sample t-test at was used to computed p-values, which were \(0.0003< p < 0.04\)). Only ADA and KD methods are considered in the statistical comparison.

4.3 Results

Various combinations of mismatched (in terms of clinics) training and testing data were used. For instance, if the training data is from clinic 1 (Utrecth), the testing data is from either clinic 2 (Singapore), or clinic3 (Amsterdam). We did not test on two different clinics even though this scenario is practical. Table 2 illustrates mean dice coefficients (two folds) for each of the scenarios mentioned in Sect. 4.2 except for adaptation on the fly which is illustrated in Table 3. KD outperformed ADA in nearly all scenarios except for domain adaptation from Singapore clinic to Utrecht clinic and vice versa. For domain adaptation from Utrecht clinic to Singapore clinic, ADA was significantly better than KD. In the vice-versa situation, KD achieved a better mean which is statistically not significant. In all other scenarios, KD yielded statistically better dice overlaps compared to ADA. Note that the statistical comparison are made only between ADA and KD. In the adaptation-on-the-fly scenario, KD yields significantly better dice overlaps on a majority of the scenarios, the superior performance of ADA remains in the experiment that involves domain adaptation from Utrecht clinic to Singapore clinic. However, in the vice-versa scenario, KD performance better than ADA. To illustrate the differences in segmentations between KD and ADA, we plot the segmentations (scenario, Utrecht clinic to Amsterdam clinic) in Fig. 1. As illustrated, both the methods perform quite well in segmenting lesions with relatively larger volume, however, the main difference is evident in segmenting smaller lesions, specially in the deep white matter regions. It is interesting to note that the adaptation-on-the-fly and the classical scenarios yield nearly the same dice indicating a good generalisability and less dependency on the choice of the small dataset coming from the target domain.

Fig. 1.
figure 1

Illustration of the segmentation’s obtained with different methods trained on the Utrecht dataset and tested on the Amsterdam dataset. The top and bottom row illustrate segmentations on two different subjects.

5 Discussion

The main objective of this paper was to present domain adaptation from a semi-supervised learning perspective. We have evaluated a modified knowledge distillation approach and compared it to the popular adversarial approach under different clinical scenarios. Overall, the knowledge distillation approach gave better results and is relatively simpler to design when compared to the more architecture-dependent adversarial approaches. Adversarial approaches require extensive tuning of DNN architectures, especially for the discriminator, in order to achieve reasonable performances. In contrast, KD only involves choosing the temperature parameter which can be chosen only based on the performances on the source domain. One of the interesting outcomes is the inferior performance of KD on domain adaptation in scenario, Utrecht clinic to Singapore clinic. One of the reasons may be attributed to not just scanner differences but also differences in demographics. This may have led to an inferior teacher performance that the student network relies on. To verify this, we used the improved network from domain adaptation using ADA as a teacher and then trained a student based on it. We observed that the mean dice overlap improved from \(0.65 \rightarrow 0.69\).

In future work, we will consider combining the adversarial approaches with knowledge distillation to improve the generalisability of DNNs across domains without the need for large annotated datasets.