Keywords

1 Introduction

Alzheimer’s disease (AD) is one of the most pervasive neurodegenerative disorders, causing an increasing morbidity burden that may outstrip diagnosis and management capacity with the population ages. The assessment of AD usually involves the acquisition of structural magnetic resonance imaging (sMRI) images, since it offers accurate visualization of the anatomy and pathology of the brain. Brain abnormalities (e.g., atrophy, enlargement, malformation) are known to be the most discriminative and reliable biomarkers [1] of AD that can be observed and analyzed through sMRI. However, automatic and reproducible identification of AD remains challenging due to heterogeneous of sMRI collected from different centers.

Recently, convolutional neural networks (CNN) have been used for automatic classification of AD from sMRI. Many methods [2, 3] use a bag of patches selected from the skull-stripped brain region, which ignores global context information that can play a significant role in identifying lesions for accurate inference [4]. Many studies [5,6,7,8] proposed to characterize AD using segmented anatomies (e.g., gray matter or hippcampus). These methods rely on the accurate segmentation of the anatomies which is usually performed in a multi-stage data processing pipeline with the help of third-party softwares (e.g., FreeSurfer [9]) driven by a prior template. However, template-driven methods depend on variable image registration accuracy and highly affected by the anatomical variability between subjects, introducing errors to the characterization of individualized abnormalities. Similarly, methods (e.g., [10]) use detected landmarks also depend on a template-driven pipeline. Taking advantage of attention mechanism, some methods [5] proposed to diagnose AD using sMRI images from multiple centers. However, the classification performance is either hardly reproducible or difficult to compare across studies. One of the major reasons is that existing methods are often trained with samples from the same training (source) domain, while testing samples come from an independent new (target) domain with a different feature distribution. In the literature, this situation relates to domain adaptation [11,12,13,14,15,16] or domain generalization [17,18,19]. A widely used solution for the problem is to learn a domain-invariant latent feature space [20]. Unfortunately, there is no guarantee that the target samples’ features will fall into the shared source domain-invariant representation, and in practice it is that new domains typically do not.

In this paper, we propose a novel domain-knowledge-constrained neural network for the diagnosis of AD using sMRI from multiple source domains. We designed a new domain-knowledge encoding module into a ResNet-like architecture for feature learning that yields a latent feature space with domain specific and domain shared information. In addition, we propose to use segmentation-free, resampling-free, patch-free 3D sub-images, which offers global context information and subject-level abnormalities to further refines generalizable and reproducible predictions.

2 Methods

We propose to design and implement an end-to-end neural network (Fig. 1) for automatic, robust, and reproducible diagnosis of AD using sMRI images, with the hope to identify and understand the most discriminative anatomical regions associate with AD. The model operates in 3 major steps: a) crop the input sMRI image to keep a sub-region (red rectangle), containing relevant anatomy structures (e.g., hippocampus, caudate, ventricles) associate with AD; b) extract features shared by all training sources based on ResNet [21]; c) design a domain-knowledge encoding module and a set of label predictors to constrain the feature learning process for better generalization.

Fig. 1.
figure 1

Schematic of the proposed generalizable classification model. Feature extractor is a ResNet18-like 3D network that extracts high-dimensional features from MRI images for classification using 3D convolution and residual connection. Basic block is the basic component of the feature extractor and consists of two 3D convolutional layers, two BatchNorm layers, a ReLu layer and residual connection. Classifier is a multilayer perceptron (MLP), consisting of two linear layers and a ReLu layer. Domain-Knowledge Encoding captures domain invariant features and domain-specific features and generates weights for classifiers based on domain similarity. Label Predictor specifies that our model has multiple mutually independent classifiers, and the predictions of all classifiers are weighted and summed to obtain the final output. (Color figure online)

2.1 Patch-Free 3D Feature Extractor

We first estimate a bounding box around relevant anatomical objects in the input sMRI. The objects are automatically identified by affine registration, which transforms the reference template to each image in the dataset to estimate label for the image. We note that, the estimated labels are only used to locate the bounding box, it has no effect on the individual’s atrophy since we pad extra space to ensure the cropped image contain all interested objects with respect to registration errors. Then, we crop the input image using the located bounding box to obtain the sub-image as input to our network. It need to be clarified that the cropping size is a fixed tuple determined by the maximum bounding box containing informative anatomical objects associated with AD.

To encode global context information, we propose a patch-free 3D feature extractor for different source domains, which is expected to learn domain-invariant features while not eliminating domain-specific features. Each domain has a unique label classifier, allowing adjustments for domain differences. Based on ResNet, we design our feature extractor as shown in Fig. 1. Each basic block consists of two convolutional layers. Each convolutional layer is followed by a batch normalization and a nonlinear activation function LeakyReLU. The basic block can be wrote as:

$$\begin{aligned} X_{l+1}=F(W_i,X_l)+W_sX_l, \end{aligned}$$
(1)

where \(X_l\) and \(X_{l+1}\) are the input and output of the basic block and \(F(W_i,X_l)\) denotes the nonlinear mapping in the basic block. Since the dimensions of X and \(F(W_i,X)\) must be the same for summation, we use the linear mapping \(W_s\) to adjust the dimensions of X in the shortcut connection.

In the proposed method, we use global average pooling function which is more suitable for disease classification, because the global average pooling operation reflects the information of gray matter volume in brain regions and preserves the relative position relationship between different channels of the feature map.

In the output layer, we use a softmax classifier based on cross-entropy loss to calculate the loss between the predicted and true labels.

$$\begin{aligned} \mathcal {L}=cross\text {-}entropy(\widehat{Y}_i(X_i\in D_s;\omega ),Y_i) \end{aligned}$$
(2)

2.2 Global Average Pooling

Global average pooling solves the problem of excessive image feature dimensions. If the feature maps of 3D images are directly expanded for classification, it will significantly increase the number of classifier parameters and increase the time and space complexity of training. Global average pooling averages the 3D feature maps in the channel dimension, preserving the relative position relationship between channels and reducing the resources required for model training.

The dimension change in the global average pooling is \([B,C,D,H,W]\rightarrow [B,C,1,1,1]\), where B denotes the batch-size and C denotes the channel number.

$$\begin{aligned} GAP(\delta )=\frac{1}{D\times H\times W}\sum _{i=1}^{D}{\sum _{j=1}^{H}{\sum _{k=1}^{W}{\delta _{i,j,k}}}} \end{aligned}$$
(3)

where \(\delta \) denotes the image feature extracted by ResNet, and D, H, W denote the three dimensions of the feature.

Since global average pooling has fewer parameters, it can prevent over-fitting to some extent, further more, global average pooling sums out the spatial information, thus it is more robust to spatial translation of the input.

2.3 Domain-Knowledge Encoding

The domain-knowledge encoding module is designed to give relative similarity weights to source domains from a new sample. The weights reflect the similarity between the testing sample and source domains, allowing the module to share strength only between similar domains.

Our model uses multiple classifiers for prediction from the features extracted by the feature extractor. The classifiers are independent from each other. We feed the image features to different classifiers and generate weights to each classifier, summing the predictions of each classifier according to the weights as the final output.

$$\begin{aligned} \widehat{Y}=\sum _{j=1}^{c\_num} \omega _{ij}\cdot classifier_j(\delta (X\in D_i),\theta _{j}) \end{aligned}$$
(4)

where \(\widehat{Y}\) denotes the prediction result of X, \(c\_num\) denotes the number of classifiers, \(D_i\) denotes the center which X belongs, \(\delta \) denotes the extracted feature from X, \(classifier_j\) denotes one classifier and \(\theta _{j}\) are the parameters in \(classifier_j\).

Multiple classifiers can capture the invariant and specific feature distributions between different domains, comparing the similarity of feature distributions between training source and unseen target domains by a joint training of the admixture classifiers, generating weights to integrate the feature distributions of known domains to fit the unknown domain feature distributions.

3 Experiments and Results

3.1 Data Description

Structural T1-weighted brain MRI data of 809 subjects (468 male, 341 female, age 68.16 ± 8.12 years, range 42–89 year) were acquired from 7 in-house independent multiple centers as detailed in [5, 22]. In total, 552 subjects (295 of normal control (NC), 257 of AD) were used for leave-center-out training. The rest 257 subjects with mild cognitive impairment (MCI) were used as an independent dataset for evaluation and compared with clinical diagnosis metrics.

3.2 Implementation Details

We first evaluated the model using leave-center-out cross-validation, where one center was selected for testing at a time and all remaining centers were used for training. Then, we applied the trained model on an independent validation set of unseen images for subjects with MCI. All images were cropped to have the same size of [80, 128, 72]. Image features were extracted with \(3\times 3\times 3\) convolution in the network and \(2\times 2\times 2\) convolution with a stride of 2 replacing the maximum pooling. The extracted features were passed through a global average pooling layer (Sect. 2.1). \(N=6\) independent classifiers were used.

During training, we sorted all training centers and feed the image features from \(site_i\) to all classifiers, and set the weight of \(classifier_{j(j=i)}\) to 1 and the weight of the rest classifiers to 0. We used cross-entropy to calculate the prediction error and update the parameters of the feature extractor and \(classifier_j\) by backpropagation. In testing stage, we feed the image features from the test center to all classifiers, and the final prediction was used the weighted average of predicted probability over all classifiers as the final prediction.

We used SGD algorithm to optimize the model coefficients, and set the initial learning rate to 0.001 and reduce the learning rate to one-tenth of the previous value every 50 epochs. The method was implemented using PyTorch 1.1 with Python 3.7. The experiments were run on an Intel Xeon CPU with 16 cores, 43 GB. RAM and a NVIDIA A5000 GPU with 24 GB memory. The code and model are available at https://github.com/Yanjie-Z/DomainKnowledge4AD.

Fig. 2.
figure 2

First row: the left panel evaluates the AUC-ROC curve for each domain through leave-center-out cross validation, and the right panel investigates the association between the predicted probabilities and clinical measure (MMSE) in subjects with Alzheimer’s disease (AD), mild cognitive impairment (MCI), and healthy controls (NC). Second row: attention map for an arbitrary example sMRI of a subject with AD, illustrating the most discriminative features learnt from the proposed approach.

3.3 Performance Evaluation

To evaluate the proposed approach, we feed 2 different types of input to the conventional 3D-ResNet [21] and each obtains a models: 1) ResNet, which use the original image as input, and 2) Baseline, which use the bounding box cropping strategy as proposed in Sect. 2.1. In addition, we incorporated the patch-free cropping strategy inspired by [4] to crop the middle-half sub-region of the original input sMRI image of the brain, and feed to ResNet, which we denote as ResNet-PF. The prediction performance are compared in Table 1.

Table 1. Comparisons among different methods with leave-center-out cross-validation. Abbreviations: ACC = accuracy, AUC = area under the curve of the receiver operating characteristic, AVG = average performance over centers. ACC in percentage.

Our model achieves an average classification accuracy of 89.25% on all test centers during cross-validation, compared to the average classification accuracy of 85.95% with baseline (without the use of domain knowledge encoding module).

We used AUC-ROC curves to evaluate the classification effectiveness [13, 17, 23] of the model on the test centers, and we counted the AUC-ROC curves for seven centers and compared them accordingly in Fig. 2.

To evaluate the interpretability of the model, we used Grad-CAM [24] to analyze the sensitive regions of the model in discriminating AD. We found that the model focused on the hippocampus in the images during prediction, which confirms that AD and the hippocampus have a significant correlation. We also find that the model pays more attention to the hippocampus in discriminating AD than healthy controls. Figure 3 compares the 3D attention map between a subject with AD and a healthy subject who never has AD, demonstrating obvious higher values in hippocampus region.

Fig. 3.
figure 3

3D attention maps for a healthy subject (first row) and a subject with AD (second row) in 4 different views (column). The bottom row shows a visual navigator.

4 Discussion

We proposed a novel reproducible and generalizable neural network to assist the automatically diagnosis of AD that benefits from domain knowledge and global contextual information with the help of segmentation-free, resampling-free, patch-free sub-image. The model was evaluated with leave-center-out cross-validation and with an independent set of unseen images for subjects with MCI (Fig. 2). It obtains an average accuracy of 89.25%, loss of 0.39 and AUC of 0.92 comparing with \(85.95\%\), 0.58 and 0.91 using ResNet. We apply the proposed model to images from a new domain (never used during training), demonstrating promising results.

We did ablation studies to evaluate the proposed method (Table 1), unsurprisingly, the cropped images obtain the best performance. Figures 2 and 3 evaluated the explainability of the proposed neural network. The results suggest that the hippocampus and ventricles regions suffer the most in AD, which are consistent with multi-stage segmentation-based methods [5], and clinical measures (in terms of MMSE) on an independent dataset (Fig. 2).

Our results and all comparative frameworks tend to predict worse for center 3, probably because it has some subjects with AD who have higher MMSE (Fig. 2) making the diagnosis challenging. As opposite, all models provide the best accuracy for center 5. We will further explore possible reasons of this center imbalance in future work. Another limitation of the presented study is the empirical estimation of early stop strategy during leave-center-out cross validation based on the observed loss ranges. In future work, we will also explore a more automated mechanism to increase model robustness for images from more center.

5 Conclusion

We proposed a novel end-to-end domain-knowledge constrained neural network for automatic and reproducible diagnosis of AD using sMRI images. We proposed a new domain-knowledge encoding module that learn simultaneously with a ResNet-like feature extractor for domain specific and domain shared representations. The network directly takes the segmentation-free, patch-free images in original resolution as input, which is able to learn with global contextual information for subject-level pathological brain dysmorphologies features to further refines reproducible predictions. Our experiments demonstrate superior performance and generalize well to completely unseen domain.