Keywords

1 Introduction

Many of the basic assumptions in machine learning are based on the fact that the data of source domain (training dataset) and the target domain (testing dataset) is independent and identically distributed. When the data distributions of source domain and target domain are different, domain adaptation for cross-domain classification becomes an effective solution [9]. Generally, similarity [19], data global structure [16, 31] and feature alignment [18] are advisable solutions to cross-domain classification. In addition, many measurement criteria such as maximum mean discrepancy (MMD) [27] and multiple kernel MMD [15] are generally applied in domain adaptation and have inspired many recognition models [6, 8, 28, 35]. Fine-tuning pre-trained models [20, 23] and domain adversarial neural network [5, 32] are also two efficient techniques for cross-domain classification. Thus, domain adaptation is critical to promote the generalization ability of neural network [7].

When lacking the labels of target domain in cross-domain classification, unsupervised domain adaptation has more profound significance [2]. For instance, CycleGAN [33] and its variants [14, 17, 29] can translate images from target to source [13, 34] to make the distribution approximate for classification [24]. However, there are still some issues that need further consideration. (i) Class-imbalanced datasets [11]: commonly exist in medical datasets, which may result in over-fitting during model training process and degrade the performance. (ii) Lack of labeled medical images [13]: leads to poor use of supervised learning. Fortunately, transfer learning [7, 15] and unsupervised learning [2, 18, 26] are employed to deal with such challenges. (iii) Separate training of different modules: current methods separately train knowledge transfer model and classification model from beginning to end [4]. That is, the training process is divided into two separate stages, which ignores knowledge transfer and information interaction between modules.

To address the above issues, we propose a coherent cooperative learning (CCL) framework based on transfer learning for unsupervised cross-domain classification, which is constructed by a proposed Wasserstein CycleGAN (WCycleGAN) for image translation and two classifiers for prediction. First, by training the WCycleGAN with the original images from the source domain and the target domain, we obtain a class-balanced dataset used to fine-tune two classifiers that are convolutional neural networks (CNNs) pre-trained on ImageNet. Specially, the classifier of the target domain uses the proposed cooperative mechanism and MMD criterion to achieve unsupervised cross-domain classification. Finally, we input both the probe image and its translated image generated by WCycleGAN into two classifiers, and get the final prediction by a voting strategy.

There are three contributions in the proposed method. (i) The proposed WCycleGAN and two classifiers are iteratively updated in a united process. (ii) The proposed collaborative training makes different modules complement each other. (iii) Knowledge transfer is reflected from three aspects in CCL: image translation by WCycleGAN, transfer learning by fine-tuning the CNN pre-trained on ImageNet to identify pneumonia, and the parameter passing in fine-tuning.

2 Method

The proposed CCL is constructed by a WCycleGAN G and two classifiers \( C_{t} \), \( C_{s} \), as shown in Fig. 1. The details will be introduced in the following section.

Fig. 1.
figure 1

The training (a) and testing (b) diagrams of the proposed method.

2.1 Wasserstein CycleGAN in the Proposed Method

WCycleGAN is designed for data augmentation and domain translation. We are inspired by Wasserstein GAN [1], Wasserstein GAN with gradient penalty [10] and CycleGAN [33] to develop the proposed WCycleGAN. The inputs of the source domain and the target domain are respectively described as labeled \(\mathcal {X} = \{x_{i}\}_{i=1}^{n}\) and unlabeled \( \mathcal {Y} = \{y_{i}\}_{i=1}^{n} \), where n is the number of images in a batch.

To solve the issue of class-imbalanced dataset, we construct an extreme class-balanced dataset \( \mathcal {X}_{B} = \{x_{i}\}_{i=1}^{m} \) by G, where m is the total number of images and the number of images per class is the same. As shown in Fig. 1 (a), in order to construct \( \mathcal {X}_{B} \), we first obtain a balanced dataset \( \overline{\mathcal {X}}\) by performing weighted sampling on \( \mathcal {X}\). Then we supplement \( \overline{\mathcal {X}} \) with some intermediate images, which are generated by G during the calculation of losses. In this way, some images may be resampled for \( \mathcal {X}_{B} \). Next, we will elaborate on the acquisition of the intermediate images during training G.

We respectively denote \( P_{\mathcal {X}} \) and \( P_{\mathcal {Y}} \) as the data distributions of the source domain and the target domain. There are four components in WCycleGAN: the two discriminators \( D_{s} \) and \( D_{t} \) can distinguish the domain of images; the two generators are able to translate images, i.e. \( G_{s2t} \): \( \mathcal {X} \rightarrow \mathcal {Y} \) and \( G_{t2s} \): \( \mathcal {Y} \rightarrow \mathcal {X} \).

The adversarial loss constrains the generators and discriminators:

$$\begin{aligned} \mathcal {L}_{adv} = \mathop {\mathbb {E}}\limits _{x \sim P_{\mathcal {X}}} [-D_{s}(G_{s2t}(x))] + \mathop {\mathbb {E}}\limits _{y \sim P_{\mathcal {Y}}} [-D_{t}(G_{t2s}(y))] \end{aligned}$$
(1)

In order to enforce \( G_{s2t} \) and \( G_{t2s} \) to be cycle-consistent with each other, we calculate the cycle consistency loss:

$$\begin{aligned} \mathcal {L}_{cyc} = \mathop {\mathbb {E}}\limits _{x \sim P_{\mathcal {X}}} [\left\| G_{t2s}(G_{s2t}(x)) - x \right\| _{1}] + \mathop {\mathbb {E}}\limits _{y \sim P_{\mathcal {Y}}} [\left\| G_{t2s}(G_{s2t}(y)) - y \right\| _{1}] \end{aligned}$$
(2)

To further ensure the generation capacity of the generator, we also introduce the identity mapping loss [26]:

$$\begin{aligned} \mathcal {L}_{idt} = \mathop {\mathbb {E}}\limits _{x \sim P_{\mathcal {X}}} [\left\| G_{t2s}(x) - x \right\| _{1}] + \mathop {\mathbb {E}}\limits _{y \sim P_{\mathcal {Y}}} [\left\| G_{s2t}(y) - y \right\| _{1}] \end{aligned}$$
(3)

It is noted that \( G_{t2s} \) and \( G_{s2t} \) are generators shown in Fig. 1 (a), from which we can get \( G_{t2s}(G_{t2s}(x)) \) of Eq. (2) and \( G_{t2s}(x) \) of Eq. (3) as the intermediate images for data augmentation. The overall loss of generators is:

$$\begin{aligned} \mathcal {L}_{gen} = \mathcal {L}_{adv} + \mathcal {L}_{cyc} + \mathcal {L}_{idt} \end{aligned}$$
(4)

Then, we take \( D_{s} \) for example to illustrate the optimal objective of discriminators, and its critic loss is:

$$\begin{aligned} \mathcal {L}_{crt}(D_{s}, G_{t2s}, \mathcal {X}, \mathcal {Y}) = \mathop {\mathbb {E}}\limits _{y \sim P_{\mathcal {Y}}} [D_{s}(G_{t2s}(y))] - \mathop {\mathbb {E}}\limits _{x \sim P_{\mathcal {X}}} [D_{s}(x)] \end{aligned}$$
(5)

On this basis, the gradient penalty term is added to constrain the gradient norm of the outputs of discriminator:

$$\begin{aligned} \mathcal {L}_{gp}(D_{s}, \mathcal {X}) = \mathop {\mathbb {E}}\limits _{\hat{x} \sim P_{\hat{\mathcal {X}}}} [\left\| ( \nabla _{\hat{x}} D_{s}(\hat{x})\right\| _{2} -1 )^{2} ] \end{aligned}$$
(6)

where \( \hat{x} = \epsilon x + (1-\epsilon )G_{s2t}(x) \) represents sampling from source data x and generating data \( G_{s2t}(x) \), and \( \epsilon \in U([0, 1]) \) is a random number. Therefore, the overall loss of \( D_{s} \) is:

$$\begin{aligned} \mathcal {L}_{dis}(D_{s}, G_{t2s}, \mathcal {X}, \mathcal {Y}) = \lambda _{1}\mathcal {L}_{crt}(D_{s}, G_{t2s}, \mathcal {X}, \mathcal {Y}) + \lambda _{2} \mathcal {L}_{gp}(D_{s}, \mathcal {X}) \end{aligned}$$
(7)

where \( \lambda _{1} \) and \( \lambda _{2} \) are two hyper-parameters. Similarly, \( \mathcal {L}_{dis}(D_{t}, G_{s2t}, \mathcal {Y}, \mathcal {X}) \) is the loss for \( D_{t} \). Finally, the optimization task of WCycleGAN is formulated by:

$$\begin{aligned} \mathcal {L}_{total} = \mathcal {L}_{gen} + \mathcal {L}_{dis}(D_{s}, G_{t2s}, \mathcal {X}, \mathcal {Y})+ \mathcal {L}_{dis}(D_{t}, G_{s2t}, \mathcal {Y}, \mathcal {X}) \end{aligned}$$
(8)

The WCycleGAN G is an important part of CCL: \( \mathcal {X}_{B} \) is formed with the help of G, and G is able to translate images to a given domain.

2.2 Cooperative Training of Classifiers Based on Transfer Learning

The classifiers \( C_{s} \) and \( C_{t} \) in a collaborative way are trained with a batch during each iterative update, and they are respectively designed for the classification tasks of the source domain and the target domain.

Supervised Classifier \(\boldsymbol{C}_{s}\) is obtained by fine-tuning a CNN that pre-trained on ImageNet. It can be used to classify images of the source domain and conditionally assist the training of \( C_{t} \).

In the current iteration, after constructing \( \mathcal {X}_{B} \) by G and feeding \( \mathcal {X}_{B} \) into \( C_{s} \), we can get the feature maps \( \mathcal {F} = \{f_{i}\}_{i=1}^{m}\in \mathbb {R}^{m\times (D\times K) } \) before the final softmax layer, where D is the dimension of images and K is the number of classes. Each \( x_i \in \mathcal {X}_{B} \) in the source domain is labeled as \( Label(x_i) \), so \( f_{i,Label(x_i)} \) is a D-dimensional feature map of \( x_i \). The objective function of \( C_{s} \) is a combination of the cross entropy loss \( \mathcal {L}_{S}\) and the center loss \(\mathcal {L}_{C} \) [30] balanced by a hyper-parameter \( \lambda _{3} \):

$$\begin{aligned} \mathcal {L}_{1} = \mathcal {L}_{S} + \lambda _{3}\mathcal {L}_{C} = -\sum _{i=1}^m log \left( \frac{e^{{f}_{i,Label(x_i)}}}{\sum _{k=1}^K e^{{f}_{i,k}}} \right) + \frac{\lambda _{3}}{2} \sum _{i=1}^m \left\| {f}_{i,Label(x_i)} - {c}_{l_{i}} \right\| _{2}^{2} \end{aligned}$$
(9)

where \( {c}_{l_{i}} \) is the feature center of the \( l_{i} \) class, and \( l_{i} \) is the label of the i-th images.

Unsupervised Classifier \( \boldsymbol{C}_{t} \) is also obtained by fine-tuning the CNN that pre-trained on ImageNet, which is an unsupervised classifier and can classify images of the target domain. During each iterative training of \( C_{t} \), we exploit labeled \(\mathcal {X}\), \(\mathcal {X}_{B}\) and unlabeled \(\mathcal {Y}\).

The unsupervised classifier \( C_{t} \) needs to receive the parameters and images from the labeled source domain, so as to realize unsupervised domain adaptation well. During the iterative training process, as shown in Fig. 1 (a), we design a cooperative mechanism to control \( C_{s} \) to pass parameters to \( C_{t} \) under two conditions: (i) the training/verification accuracy of \( C_{s} \) does not reach the threshold \(\tau \); (ii) the predictions of \( C_{t} \) per batch all belong to a certain class (over-fitting).

When the above conditions are not satisfied in the current iteration, \( C_{t} \) will train by itself to update, in which MMD [27] criterion is utilized to minimize the distribution distance between \( \mathcal {X} \) and \( \mathcal {Y} \). The squared MMD distance is:

$$\begin{aligned} d(\mathcal {X},\ \mathcal {Y})^{2} = \left\| \frac{1}{n} \left( \sum _{{{x}}_{i} \in \mathcal {X}} \phi ({x}_{i}) - \sum _{{{y}}_{i} \in \mathcal {Y}} \phi ({y}_{i}) \right) \right\| _{H }^{2} \end{aligned}$$
(10)

where \( \phi (\cdot ) \) is the mapping function corresponding to the Gaussian kernel, and the subscript \( H \) means the distance is measured by using \( \phi (\cdot ) \) to map the data into reproducing kernel Hilbert space.

By combining Eq. (10) with Eq. (9) using a hyper-parameter \( \lambda _{4} \), the optimized object of \( C_{t} \) is:

$$\begin{aligned} \mathcal {L}_{2} = \mathcal {L}_{1} + \lambda _{4}d(\mathcal {X},\ \mathcal {Y})^{2} \end{aligned}$$
(11)

2.3 Prediction

After the above iterative training, \( C_{s} \) and \( C_{t} \) are able to classify images of the corresponding domain, and G can translate the original images to a given domain. The output of each classifier is the predicted probability for each class. For obtaining more reliable prediction, we utilize a voting strategy of Fig. 1 (b) to fuse predictions.

For a probe y from the target domain, we can obtain \( y' = G_{t2s}(y) \) translated by G, whose distribution is as consistent as possible with \( \mathcal {X} \). Then we use both y and \( y' \) for prediction. The label of y is expressed as \( Label(y)= l_{k} \), where \( l_{k} \) represents the label of the k-th class. The indicator k is worked out by:

$$\begin{aligned} k = \left\{ \begin{aligned}&mode(\mathop {\arg \max }_{k} \left\{ C^{k}_{s}(y'), C^{k}_{t}(y) \right\} ),\ \text {if it is unique} \\&\mathop {\arg \max }_{k} (C^{k}_{s}(y') + C^{k}_{t}(y))\ \text {else} \end{aligned} \right. \end{aligned}$$
(12)

where the ‘mode’ of Eq. (12) chooses a value that appears the most frequently.

3 Experiments

3.1 Databases and Settings

Databases. We use three databases in the experiments, and their information is shown in Table 1. The Chest X-RayFootnote 1 is divided into training dataset and testing dataset [12]. Single lesionFootnote 2 and Multiple lesionsFootnote 3 are the training datasets of two open lesion recognition competitions. We name them according to the number of lesions: each image consists of at most/least one lesion.

Table 1. The number of images and examples of (a) Chest X-Ray database; (b) Single lesion database; (c) Multiple lesions database. Their data distributions vary due to different imaging equipment/standards.

Comparative Methods. Kermany et al. [12], Ayan et al. [3] and Gu et al. [7] (a two-step progressive transfer learning technique) are all the transfer learning methods for medical image classification. Additionally, we introduce the domain-adversarial training of neural networks (DANN) [5] and [7]-GAN (an adversarial learning technique with CycleGAN) for cross-domain classification.

Settings.All the experiments are carried out on Intel\(^\circledR \) Xeon\(^\circledR \) Gold 6230, and CCL is implemented with Pytorch 1.6.0. We set \( \lambda _{1} = 0.5 \) [33] and \( \lambda _{2} = 10.0 \) [10] in Eq. (7), \( \lambda _{3} = 1.0 \) [30] in Eq. (9) and \( \lambda _{4} = 2.0 \) [27] in Eq. (11). During training, the initial learning rate for SGD optimizer is 0.001, the weight decay is 0.0005 and the threshold \( \tau \) of cooperative mechanism is 0.8 (see the Supplement Sect. 1 for the related experiments). The number per batch \( n = 32 \), and one batch constitutes n mini-batches including 1 images in training WCycleGAN. We can load a simple pre-trained WCycleGAN to speed up the convergence, and the training can be completed within 10 epochs with the optimal model stored after every 50 batches. The visualization of the training process for verifying the performance of CCL can be seen from the Supplement Sect. 2.

Evaluation Metrics. The evaluation metrics are accuracy, precision, recall and F1 score. We take the normal diagnosis as the positive in this paper.

3.2 Performance of the Supervised Classifier

Table 2. Results of supervised classification on Chest X-Ray.

We first verify the performance of the supervised classifier \( C_{s} \) on Chest X-Ray. Consider the experimental results with diverse backbones, we choose MobileNetV2 (fine-tune the last 6 layers) [21], InceptionV3 (fine-tune the classification layer) [25] and Vgg16 (fine-tune the last 5 layers) [22] as backbones. From Table 2, \( C_{s} \) based on diverse backbones is always superior compared with other methods, which indicates that \( C_{s} \) is capable of classification in the same domain very well.

3.3 Evaluation of Unsupervised Cross-Domain Classification

We follow the common leave-one-domain-out strategy as [18], and use the three databases in pairs to test the performance of CCL on cross-domain classification. We get the average result of 3 runs. To ensure the comparison and repeatability of experimental results, we use the existed divided datasets for training/testing/validation, instead of using cross-validation that will change the composition of datasets. All samples of the source domain are used for training, and Chest X-Ray (for training) and Chest X-Ray (for testing) in the target domain are respectively used for validation and testing to ascertain the iterations when convergence. Then we can apply the iterations as stop criteria to test other datasets in the target domain.

From the results in Table 3, CCL has the best accuracy and F1 score for all tasks. For different classification task, diverse backbones have their own advantages: (i) when training with the large dataset and the limited device memory, the lightweight network such as MobileNetV2 is a priority; (ii) complex CNNs have better ability to prevent over-fitting to some extent.

Table 3. Results of unsupervised cross-domain classification

3.4 Visualization and Ablation Experiments

Figure 2 is the visualization of translating original images. Figure 2 (b) and (c) are respectively translated to another domain by CycleGAN and WCycleGAN, from which it is clear that WCycleGAN has the better generation capacity (e.g. the edges of ribs are as clear as the images of the source domain) than CycleGAN. Figure 2 (d) and (e) are respectively the intermediates in Eq. (2) and Eq. (3) and also used for data augmentation, which are similar to Fig. 2 (a) but not identical.

We take Vgg16 as the backbone to do the ablation experiments of method\( ^\# \)1 wo passing balanced dataset, method\( ^\# \)2 wo passing parameters and method\( ^\# \)3 wo generating balanced dataset, which are the operations of Fig. 1. From Fig. 3, it is obvious that method\( ^\# \)3 causes severe over-fitting, and method\( ^\# \)1 and \( ^\# \)2 are also greatly affected. Hence, these operations in CCL are all essential and beneficial.

Fig. 2.
figure 2

The visual image translation between Chest X-Ray (row1) and Multiple lesions (row2).

Fig. 3.
figure 3

The results of ablation experiments.

4 Discussion and Conclusion

In this paper, we present an effective framework named CCL based on transfer learning for unsupervised cross-domain classification. The class-balanced dataset of CCL contributes to avoiding over-fitting. Besides, the proposed method can overcome the problem of insufficient labels in medical data by combining transfer learning and unsupervised learning. During the training and testing process, WCycleGAN and two classifiers complement each other by cooperative learning, whose backbones can be flexibly modified to obtain competitive results. Experiments on three pneumonia databases indicate that the propose method achieves promising performance in unsupervised cross-domain classification.