Keywords

1 Introduction

Although it has been more than a century since Alois Alzheimer first described the clinical characteristics of Alzheimer’s Disease (AD) [3, 4], the disease still eludes early detection. AD is the leading cause of dementia, accounting for 60%-80% of all cases worldwide [20, 34], and the number of patients effected is growing. In 2015, Alzheimer’s Disease International (ADI) reported that over 46 million people were estimated to have dementia worldwide and that this number was expected to increase to 131.5 million by 2050 [29]. Since AD is a progressive disease, computer assisted early identification of the disease may enable early medical treatment to slow its progression.

Methods that require intensive expert input for feature collection, such as Morphometry [13], and more automated solutions based on deep learning [5, 8, 24] have been utilized in the computer assisted diagnosis of AD literature. These automated detection methods usually classify patients as belonging to one of three stages: Normal (patients exhibiting no signs of dementia and no memory complaints), Mild Cognitive Impairment (MCI) (an intermediate state in which a patient’s cognitive decline is greater than expected for their age, but does not interfere with activities of their daily life), and full AD.

A participant’s progression from one of the stages to the next, however, can take more than five years [30]. This can mean that when automated disease classification systems based on these three levels are used, patients at a near severe stage do not receive the required treatment because they are classified as belonging to the pre-severe stage. This is illustrated in Fig. 1(a) for five participants from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) study [28]. Using the typical classification approach (See Fig. 1(b)), for example, even though participant five is only a year away from progressing to the severe AD stage, they would be classified as MCI in the year 2009. To address this issue we focus on the clinical question “how far away is a progressive MCI patient on their trajectory to AD?” To do this we propose an ordinal categorization of brain images based on participants’ level of progression from MCI to AD as shown in Fig. 1(c). Our approach adds ordinal labels to MRI scans of patients with progressive MCI indicating how many years they are from progressing to AD, and we construct a dataset of 444 MRI scans from 288 participants with these labels and share a replication script.

Fig. 1.
figure 1

(a) Progression levels of five sample MCI participants where each dot represents an MRI image during an examination year. (b) shows the typical approach of organizing images for identification of progressive MCI or classification of MCI and AD. The images in the lower light orange section are categorized as MCI when preparing a training dataset (this includes those images from patients nearing progression to AD—near progression level 1.0). The images in the upper red section are categorized as AD. In (c) our approach to organizing brain images is illustrated using a Viridis map. Images are assigned ordinal progression levels \(\in [ 0.1,0.9 ]\) based on their distance in years from progressing to AD stage. (Color figure online)

In addition to constructing the dataset, we also develop a computer assisted approach to identifying a participant’s (or more specifically, their MRI image’s) progression level. Accurately identifying how far a patient is from progressing to full AD is of paramount importance as this information may enable earlier intervention with medical treatments [2]. Rather than using simple ordinal classification techniques, we use Siamese networks due to their ability to handle the class imbalance in the employed dataset [21, 33]. We use a Siamese network architecture, and a novel Weighted Siamese network that uses a new loss function tailored to learning to predict input MRI image’s likelihood of progression. Furthermore, we complement results of our Siamese network based method with interpretations of the embedding space using an auxiliary model explanation technique, T-distributed Stochastic Neighbor Embedding (t-SNE) [26]. t-SNE condenses high dimensional embedding spaces learned by a Siamese network into interpretable two or three dimensional spaces [9].

The main contributions of this paper are:

  1. 1.

    We provide a novel approach that interpolates ordinal categories between existing MCI and AD categories of the ADNI dataset based on participants’ progression levels.

  2. 2.

    We apply the first Siamese network approach to predict interpolated progression levels of MCI patients.

  3. 3.

    We propose a simple and novel variety of triplet loss for Siamese networks tailored to identifying progression levels of MCI patients.

  4. 4.

    Our experiments demonstrate that using our version of the triplet loss is better at predicting progression level than the traditional triplet loss. Code is shared onlineFootnote 1.

2 Related Work

Before the emergence of deep learning, and in the absence of relevant large datasets, computer-assisted identification of AD relied on computations that require expensive expert involvement such as Morphometry [13]. However, the release of longitudinal datasets, such as ADNI [28], inspired research on automated solutions that employed machine learning and deep learning methods for the identification of AD.

Most of the approaches proposed for AD diagnosis perform a classification among three recognized stages of the disease: Normal, MCI, and AD [23]. Some examples in the literature distinguish between all three of the categories [36], while others distinguish between just two: Normal and MCI [19], Normal and AD [32], or MCI and AD [31].

Patients at the MCI stage have an increased risk of progressing to AD, especially for elderly patients [30]. For example, in the Canadian Cohort Study of Cognitive Impairment and Related Dementia [17] 49 out of a cohort of 146 MCI patients progressed to AD in a two-year follow-up. In general, while healthy adult controls progress to AD annually at a maximum rate of 2%, MCI patients progress at a rate of 10%-25% [15]. This necessitates research on identifying MCI subjects at risk of progressing to AD. In a longitudinal study period, participants diagnosed with MCI can be categorized into two categories: (1) Progressive MCI, which represents participants who were diagnosed with MCI at some stage during the study but were later diagnosed with AD, and (2) Stable MCI, patients who stayed as MCI during the whole study period [16]. This excludes MCI participants with chances of reverting back to healthy, since they were also reported to have chances of progressing to AD [30]. There are some examples in the literature of using machine learning techniques such as random forest [27] and CNNs [16] to classify between stable and progressive MCI. While feature extraction is used prior to model training towards building relatively simpler models [23, 35], 3D brain images are also deployed with 3D CNNs to reduce false positives [5, 25].

The brain image classification task can also be transformed to ordinal classification to build regressor models. For example, four categories of AD: healthy, stable MCI, progressive MCI, and AD were used as ordinal labels to build a multi-variate ordinal regressor using MRI images in [14]. However, the output of these models gives no indication of the likelihood of a patient to progress from one stage to another. Furthermore, this does not provide prediction for interpolated inter-category progression levels. Albright et al. (2019) [2] used a longitudinal clinical data including ADAS13, which is a 13-item Alzheimer’s Disease Assessment Scale, and Mini-Mental State Examination (MMSE) to train multi-layer perceptron and recurrent networks for AD progression prediction. This work, however, uses no imaging data and it has been shown that brain images play a key role in improving diagnostic accuracy for Alzheimer’s disease [18].

Siamese networks, which use a distance-based similarity training approach [10, 12], have found applications in areas such as object tracking [7] and anomaly detection [6]. Although it does not focus on AD detection, we found [22] to be the closest approach to our proposed method in the literature. Li et al. [22] report that a Siamese network’s distance output could be translated to predict disease positions on a severity scale. Although this approach takes the output as severity scale without any prior training on disease severity, only deals with existing disease stages, and does not interpolate ordinal categories within, it does suggest Siamese networks as a promising approach for predicting ordinal progression levels.

3 Approach

In this section, we describe the datasets (and how they are processed), model architectures, model training and evaluation techniques used in our experiments, as well as our proposed triplet loss for Siamese networks.

3.1 Dataset Preparation

The data used in the experiments described here was obtained from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database. ADNI was launched in 2003, led by Principal Investigator Michael W. Weiner, MD (https://adni.loni.usc.edu). For up-to-date information, see https://adni.loni.usc.edu/.

Table 1. Image distribution across progression levels, \(\rho \).

From the ADNI dataset we identified MRI brain images of 1310 participants who were diagnosed with MCI or AD. 288 participants had progressive MCI, 545 had stable MCI, and the rest had AD. We used MRI images of the 288 progressive MCI participants to train and evaluate our models. We labeled the progressive MCI participants based on their progression levels towards AD, \(\rho \in [0.1, 1.0]\) with a step size = 0.1, where for a single participant, P, \(min(\rho )=0.1\) represents the first time P was diagnosed with MCI and P transitions to stage AD at \(max(\rho )=1.0\). This transforms the binary MCI and AD labels to 10 ordinal labels. An example of the data organization based on progression level is plotted in Fig. 1. The distribution of the constructed ordinal regression levels is shown in Table 1 (where \(\rho =1.0\) represents AD) where the imbalance between the different labels is clear. Within the ADNI dataset, the maximum number of MRI scans that the progressive MCI participants have had until they progressed to AD (\(\rho =1.0\)) is 9, which means that the smallest \(\rho \) is 0.2. We took advantage of Siamese networks robustness to class imbalance to circumnavigate the imbalance in the ordinal labels. By sub-sampling from the majority classes, we selected 444 3D MRI images (shape = 160\(\,\times \,\)192\(\,\times \,\)192) for the negative, anchor, and positive datasets (each holding 148 images) required when training a Siamese network using triplet loss. We used 80% of the images for training and the rest for testing. AD images were randomly separated to the anchor and positive dataset. We ensure that there is no participant overlap between sets when performing the data splitting between training and testing dataset, and between anchor and positive dataset.

3.2 Weighted Siamese Network

Fig. 2.
figure 2

Weighted Siamese network. The text ResNet-50 here refers to the base layers of the ResNet-50 architecture which are trained from scratch for extracting image embeddings, i.e. excluding the fully connected classifier layers. While we used 3D MRI images for model training and evaluation, sagittal plane is used here only for visualization purposes.

Siamese networks are usually trained using a triplet loss or its variants. While a traditional triplet loss teaches a network that a negative instance is supposed to be at a larger distance from the anchor than a positive instance, we propose a Weighted triplet loss that teaches a network that instances, which can all be considered to be in the negative category, are not at the same distance from the anchor and that their distance depends on their progression level, \(\rho \). So that lower progression levels have larger distance from an anchor instance, we transform \(\rho \) to a weighting coefficient \(\alpha = 1.9 - \rho \), excluding \(\rho =1.0\), as shown in Fig. 3. The architecture of our proposed Weighted Siamese network is shown in Fig. 2.

Fig. 3.
figure 3

Transforming progression level, \(\rho \), to \(\alpha \).

We used two different loss functions to train our Siamese networks. The first is a traditional triplet loss, which we refer to as Unweighted Siamese:

$$\begin{aligned} L_{u} = max(d_{ap} - d_{an} + margin, 0) \end{aligned}$$
(1)

where \(margin=1.0\), \(d_{ap}\) is the Euclidean distance between anchor and positive embeddings, and \(d_{an}\) is the distance between the embeddings of anchor and negative instances.

The second loss is a newly proposed Weighted triplet loss—Weighted Siamese which introduces a coefficient \(\alpha \in [1.0, 1.8]\) to \(d_{an}\) in \(L_{u}\):

$$\begin{aligned} L_{w} = max(d_{ap} - \alpha d_{an} + margin, 0) \end{aligned}$$
(2)

3.3 Training and Evaluation

We implemented all of our experiments using TensorFlow [1] and Keras [11]. After comparing performance between different architectures and feature embedding size, we chose to train a 3D ResNet-50 model from scratch by adding three fully connected layers of sizes 64, 32, and 8 nodes with ReLu activations, taking the last layer of size 8 as the embedding space. We used an Adam optimizer with a decaying learning rate of 1e-3. We trained the model with five different seeds for 150 epochs, which took an average of 122 min per a training run on an NVIDIA RTX A5000 graphics card.

For model evaluation on training and testing datasets, we use both the Unweighted Siamese and Weighted Siamese losses as well as Mean Absolute Error (MAE) and Root Mean Squared Error (RMSE). MAE and RMSE are presented in Eqs. 3 and 4 respectively, for a test set of size N where \(y_{i}\) and \(Y_{i}\) hold the predicted and ground truth values for instance i, respectively. We turn the distance outputs of the Siamese networks into y by discretizing them into equally spaced bins, where the number of bins equals the number of progression levels.

$$\begin{aligned} MAE = \frac{1}{N} \sum _{i=1}^N| y_{i} - Y_{i} | \end{aligned}$$
(3)
$$\begin{aligned} RMSE = \sqrt{\frac{1}{N} \sum _{i=1}^N (y_{i} - Y_{i})^2} \end{aligned}$$
(4)

We make use of t-SNE to explain the 8 dimension embedding space learned by the Weighted Siamese network by condensing it two dimensions. For presentation purposes and in order to fit the t-SNE well, we drop underrepresented progression levels; while we dropped progression level 0.2 from the training dataset, progression levels 0.2, 0.3, and 0.5 were removed from the testing dataset. The t-SNE was fitted over a 1000 iterations using Euclidean distance metric with a perplexity of 32 and 8 for the training and testing datasets, respectively.

4 Results and Discussion

In this section, we present training and testing losses, MAE and RMSE metrics of evaluation, a plot showing comparison between predicted and ground-truth progression levels, as well as interpretation of the results.

Fig. 4.
figure 4

Training and testing losses of the Weighted and Unweighted Siamese models. The first 40 epochs are cropped out for easier visualization. Bars represent std. errors over five runs.

Table 2. Average MAE and RMSE over five runs.

Training and testing losses over five runs of model training for both the Unweighted and Weighted Siamese networks are shown in Fig. 4. While the average training and testing losses of the Weighted Siamese network are 2.92 and 2.79, the Unweighted Siamese achieves 10.02 and 17.53, respectively. We were able to observe that the Unweighted Siamese network had a hard time learning the progression levels of all the ordinal categories. However, our proposed approach using Weighted loss was better at fitting to all the levels. We accredit this to the effects of adding a weighing factor using \(\rho \).

A plot of predicted vs. ground truth MCI to AD progression levels is presented in Fig. 5. Our proposed Weighted Siamese network outperforms the Unweighted Siamese network at predicting progression levels(Fig. 5 and Table 2).

Fig. 5.
figure 5

Predicted progression levels of test MRI images against ground truth levels.

We observed that the simple modification of factoring the distance between an embedding of anchor and negative instances by a function of the progression level brought considerable performance gain in separating between the interpolated categories between MCI and AD.

In Fig. 5, although the Weighted Siamese outperforms the Unweighted Siamese, it also usually classifies the input test images with lower progression levels as if they are on a higher progression levels. This would mean brain images of patients that are far away from progressing to AD would be identified as if they are close to progressing. While it’s important to correctly identify these low risk patients, we believe it’s better to report the patients at lower risk as high risk and refer them for expert input than classifying high risk patients as low risk.

An interpretation of the results of the proposed Weighted Siamese method using t-SNE is displayed in Fig. 6. The clustering of the embedding of input instances according to their progression levels, especially between the low-risk and high-risk progression levels assures us that the results represent the ground truth disease levels.

Fig. 6.
figure 6

Visualization of t-SNE of the embedding spaces of training and test instances.

5 Conclusion

Similarly to other image-based computer assisted diagnosis research work, the AD identification literature is heavily populated by disease stage classification. However, an interesting extra step can be taken to identify how far an input brain image is from progressing to a more severe stage of AD. We present a novel approach of interpolating ordinal categories in-between the MCI and AD categories to prepare a training dataset. In addition, we proposed and implemented a new Weighted loss term for Siamese networks that is tailored to such a dataset. With our experiments, we show that our proposed approach surpassed the performance of a model trained using a standard Unweighted loss term; and we show how the predicted levels translate to the ground truth progression levels by applying a model interpretability technique on the embedding space. We believe our approach could easily be transferred to other areas of medical image classification involving progressive diseases.

The diagnosis results taken in our study are bounded by the timeline of the ADNI study—meaning, even though based on extracted information a participant may have MCI during an examination year and they may progress to AD after some year(s), they could have had MCI before joining the ADNI study and their progression to AD might have taken longer than what we have noted. Future work should consider this limitation.