1 Introduction

The interpretability of deep learning models is especially a concern for applications related to human health, such as analyzing longitudinal brain MRIs. To avoid interpretation during post-hoc analysis [6, 14], some methods strive for an interpretable latent representation [9]. One example is self-organizing maps (SOM) [5], which cluster the latent space so that the SOM representations (i.e., the ‘representatives of the clusters) can be arranged in a discrete (typically 2D) grid while preserving high-dimensional relationships between clusters. Embedded in unsupervised deep learning models, SOMs have been used to generate interpretable representations of low-resolution natural images [3, 8].

Intriguing as it sounds, we found their application to (longitudinal) 3D brain MRIs unstable during training and resulted in uninformative SOMs. These models get stuck in local minima so that only a few SOM representations are updated during backpropagation. The issue has been less severe in prior applications [3, 8] as their corresponding latent space is of much lower dimension than the task at hand, which requires a high dimension latent space so that it can accurately encode the fine-grained anatomical details in brain MRIs [12, 17]. To ensure all SOM representations can be updated during backpropagation, we propose a soft weighing scheme that not only updates the closest SOM representation for a given MRI but also updates all other SOM representations based on their distance to the closest SOM representation [3, 8]. Moreover, our model relies on a stop-gradient operator [16], which sets the gradient of the latent representation to zero so that it only focuses on updating the SOM representations. It is especially crucial at the beginning of the training when the (randomly initialized) SOM representations are not good representatives of their clusters. Finally, the latent representations of the MRIs are updated via a commitment loss, which encourages the latent representation of an MRI sample to be close to its nearest SOM representation. In practice, these three components ensure stability during the self-supervised training of the SOM on high-dimensional latent spaces.

To generate SOMs informative to neuroscientists, we extend SOMs to the longitudinal setting such that the latent space and corresponding SOM grid encode brain aging. Inspired by [12], we encode pairs of MRIs from the same longitudinal sequence (i.e., same subject) as a trajectory and encourage the latent space to be a smooth trajectory (vector) field. We enforce smoothness by computing for each SOM cluster a reference trajectory, which represents the average aging of that cluster with respect to the training set. The reference trajectories are updated by the exponential moving average (EMA) such that, in each iteration, it aggregates the average trajectory of a cluster with respect to the corresponding training batch (i.e., batch-wise average trajectory). In doing so, the model ensures longitudinal consistency as the (subject-specific) trajectories of a cluster are maximally aligned with the reference trajectory of that cluster.

Named Longitudinally-consistent Self-Organized Representation learning (LSOR), we evaluate our method on a longitudinal T1-weighted MRI dataset of 632 subjects from ADNI to encode the brain aging of Normal Controls (NC) and patients diagnosed with static Mild Cognitive Impairment (sMCI), progressive Mild Cognitive Impairment (pMCI), and Alzheimer’s Disease (AD). LSOR clusters the latent representations of all MRIs into 32 SOM representations. The resulting 4-by-8 SOM grid is organized by both chronological age and cognitive measures that are indicators of brain age. Note, such an organization solely relies on longitudinal MRIs, i.e., without using any tabular data such as age, cognitive measure, or diagnosis. To visualize aging effects on the grid, we compute (post-hoc) a 2D similarity grid for each MRI that stores the similarity scores between the latent representation of that MRI and all SOM representations. As the SOM grid is an encoding of brain aging, the similarity grid indicates the likelihood of placing the MRI within the “spectrum” of aging. Given all MRIs of a longitudinal scan, the change across the corresponding similarity grids over time represents the brain aging process of that individual. Furthermore, we infer brain aging on a group-level by first computing the average similarity grid for an age group and then visualizing the difference of those average similarity grids across age groups. With respect to the downstream tasks of classification (sMCI vs. pMCI) and regression (i.e., estimating the Alzheimer’s Disease Assessment Scale-Cognitive Subscale (ADAS-Cog) on all subjects), our latent representations of the MRIs is associated with comparable or higher accuracy scores than representations learned by other state-of-the-art self-supervised methods.

Fig. 1.
figure 1

Overview of the latent space derived from LSOR. All trajectories (\(\varDelta z\)) form a trajectory field (blue box) modeling brain aging. SOM representations in \(\mathcal {G}\) (orange star) are organized as a 2D grid (orange grid). As shown in the black box, reference trajectories \(\varDelta \mathcal {G}\) (collection of all \(\varDelta g\), green arrow) are iteratively updated by EMA using the aggregated trajectory \(\varDelta h\) (purple arrow) across all trajectories of the corresponding SOM cluster within a training batch. (Color figure online)

2 Method

As shown in Fig. 1, the longitudinal 3D MRIs of a subject are encoded as a series of trajectories (blue vectors) in the latent space. Following  [12, 17], we consider a pair of longitudinal MRIs (that corresponds to a blue vector) as a training sample. Specifically, let \(\mathcal {S}\) denote the set of image pairs of the training cohort, where the MRIs \(x^u\) and \(x^v\) of a longitudinal pair \((x^u, x^v)\) are from the same subject and \(x^v\) was acquired \(\varDelta t\) years after \(x^u\). For simplicity, \(\times \) refers to u or v when a function is separately applied to both time points. The MRIs are then mapped to the latent space by an encoder F, i.e., \(z^\times :=F(x^\times )\). On the latent space, the trajectory of the pair is denoted as \(\varDelta z := (z^v - z^u) / \varDelta t\), which represents morphological changes. Finally, decoder H reconstructs the input MRI \(x^\times \) from the latent representation \(z^\times \), i.e., \(\tilde{x}^\times :=H(z^\times )\). Next, we describe LSOR, which generates interpretable SOM representations, and the post-hoc analysis for deriving similarity grids.

2.1 LSOR

Following [3, 8], SOM representations are organized in a \(N_r\) by \(N_c\) grid (denoted as SOM grid) \(\mathcal {G}=\{g_{i,j}\}_{i=1,j=1}^{N_r,N_c}\), where \(g_{i,j}\) denotes the SOM representation on the i-th row and j-th column. This easy-to-visualize grid preserves the high-dimensional relationships between the clusters as shown in by the orange lines in Fig. 1. Given the latent representation \(z^\times \), its closest SOM representation is denoted as \(g_{\epsilon ^\times }\), where \(\epsilon ^\times := argmin_{(i,j)} \parallel z^\times - g_{i,j} \parallel _2\) is its 2D grid index in \(\mathcal {G}\) and \(\parallel \cdot \parallel _2\) is the Euclidean norm. This SOM representation is also used to reconstruct the input MRI by the decoder, i.e., \(\tilde{x}^\times _g=H(g_{\epsilon ^\times })\). To do so, the reconstruction loss encourages both the latent representation \(z^\times \) and its closet SOM representation \(g_{\epsilon ^\times }\) to be descriptive of the input MRI \(x^\times \), i.e.,

$$\begin{aligned} L_{recon} := \mathbb {E}_{(x^u, x^v) \sim \mathcal {S}} \left( \sum _{\times \in \{x,v\} }\parallel x^\times - \tilde{x}^\times \parallel _2^2 + \parallel x^\times - \tilde{x}^\times _g \parallel _2^2 \right) , \end{aligned}$$
(1)

where \(\mathbb {E}\) defines the expected value. The remainder describes the three novel components of our SOM representation.

Explicitly Regularizing Closeness. Though \(L_{recon}\) implicitly encourages close proximity between \(z^\times \) and \(g_{\epsilon ^\times }\), it does not inherently optimize \(g_{\epsilon ^\times }\) as \(z^\times \) is not differentiable with respect to \(g_{\epsilon ^\times }\). Therefore, we introduce an additional ‘commitment’ loss explicitly promoting closeness between them:

$$\begin{aligned} L_{commit} := \mathbb {E}_{(x^u, x^v) \sim \mathcal {S}} \left( \parallel z^u - g_{\epsilon ^u} \parallel _2^2 + \parallel z^v - g_{\epsilon ^v} \parallel _2^2 \right) . \end{aligned}$$

Soft Weighting Scheme. In addition to update \(z^\times \)’s closest SOM representation \(g_{\epsilon ^\times }\), we also update all SOM representations \(g_{i,j}\) by introducing a soft weighting scheme as proposed in [10]. Specifically, we design a weight \(w^\times _{i,j}\) to regularize how much \(g_{i,j}\) should be updated with respect to \(z^\times \) based on its proximity to the grid location \(\epsilon ^\times \) of \(g_{\epsilon ^\times }\), i.e.,

$$\begin{aligned} w_{i,j}^\times := \delta \left( e^{-\frac{\parallel \epsilon ^\times - (i,j)\parallel _1^2}{2\tau }}\right) , \end{aligned}$$
(2)

where \(\delta (w):=\frac{w}{\sum _{i,j} w_{i,j}}\) ensures that the scale of weights is constant during training and \(\tau >0\) is a scaling hyperparameter. Now, we design the following loss \(L_{som}\) so that SOM representations close to \(\epsilon ^\times \) on the grid are also close to \(z^\times \) in the latent space (measured by the Euclidean distance \(\parallel z^\times - g_{i,j} \parallel _2\)):

$$\begin{aligned} L_{som} := \mathbb {E}_{(x^u, x^v) \sim \mathcal {S}} \left( \sum _{g_{i, j} \sim \mathcal {G}}\left( w^u_{i,j} \cdot \parallel z^u - g_{i,j} \parallel _2^2 \,+\, w^v_{i,j} \cdot \parallel z^v - g_{i,j} \parallel _2^2 \right) \right) . \end{aligned}$$
(3)

To improve robustness, we make two more changes to Eq. 3. First, we account for SOM representations transitioning from random initialization to becoming meaningful cluster centers that preserve the high-dimensional relationships within the 2D SOM grid. We do so by decreasing \(\tau \) in Eq. 2 with each iteration so that the weights gradually concentrate on SOM representations closer to \(g_{\epsilon ^\times }\) as training proceeds: \(\tau (t) := N_r \cdot N_c \cdot \tau _{max} \left( \frac{\tau _{min}}{\tau _{max}} \right) ^{t/T} \) with \(\tau _{min}\) being the minimum and \(\tau _{max}\) the maximum standard deviation in the Gaussian kernel, and t represents the current and T the maximum iteration.

The second change to Eq. 3 is to apply the stop-gradient operator \(sg[\cdot ]\) [16] to \(z^\times \), which sets the gradients of \(z^\times \) to 0 during the backward pass. The stop-gradient operator prevents the undesirable scenario where \(z^\times \) is pulled towards a naive solution, i.e., different MRI samples are mapped to the same weighted average of all SOM representations. This risk of deriving the naive solution is especially high in the early stages of the training when the SOM representations are randomly initialized and may not accurately represent the clusters.

Longitudinal Consistency Regularization. We derive a SOM grid related to brain aging by generating an age-stratified latent space. Specifically, the latent space is defined by a smooth trajectory field (Fig. 1, blue box) characterizing the morphological changes associated with brain aging. The smoothness is based on the assumption that MRIs with similar appearances (close latent representations on the latent space) should have similar trajectories. It is enforced by modeling the similarity between each subject-specific trajectory \(\varDelta z\) with a reference trajectory that represents the average trajectory of the cluster. Specifically, \(\varDelta g_{i,j}\) is the reference trajectory (Fig. 1, green arrow) associated with \(g_{i,j}\) then the reference trajectories of all clusters \(\mathcal {G}_{\varDelta }=\{ \varDelta g_{i,j} \}_{i=1,j=1}^{N_r,N_c}\) represent the average aging of SOM clusters with respect to the training set. As all subject-specific trajectories are iteratively updated during the training, it is computationally infeasible to keep track of \(\mathcal {G}_{\varDelta }\) on the whole training set. We instead propose to compute the exponential moving average (EMA) (Fig. 1, black box), which iteratively aggregates the average trajectory with respect to a training batch to \(\mathcal {G}_{\varDelta }\):

$$\begin{aligned} \varDelta g_{i,j}&\leftarrow {\left\{ \begin{array}{ll} \varDelta h_{i,j} &{} t=0 \\ \varDelta g_{i,j} &{} t>0 \hbox { and } |\varOmega _{i,j}| = 0 \\ \alpha \cdot \varDelta g_{i,j} + (1-\alpha ) \cdot \varDelta h_{i,j} &{} t>0 \hbox { and } |\varOmega _{i,j}| > 0\\ \end{array}\right. }\\ \hbox {with } \varDelta h_{i,j}&:= \frac{1}{|\varOmega _{i,j}|} \sum _{k=1}^{N_{bs}} \mathbbm {1}[\epsilon ^u_k=(i,j)] \cdot \varDelta z_k \hbox { and } |\varOmega _{i,j}|:=\sum _{k=1}^{N_{bs}} \mathbbm {1}[\epsilon ^u_k=(i,j)]. \end{aligned}$$

\(\alpha \) is the EMA keep rate, k denotes the index of the sample pair, \(N_{bs}\) symbolizes the batch size, \(\mathbbm {1}[\cdot ]\) is the indicator function, and \(|\varOmega _{i,j}|\) denotes the number of sample pairs with \(\epsilon ^u=(i,j)\) within a batch. Then in each iteration, \(\varDelta h_{i,j}\) (Fig. 1, purple arrow) represents the batch-wise average of subject-specific trajectories for sample pairs with \(\epsilon ^u = (i,j)\). By iteratively updating \(\mathcal {G}_{\varDelta }\), \(\mathcal {G}_{\varDelta }\) then approximate the average trajectories derived from the entire training set. Lastly, inspired by [11, 12], the longitudinal consistency regularization is formulated as

$$\begin{aligned} L_{dir} := \mathbb {E}_{(x^u, x^v) \sim \mathcal {S}} \left( 1 - cos(\theta [\varDelta z, sg[\varDelta g_{\epsilon ^u}]])\right) , \end{aligned}$$

where \(\theta [\cdot , \cdot ]\) denotes the angle between two vectors. Since \(\varDelta g\) is optimized by EMA, the stop-gradient operator is again incorporated to only compute the gradient with respect to \(\varDelta z\) in \(L_{dir}\).

Objective Function. The complete objective function is the weighted combination of the prior losses with weighing parameters \(\lambda _{commit}\), \(\lambda _{som}\), and \(\lambda _{dir}\):

$$\begin{aligned} L := L_{recon} + \lambda _{commit} \cdot L_{commit} + \lambda _{som} \cdot L_{som} + \lambda _{dir} \cdot L_{dir} \end{aligned}$$

The objective function encourages a smooth trajectory field of aging on the latent space while maintaining interpretable SOM representations for analyzing brain age in a pure self-supervised fashion.

2.2 SOM Similarity Grid

During inference, a (2D) similarity grid \(\rho \) is computed by the closeness between the latent representation z of an MRI sample and the SOM representations:

$$\begin{aligned} \rho := softmax(-\parallel z - \mathcal {G} \parallel _2^2 / \gamma ) \hbox { with } \gamma := std(\parallel z - \mathcal {G} \parallel _2^2) \end{aligned}$$

std denotes the standard deviation of the distance between z to all SOM representations. As the SOM grid is learned to be associated with brain age (e.g., represents aging from left to right), the similarity grid essentially encodes a “likelihood function" of the brain age in z. Given all MRIs of a longitudinal scan, the change across the corresponding similarity grids over time represents the brain aging process of that individual. Furthermore, brain aging on the group-level is captured by first computing the average similarity grid for an age group and then visualizing the difference of those average similarity grids across age groups.

3 Experiments

3.1 Experimental Setting

Dataset. We evaluated the proposed method on all 632 longitudinal T1-weighted MRIs (at least two visits per subject, 2389 MRIs in total) from ADNI-1 [13]. The data set consists of 185 NC (age: 75.57 ± 5.06 years), 193 subjects diagnosed with sMCI (age: 75.63 ± 6.62 years), 135 subjects diagnosed with pMCI (age: 75.91 ± 5.35 years), and 119 subjects with AD (age: 75.17 ± 7.57 years). There was no significant age difference between the NC and AD cohorts (p = 0.55, two-sample t-test) as well as the sMCI and pMCI cohorts (p = 0.75). All MRI images were preprocessed by a pipeline including denoising, bias field correction, skull stripping, affine registration to a template, re-scaling to 64 \(\times \) 64 \(\times \) 64 volume, and transforming image intensities to z-scores.

Implementation Details. Let C\(_k\) denote a Convolution(kernel size of \(3\times 3\times 3\), Conv\(_k\))-BatchNorm-LeakyReLU(slope of 0.2)-MaxPool(kernel size of 2) block with k filters, and CD\(_k\) an Convolution-BatchNorm-LeakyReLU-Upsample block. The architecture was designed as C\(_{16}\)-C\(_{32}\)-C\(_{64}\)-C\(_{16}\)-Conv\(_{16}\)-CD\(_{64}\)-CD\(_{32}\)-CD\(_{16}\)-CD\(_{16}\)-Conv\(_{1}\), which results in a latent space of 1024 dimensions. The training of SOM is difficult in this high-dimensional space with random initialization in practice, thus we first pre-trained the model with only \(L_{recon}\) for 10 epochs and initialized the SOM representations by doing k-means of all training samples using this pre-trained model. Then, the network was further trained for 40 epochs with regularization weights set to \(\lambda _{recon}=1.0\), \(\lambda _{commit}=0.5\), \(\lambda _{som}=1.0\), \(\lambda _{dir}=0.2\). Adam optimizer with learning rate of \(5 \times 10^{-4}\) and weight decay of \(10^{-5}\) were used. \(\tau _{min}\) and \(\tau _{max}\) in \(L_{som}\) were set as 0.1 and 1.0 respectively. An EMA keep rate of \(\alpha =0.99\) was used to update reference trajectories. A batch size \(N_{bs}=64\) and the SOM grid size \(N_r=4, N_c=8\) were applied.

Fig. 2.
figure 2

The color at each SOM representation encodes the average value of (a) chronological age, (b) % of AD and pMCI, and (c) ADAS-Cog score across the training samples of that cluster; (d) Confined to the last row of the grid, the average MRI of 20 latent representations closest to the corresponding SOM representation. (Color figure online)

Evaluation. We performed five-fold cross-validation (folds split based on subjects) using 10% of the training subjects for validation. The training data was augmented by flipping brain hemispheres and random rotation and translation. To quantify the interpretability of the SOM grid, we correlated the coordinates of the SOM grid with quantitative measures related to brain age, e.g., chronological age, the percentage of subjects with severe cognitive decline, and Alzheimer’s Disease Assessment Scale-Cognitive Subscale (ADAS-Cog). We illustrated the interpretability with respect to brain aging by visualizing the changes in the SOM similarity maps over time. We further visualized the trajectory vector field along with SOM representations by projecting the 1024-dimensional representations to the first two principal components of SOM representations. Lastly, we quantitatively evaluated the quality of the representations by applying them to the downstream tasks of classifying sMCI vs. pMCI and ADAS-Cog prediction. We measured the classification accuracy via Balanced accuracy (BACC) and Area Under Curve (AUC) and the prediction accuracy via R2 and root-mean-square error (RMSE). The classifier and predictor were multi-layer perceptrons containing two fully connected layers of dimensions 1024 and 64 with a LeakyReLU activation. We compared the accuracy metrics to models using the same architecture with encoders pre-trained by other representation learning methods, including unsupervised methods (AE, VAE [4]), self-supervised method (SimCLR [1]), longitudinal self-supervised method (LSSL [17]), and longitudinal neighborhood embedding (LNE [12]). All comparing methods used the same experimental setup (e.g., encoder-decoder, learning rate, batch size, epochs, etc.), and the method-specific hyperparameters followed [12].

Fig. 3.
figure 3

The average similarity grid \(\rho \) over subjects of a specific age and diagnosis (NC vs AD). Each grid encodes the likelihood of the average brain age of the corresponding sub-cohort. Cog denotes the average ADAS-Cog score.

3.2 Results

Interpretability of SOM Embeddings. Fig. 2 shows the stratification of brain age over the SOM grid \(\mathcal {G}\). For each grid entry, we show the average value of chronological age (Fig. 2(a)), % of AD & pMCI (Fig. 2(b)), and ADAS-Cog score (Fig. 2(c)) over samples of that cluster. We observed a trend of older brain age (yellow) from the upper left towards the lower right, corresponding to older chronological age and worse cognitive status. The SOM grid index strongly correlated with these three factors (distance correlation of 0.92, 0.94, and 0.91 respectively). Figure 2(d) shows the average brain over 20 input images with representations that are closest to each SOM representation of the last row of the grid (see Supplement Fig. S1 for all rows). From left to right the ventricles are enlarging and the brain is atrophying, which is a hallmark for brain aging effects.

Interpretability of Similarity Grid. Visualizing the average similarity grid \(\rho \) of the NC and AD at each age range in Fig. 3, we observed that higher similarity (yellow) gradually shifts towards the right with age in both NC and AD (see Supplemental Fig. S2 for sMCI and pMCI cohorts). However, the shift is faster for AD, which aligns with AD literature reporting that AD is linked to accelerated brain aging [15]. Furthermore, the subject-level aging effects shown in Supplemental Fig. S3 reveal that the proposed visualization could capture subtle morphological changes caused by brain aging.

Interpretability of Trajectory Vector Field. Fig. 4 plots the PCA projections of the latent space in 2D, which shows a smooth trajectory field (gray arrows) and reference trajectories \(\mathcal {G}_{\varDelta }\) (blue arrows) representing brain aging. This projection also preserved the 2D grid structure (orange) of the SOM representations suggesting that aging was the most important variation in the latent space.

Downstream Tasks. To evaluate the quality of the learned representations, we froze encoders trained by each method without fine-tuning and utilized their representations for the downstream tasks (Table 1). On the task of sMCI vs. pMCI classification (Table 1 (left)), the proposed method achieved a BACC of 69.8 and an AUC of 72.4, a comparable accuracy (\(p > 0.05\), DeLong’s test) with LSSL [17] and LNE [12], two state-of-the-art self-supervised methods on this task. On the ADAS-Cog score regression task, the proposed method obtained the best accuracy with an R2 of 0.32 and an RMSE of 6.31. It is worth mentioning that an accurate prediction of the ADAS-Cog score is very challenging due to its large range (between 0 and 70) and its subjectiveness resulting in large variability across exams [2] so that even larger RMSEs have been reported for this task [7]. Furthermore, our representations were learned in an unsupervised manner so that further fine-tuning of the encoder would improve the prediction accuracy.

Fig. 4.
figure 4

2D PCA of the LSOR’s latent space. Light gray arrows represent \(\varDelta z\). The orange grid represents the relationships between SOM representations and associated reference trajectory \(\varDelta \mathcal {G}\) (blue arrow). (Color figure online)

Table 1. Supervised downstream tasks using the learned representations z (without fine-tuning the encoder). LSOR achieved comparable or higher accuracy scores than other state-of-the-art self- and un-supervised methods.

4 Conclusion

In this work, we proposed LSOR, the first SOM-based learning framework for longitudinal MRIs that is self-supervised and interpretable. By incorporating a soft SOM regularization, the training of the SOM was stable in the high-dimensional latent space of MRIs. By regularizing the latent space based on longitudinal consistency as defined by longitudinal MRIs, the latent space formed a smooth trajectory field capturing brain aging as shown by the resulting SOM grid. The interpretability of the representations was confirmed by the correlation between the SOM grid and cognitive measures, and the SOM similarity map. When evaluated on downstream tasks sMCI vs. pMCI classification and ADAS-Cog prediction, LSOR was comparable to or better than representations learned from other state-of-the-art self- and un-supervised methods. In conclusion, LSOR is able to generate a latent space with high interpretability regarding brain age purely based on MRIs, and valuable representations for downstream tasks.