1 Introduction

Alzheimer’s disease (AD) is a slow fatal neurodegenerative disease affecting people over the age of 65 years. Thus, the identification and analysis of AD and its prodromal phase, i.e., mild cognitive impairment (MCI), are essential for early treatment and possible delays in disease progression. Significant advances in neuroimaging have provided opportunities to study brain-related diseases, bringing improvements in early detection of AD and MCI. Structural magnetic resonance imaging (MRI) is widely used in AD studies because it provides an excellent spatial resolution and a non-invasive method to study a patient’s brain. MRI-based methods help predict the progression of MCI to AD, by describing brain atrophy and change in the size of brain tissues explicitly [1, 2].

Conventional MRI-based methods often suffer from inter-domain data heterogeneity caused by different scanning parameters, because models learned on source domains are directly applied to problems in the to-be-analyzed target domain [3,4,5]. For example, Cheng et al. [3] proposed the sharing domain transfer learning method for MCI conversion prediction, which directly use the source domain to be auxiliary for the target domain to select a subset of common features. Several deep learning methods have been recently developed to alleviate the issue of domain shift [6,7,8,9,10]. For example, Motiian et al. [6] provided a unified framework to learn an embedding subspace for addressing the problem of supervised domain adaptation. A deep domain confusion network [7] was developed to map data from both domains into a common feature space to reduce the domain shift, measured by the maximum mean discrepancy (MMD). This method was further extended to a deep adaptation network (DAN) via a multi-layer multi-kernel selection technique [8] and a joint adaptation network (JAN) [9] for natural image classification. Long et al. [10] presented a conditional adversarial domain adaptation network (CDAN) by integrating adversarial learning and domain adaptation to a unified framework.

Fig. 1.
figure 1

Illustration of the proposed Unsupervised Conditional consensus Adversarial Network (UCAN), including (1) a feature extraction module with a convolutional (Conv) layer and four type of residual blocks as the backbone, (2) a cycle feature adaptation (CFA) module to harmonize learned features and classifiers of the source and target domains, and (3) a classification module. The input are a pair of images, i.e., \(x_i^s\) from the labeled source domain and \(x_j^t\) from the unlabeled target domain (with learned features as \(f_{i}^s\) and \(f_{j}^t\), respectively). Here, \(g^G=C_s(G_{s\rightarrow t}(f_{i}^s))+C_t(G_{s\rightarrow t}(f_{i}^s))\).

However, existing methods typically suffer from two limitations. (1) They are usually based on the assumption that the source classifier and target classifier can be shared directly. Unfortunately, such an assumption is too strong and could not always hold in practical applications. (2) Many of them are supervised, relying on fully-labeled data in both target and source domains, and hence, they cannot be applied to general problems, where there are labeled data in the source domain and unlabeled data in the target domain.

As shown in Fig. 1, we propose an Unsupervised Conditional consensus Adversarial Network (UCAN) for brain disease identification, where classifiers learned from labeled source domain can be robustly adapt to a different target domain (without any label information). Three major components are included: (1) a feature extraction module for learning discriminate representations from the input MRI, (2) a cycle feature adaptation module to harmonize learned MRI features and classifiers of the source and target domains, and (3) a classification module. Experimental results on two public datasets with structural MRIs (acquired using different scanning parameters) suggest the efficacy of our method in both tasks of AD/MCI identification and MCI conversion prediction.

2 Materials and Method

2.1 Subjects and Structural MR Image Pre-processing

Two datasets from the ADNI database [11] were employed in this work, including ADNI1 and ADNI2. Since several subjects participated in both ADNI1 and ADNI2, we remove these subjects from ADNI2 to ensure that these two datasets are independent. Subjects in these datasets were divided into four categories: (1) AD, (2) cognitively normal (CN), (3) progressive MCI (pMCI) that would progress to AD within 36 months after baseline, and (4) static MCI (sMCI) that would not progress to AD. The baseline ADNI1 dataset consists of 1.5 T T1-weighted MR images acquired from a total of 785 subjects, including 231 NC, 246 sMCI, 103 pMCI, and 205 AD subjects. The baseline ADNI2 dataset includes 3.0 T T1-weighted sMRI data acquired from 721 subjects, including 205 NC, 312 sMCI, 42 pMCI, and 162 AD subjects. A standard pipeline was used to pre-process structural MR images, including (1) anterior commissure (AC)-posterior commissure (PC) alignment; (2) skull stripping; (3) intensity correction; (4) cerebellum removal; (5) linear alignment to the Colin27 template [12]; (6) re-sampling all MR images to have the same size of \(142\times 142\times 178\,\mathrm{mm}^3\) (with a spatial resolution of \(1\times 1\times 1\,\mathrm{mm}^3\)) corresponding to the coronal-plane view, sagittal-plane view, and axial-plane view, respectively; and (7) intensity inhomogeneity correction using the N3 algorithm [13].

2.2 Unsupervised Conditional Consensus Adversarial Network

We now present our conditional consensus adversarial domain adaptation framework (see Fig. 1). We consider the problem of unsupervised adaptation in this work. Denote \(x_i^s\) (with its label \(y_i^s\)) and \(x_j^t\) (without label) as the i-th and the j-th subjects from the source and target domains, respectively. Given \(n_s\) subjects from the source domain and \(n_t\) subjects from the target domain, we represent the labeled source domain as \(\mathcal D_s= \{(x_i^s,y_i^s)\}_{i=1}^{n_s}\) and the unlabeled target domain as \(\mathcal D_t= \{x_j^t\}_{j=1}^{n_t}\). The goal is to learn a model that can correctly predict the label of subjects from the target domain, based on labeled source data.

Feature Extraction Module: Each input image is first fed into a convolution (Conv) layer with the kernel/filter size of \(7\times 7\) and the channel size of 64 (stride: 2), followed by a max pooling with \(3\times 3\) filter (stride: 2). Then, we use the ResNet-50 model [14] to extract MRI features of images from both the source and target domains, including four types of residual blocks (with each block containing 3 Conv layers). The numbers of blocks are 3, 4, 6, and 3 for four types of residual blocks, respectively. And the different parameters (i.e., [filterschannels]) for these four types of blocks are listed below:

Cycle Feature Adaptation Module: Using learned features (i.e., \(f_i^s\) for \(x_i^s\) and \(f_j^t\) for \(x_j^t\)) via the feature extraction module, we further perform domain adaptation to harmonize data from the source and target domains. Let \(C_s\) and \(C_t\) be the task-specific classification models trained on the source and target domains, respectively. We can begin by simply learning a source classifier \(C_s\) with the cross-entropy loss as follows

(1)

where \(L(\cdot ,\cdot )\) is the cross-entropy function.

As shown in Fig. 1 (b), our proposed cycle feature adaptation module has a symmetric network design, based on which we augment the cycle generative adversarial structure to promote the efficacy of feature adaptation. By directly mapping the source feature (e.g., \(f_i^s\)) to the target domain via a generator \(G_{s\rightarrow t}\) and mapping the target feature (e.g., \(f_j^t\)) to the source domain via a generator \(G_{t\rightarrow s}\), we aim to remove the low-level differences between features from two domains, ensuring that our learned model is well-conditioned on target domain alignment. We assume that having cycles in both directions helps perform global domain alignment by learning features in the adaptation process, and employ the following source domain loss \(\mathcal {L}_{sdom}\) and the target domain loss \(\mathcal {L}_{tdom}\):

(2)
(3)

where \(D_s\) and \(D_t\) denote the discriminators corresponding to the source and target domains, respectively. And \(G_{s\rightarrow t}\) denotes the generator to map source features to the target domain, while \(G_{t\rightarrow s}\) is the generator to map target features to the source domain.

To distinguish features from different domains, we first concatenate the feature representation (e.g., \(f_i^s\)) and classifier prediction (e.g., \(g_i^s\)) from each domain. Such a concatenation operation could capture the inherent relationship of task-specific features and classifier prediction results, formulated as follows:

We then employ the above inherent relationship between features and classifier prediction as discriminative representation conditioning on the adversarial learning, encouraging the distributions of source data and target data to be similar. Thus, the adversarial loss based on the discriminative representations (e.g., \(h_i^s\) and \(h_j^t\)) of the source and target domains can be formulated as follows:

(4)

where D is the discriminator to tell the domain labels (e.g., source or target domain) of the input paired images based on \(h_i^s\) and \(h_j^t\).

Classification Module: Since data in the target domain are unlabeled, we propose to employ the generated feature representation (via the generator \(G_{s\rightarrow t}\)) and its label in the source domain to improve the robustness of the target classifier. Given a feature vector \(f_{i}^s\), we denote \(g^G=C_s(G_{s\rightarrow t}(f_{i}^s))+C_t(G_{s\rightarrow t}(f_{i}^s))\) as the prediction for our generated feature \(G_{s\rightarrow t}(f_{i}^s)\), and such predictions are achieved by the source classifier \(C_s\) and the target classifier \(C_t\). Then, we develop a task-specific loss to assist the classification tasks in both the source and target domains, which is formulated as follows:

(5)

through which we can explicitly learn the target classifier \(C_t\), even though there are no labeled data in the target domain.

By combining Eqs. (1)–(5), we obtain the objective function of our UCAN model as follows:

(6)

Implementation: The proposed network is implemented in PyTorch. Specifically, in the training stage, we first pre-train the feature extraction module using ResNet-50 [14] on ImageNet (with the Pytorch source code provided in https://github.com/KaimingHe/deep-residual-networks), and then train the entire network in an end-to-end manner for 500 epochs. The Adam solver is used with a batch size of 36 and a learning rate of \( 3\times 10^{-3} \). In the testing stage, we feed the testing MRI into the bottom part of the trained UCAN model (only the target domain) to predict its class label. For each 3D structural MR image, we extract its slices from three views (i.e., coronal-plane, sagittal-plane and axial-plane views), followed by feeding slides of each view to the proposed network. Finally, the results based on slides of three views for each subject are fused using the majority voting strategy to get the final result. Each slice is re-sized to \(256\times 256\) with zero-filling before being fed into our network. Besides, we apply both horizontal and vertical flip to MRI slides for data augmentation.

3 Experiment

Experimental Setup: The proposed UCAN method was compared with three state-of-the-art methods for deep domain adaptation using structural MR images, including (1) Deep Adaptation Network (DAN) [8], and (2) Joint Adaptation Network (JAN) [9], and (3) Conditional Domain Adaptation Network (CDAN) [10]. To evaluate the efficacy of the proposed cycle feature adaptation (CFA) module (see Fig. 1), we further compare our UCAN with its variant without using the CFA module (called UCAN-c). Except for the CFA module, UCAN-c and UCAN share the same network architecture. For a fair comparison, both DAN and JAN methods employ the suggested network architecture and parameters provided by the authors (see https://github.com/thuml/Xlearn), while CDAN uses the same settings provided by the respective paper (see https://github.com/thuml/CDAN). Two groups of experiments were performed, including (1) brain disease identification (i.e., AD vs. pMCI vs. sMCI vs. NC classification), and (2) MCI conversion prediction (i.e., pMCI vs. sMCI classification). To validate the robustness of a specific method, a two-fold cross-validation was used in the experiments. In the \(1^{st}\) fold (i.e., “ADNI1\(\rightarrow \)ADNI2”), we treated ADNI1 as the source domain and ADNI2 as the target domain. In the \(2^{nd}\) fold (i.e., “ADNI2\(\rightarrow \)ADNI1”), we regarded ADNI2 as the source domain and ADNI1 as the target domain. The classification accuracy was employed as the evaluation metric in the task of pMCI vs. sMCI classification, while the averaged accuracy among four categories was used in the task of brain disease identification.

Fig. 2.
figure 2

Results of five different methods in both tasks of (a) disease identification and (b) MCI conversion prediction.

Results of Disease Identification: Figure 2(a) reports the results achieved by five different methods in the task of brain disease identification (i.e., AD vs. pMCI vs. sMCI vs. NC classification), from which one may have the following observations. First, our UCAN method achieved the best performance in two-fold cross-validation, compared with three deep learning methods (i.e., DAN, JAN, and CDAN). Second, our UCAN consistently outperforms its degenerated variant (i.e., UCAN-c). For instance, using ADNI1 as the source domain (i.e., “ADNI1\(\rightarrow \)ADNI2”), the average accuracy of four-category classification achieved by UCAN is 0.4429 for subjects in the target ADNI2 domain, while UCAN-c only yields an average accuracy of 0.4351. This clearly suggests the effectiveness of our proposed cycle feature adaptation module to alleviate the data distribution shift among source and target domains. Besides, the overall performance achieved by five methods in the \(2^{nd}\) fold (i.e., “ADNI2\(\rightarrow \)ADNI1”) is worse than that of different methods in the \(1^{st}\) fold (i.e., “ADNI1\(\rightarrow \)ADNI2”). The possible reason is that ADNI2 is a relatively unbalanced dataset (e.g., the number of pMCI subjects is largely less than the other three categories, compared with ADNI1. In such a case, models trained on ADNI2 are less robust than those trained on the balanced ADNI1 dataset.

Results of Disease Progression Prediction: We also report the results of five methods in MCI conversation prediction (i.e., pMCI vs. sMCI classification) in Fig. 2(b). From this figure, a similar trend can be found as that in the task of brain disease identification. That is, our UCAN method is consistently superior to the four competing methods in both folds. The underlying reason is being that our method can capture the relationship between feature representations and classifier predictions as complementary information in the training process. Considering that the domain adaptation between the source and target domains in the task of pMCI vs. sMCI classification is a severe data unbalanced problem, these results further demonstrate the robustness of the proposed UCAN method.

Fig. 3.
figure 3

Results of the proposed UCAN model using MRI slices from three different views in (a) disease identification and (b) MCI conversion prediction.

Discussion and Future Work: In the current work, the input data of UCAN are 2D image slices extracted from three views (i.e., sagittal-plane, axial-plane, and coronal-plane views). We now investigate the influence of different views on the performance of our method, with results reported in Fig. 3. From Fig. 3, one can observe that the overall best performance is achieved by using the sagittal-plane view in \(1^{st}\) fold. In the experiments, we equally treat three views to generate the prediction result for a testing subject, by fusing the results of UCAN using slides of three views via majority voting. Using a weighted voting strategy seems to be more reasonable, which will be our future work. Besides, we plan to employ 3D (rather than 2D) convolution in the proposed network to take advantage of the global structure information of 3D MRIs.

4 Conclusion

We present an unsupervised conditional consensus adversarial network (UCAN) for deep domain adaptation, which can learn the disease classifier from the labeled source domain and adapt to a different target domain. Specifically, we first design a feature extraction module to learn representations from input MRI, followed by a cycle feature adaptation module to harmonize features and classifiers of the source and target domains. Experimental results on 1, 506 subjects suggest the efficacy of the proposed method.