Keywords

1 Introduction

Estimating progression models from the analysis of time dependent data is a challenging task that helps to uncover latent dynamics. For the study of neurodegenerative diseases, longitudinal databases have been assembled where a set of biomarkers (medical images, cognitive scores and covariates) are gathered for individuals across time. Understanding their temporal evolution is of crucial importance for early diagnosis and drug trials design, especially the imaging biomarkers that can reveal a silent prodromal phase.

In this context, several approaches have been proposed for the progression of scalar measurements such as clinical scores or volumes of brain structures [12, 18, 29] or series of measurements across brain regions forming a network [4, 21]. These approaches require the prior segmentation and extraction of the measurements from the images. Providing progression models for high dimensional structured data without prior processing is still a challenging task. The difficulty is to provide a low dimensional representation of the data, where each patient’s trajectory admits a continuous parametrization over time. It should allow sampling at any time point, be resilient to irregularly spaced instances and disentangle temporal alterations from the inter-patients variability. 0

1.1 Related Work

When dealing with high dimensional data, it is often assumed that the data can be encoded into a low dimensional manifold where the distribution of the data is simple. Deep Generative Neural Networks such as Variational Auto Encoders (VAE) [19] allow finding such embeddings. Several approaches have explored longitudinal modeling for images within the context of dimensionality reduction.

Recurrent Neural Networks (RNN) provide a straightforward way to extract information from sequential data. Convolutional networks with a recurrent structure have been used for diagnosis prediction using MRI [11] or PET [23] scans in Alzheimer’s Disease (AD). The main caveat of these approaches is that the recurrent structure is highly sensible to the temporal spacing between instances which is troublesome in the context of disease modeling, where visits are often missing of irregularly spaced.

Mixed-effects models provide an explicit description of the progression of each patient, allowing to sample at any timepoint. Through a time reparametrization, all patients are aligned on a common pathological timeline, and individual trajectories are parametrised as small variations (random effects) around a reference trajectory (fixed effects) that can be seen as the average scenario. Now considered a standard tool in longitudinal modeling [21, 28, 29], mixed-effects models have yet been scarcely used for images within the context of dimension reduction. In [24], a RNN outputs the parameters of a mixed-effect model that describes patients’ trajectories as straight lines in the latent space of a VAE across time.

Self supervised methods have proposed to alleviate the need for labels, in our case the age of the patients at each visit. In [9], the encoder of a VAE learns a latent time variable and a latent spatial variable to disentangle the temporal progression from the patient’s intrinsic characteristics. Similarly, in [30], the encoder is penalized with a cosine loss that imposes one direction in the latent space that corresponds to an equivalent of time. Both these methods allow the model to learn a temporal progression that does not rely on the clinical age of the patients, offering potential for unlabeled data, at the cost of interpretability of the abstract timeline and the ability to sample at any given timepoint.

Longitudinal VAEs architectures have been proposed in order to endow the latent space with a temporal structure. Namely, Gaussian Process VAEs (GPVAE) [7] introduced a more general prior for the posterior distribution in the latent space, in the form of a Gaussian Process (GP) that depends on the age of the patients [13] as well as a series of covariates [2, 26]. This approach poses challenges as to the choice of parametrization for the Gaussian Process, and does not provide an expected trajectory for each patient.

Diffeomorphic methods provide progression models for images. The main approaches are based off of the geodesic regression framework [3, 25] and allow learning a deformation map that models the effect of time on the images for a given subject. While providing high resolution predictions, these methods show limited predictive abilities further in time when compared to mixed-effects models, that aggregate information from all the subjects at different stages of the disease [6].

1.2 Contributions

In this context, we propose to endow the latent space of a VAE with a linear mixed-effect longitudinal model. While in [24], the networks predict the random effects from visits grouped by patients, we propose to enrich a regular VAE that maps each individual visits to a latent representation, with an additional longitudinal latent model that describes the progression of said representations over time. A novel Monte Carlo Markov Chain (MCMC) procedure to jointly estimate the VAE and the structure of its representation manifold is proposed. To sum up the contributions, we:

  1. 1.

    use the entire 3D scan without segmentation or parcellation to study relations across brain regions in an unsupervised manner,

  2. 2.

    proceed to dimension reduction using a convolutional VAE with the added constraint that latent representations must comply with the structure of a generative statistical model of the trajectories,

  3. 3.

    provide a progression model that disentangles temporal changes from changes due to inter-patients variability, and allows sampling patients’ trajectories at any timepoint, to infer missing data or predict future progression,

  4. 4.

    demonstrate this method on a synthetic data set and on both MRI and PET scans from the Alzheimer’s Disease Neuroimaging Initiative (ADNI), recovering known patterns in normal or pathological brain aging.

2 Methodology

2.1 Representation Learning with VAEs

Auto Encoders are a standard tool for non-linear dimensionality reduction, comprised of an encoder network \(q_\phi \) that maps high dimensional data \(x\in \mathcal {X}\) to \(z\in \mathcal {Z}\), in a smaller space refered to as the latent space, and a decoder network \(p_\theta : z\in \mathcal {Z}\mapsto \hat{x}\in \mathcal {X}\). VAEs [19] offer a more regularized way to populate the latent space. Both encoder and decoder networks output variational distributions \(q_\phi (z|x)\) and \(p_\theta (x|z)\), chosen to be multivariate Gaussian distributions. Adding a prior q(z), usually the unit Gaussian distribution \(\mathcal {N}(0,\textrm{I})\), on \(\mathcal {Z}\) allows to derive a tractable Evidence Lower BOund for the log-likelihood \( ELBO = \mathcal {L}_{recon} + \beta \mathcal {L}_{KL}\) where \(\mathcal {L}_{recon}\) is the \(\ell _2\) reconstruction error, \(\mathcal {L}_{KL}\) is the Kullback-Leibler (KL) divergence between the approximate posterior and the prior on the latent space and \(\beta \) balances reconstruction error and latent space regularity [16].

2.2 Longitudinal Statistical Model

In this section, we propose a temporal latent variables model that encodes disease progression in the low-dimensional space \(\mathcal {Z}\). Given a family of observations from N patients \(\{ x_{i,j} \}_{1\le i \le N}\), each observed at ages \(t_{i,j}\) for \(1\le j \le n_i\) visits, and their latent representations \(\{ z_{i,j} \}\), we define a statistical generative model with

$$z _{i,j} = p_0 + \left[ e^{\xi _i}(t_{i,j} - \tau _i) \right] v_0 + w_i + \varepsilon _{i,j} $$

where \(e^{\xi _i}\) and \(\tau _i\), respectively the acceleration factor and the onset age of patient i, allow an affine time warp aligning all patients on a common pathological timeline, and \(w_i\in \mathcal {Z}\) is the space shift that encodes inter-subjects variability, such as morphological variations across regions that are independent from the progression. These parameters position the individual trajectory with respect to the typical progression that is estimated at the population level. These three parameters form the random effects of the model \(\psi _r\). Vectors \(w_i\) and \(v_0\) need to be orthogonal in order to uniquely identify temporal and spatial variability.

We choose the Gaussian priors for the noise \(\varepsilon _{i,j}\sim \mathcal {N}(0, \sigma _\varepsilon ^2)\) and random effects \(\tau _i \sim \mathcal {N}(t_0, \sigma _\tau ^2)\), \(\xi _i \sim \ \mathcal {N}(0, \sigma _\xi ^2)\) and \(w_i \sim \mathcal {N}(0, \textrm{I})\). The parameters \(p_0\in \mathcal {Z}\), \(v_0\in \mathcal {Z}\), \(t_0\in \mathbb {R}\) are respectively a reference position, velocity and time and describe the average trajectory. Together with the variances \(\sigma _\varepsilon , \sigma _\tau , \sigma _\xi \), they form the fixed-effects of the model \(\psi _f\). We note \(\psi = (\psi _r, \psi _f)\).

2.3 Longitudinal VAE

We combine dimension reduction using a regular \(\beta \)-VAE and the aforementioned statistical model to add a temporal structure to the latent space. To do so, we consider a composite loss that accounts for both the VAE loss and the goodness-of-fit of the mixed-effect model:

$$\begin{aligned} \mathcal {L} = \mathcal {L}_{recon} + \beta \mathcal {L}_{KL} + \gamma \mathcal {L}_{align}\;\; \textrm{where} \;\, \left\{ \begin{array}{ll} \mathcal {L}_{recon} &{}= \mathop {\sum }\nolimits _{i,j} ||x_{i,j} - \hat{x}_{i,j}||^2 \\ \mathcal {L}_{KL} &{}= \mathop {\sum }\nolimits _{i,j} KL(q_\phi (z|x_{i,j})||\mathcal {N}(0,\textrm{I})) \\ \mathcal {L}_{align} &{}= \mathop {\sum }\nolimits _{i,j} ||z_{i,j} - \eta ^i_\psi (t_{i,j})||^2 \end{array}\right. \end{aligned}$$

where \(z_{i,j}\) and \(\hat{x}_{i,j}\) are the modes of \(q_\phi (x_{i,j})\) and \(p_\theta (z_{i,j})\), and \(\eta ^i_\psi (t_{i,j}) = p_0 + \left[ e^{\xi _i}(t_{i,j} - \tau _i)\right] v_0 + w_i \) is the expected position of the latent representation according to the longitudinal model and \(\gamma \) balances the penalty for not aligning latent representations with the linear model. Since the loss is invariant to rotation in \(\mathcal {Z}\), we set \(p_0=0\) and \(v_0=(1,0,\cdots ,0)\).

figure a
Fig. 1.
figure 1

Images \(\{x_{i,j}\}\) are encoded into \(\mathcal {Z}\) such that the \(\{z_{i,j}\}\) are close to the estimated latent trajectories. Individual trajectories (straight lines) are parametrized with \(w_i\),\(\tau _i\) and \(e^{\xi _i}\) as variations around the reference trajectory (orange arrow). (Color figure online)

Since \(\mathcal {L}_{align}\) is a \(\ell _2\) loss in the latent space, it can be seen as the log-likelihood of a Gaussian prior \(z_{i,j} \sim \mathcal {N}(\eta _\psi ^i(t_{i;j}), \textrm{I})\) in the latent space, which defines an elementary Gaussian Process, and which supports the addition of GP priors in the latent space of VAEs to model longitudinal data [13, 26]. Besides, \(\mathcal {X}\) can be seen as a Riemannian manifold, the metric of which is given by the pushforward of the Euclidean metric of \(\mathcal {Z}\) through the decoder, such that trajectories in \(\mathcal {X}\) are geodesics, in accordance with the Riemannian modeling of longitudinal data [5, 14, 24, 28, 29]. The metric on \(\mathcal {X}\) thus allows to recover non linear dynamics, as is often the case for biomarkers. Our approach thus bridges the gap between the deep learning approach to longitudinal data and a natural generalization of well studied disease progression models to images.

Network Implementation and Estimation. Both the encoder and decoder are chosen to be vanilla convolutional Neural Networks (4 layers of Convolution with stride/BatchNorm/ReLU and transposition for decoder) with a dense layer towards the latent space, as described in Fig. 1. The implementation is in PyTorch and available at https://github.com/bsauty/longitudinal-VAEs.

The difficulty lies in the joint estimation of \((\theta , \phi )\) and \(\psi \), which are co-dependant since \(\mathcal {L}_{align}\) depends on \(\eta _\psi ^i\) and \(\psi \) depends on the encoded representation \(z = q_\phi (x)\). The longitudinal statistical model is part of a family of geometric models that have been studied in [20, 29]. Given \(\{z_{i,j}\}\), we can proceed to a Maximum a Posteriori estimation of \(\psi \) with the MCMC-SAEM procedure in which the estimation step of an EM algorithm is replaced by a stochastic approximation. See [1, 22] for details. Given the target trajectories \(\{\eta ^i_\psi \}\), the weights from both networks of the VAE is optimized through backpropagation using \(\mathcal {L}\) with an optimizer with randomized batches. Both estimation schemes are iterative so we designed a Monte Carlo estimator for \((\phi , \theta , \psi )\), presented in Algorithm 1, alternating between both schemes.

Once calibrated with a training set, we freeze the VAE parameters \((\phi , \theta )\) and fixed effects \(\psi _f\), and learn the individual parameters \(\psi _r\), via gradient descent of the likelihood, to personalize the model for new subjects.

Hyperparameters are dim(\(\mathcal {Z}\)), which should be small enough to allow the mixed-effect model to be interpretable but big enough to reach good reconstruction; \(\beta \), which should minimize overfitting while not impairing reconstruction quality; and \(\gamma \), which should also not be too big to avoid loosing contextual information in \(\mathcal {Z}\). These parameters were set using grid search. Lastly, the MCMC-SAEM is computationaly inexpensive compared to backpropagation so memory footprint and runtime are similar to training a regular VAE. All training was performed with a Quadro RTX4000 8Go GPU.

3 Experiments and Results

3.1 Results on Synthetic Experiments

We first validated our approach on a synthetic data set of images of silhouettes of dimension 64\(\,\times \,\)64 [10] . Over time, the silhouette raises its left arm. Different silhouettes are generated by varying the relative position of the three other limbs, all of them raising their left arm in time. The motion is modulated by varying the time stamp at which the motion starts and the pace of motion. This is done using an affine reparametrization of the time stamp \(t_{i,j}\) of the silhouette with Gaussian log-acceleration factor \(\xi _i\) and onset age \(\tau _i\). This data set contains \(N=1,000\) subjects with \(n=10\) visits each, sampled at random time-points.

Fig. 2.
figure 2

Synthetic experiment. (a) the average trajectory over time (left to right) on first row, followed by its translation in the directions \(w_1=(0,1,0,0)\) and \(w_2=(0,0,1,0)\) in the latent space (second and third row). (b) the gradient of the image at \(p_0 = 0\) in the 4 directions of the latent space \(v_0, w_1, w_2, w_3\). (c) data of a test subject (first row), its reconstructed trajectory (second row) and the ground truth (third row)

We choose dim\((\mathcal {Z})=4\) to evaluate the ability of our model to isolate temporal changes (motion of the left arm) from the independent spatial generative factors (the position of the other 3 limbs). Results are displayed in Fig. 2. The 5-fold reconstruction mean squared errors (MSE) (times 10\(^{-3}\)) for train/test images are \(7.88\pm .22/7.93\pm .29\), showing little over-fitting. Prediction error for missing data, when trained on half-pruned data set, is \(8.1\pm .78\) which shows great extrapolation capabilities. A thorough benchmark of six former approaches on this data set was provided in [9] displaying similar MSE to ours. Although a couple of approaches [9, 30] also disentangle time from space, ours is the only one to yield the true generative factors, with the direction of progression encoding the motion of the left arm and the 3 spatial directions orthogonal to it encoding legs spreading, legs translation, and right arm position respectively.

3.2 Results on 3D MRI and PET Scans

We then applied the method to 3D T1w MRI and FDG-PET scans from the public ADNI database (http://adni.loni.usc.edu). For MRI, we selected two cohorts: patients with a confirmed AD diagnosis at one visit at least (\(N=783\) patients for a total of \(N_{tot}=\,\)3,685 images) and Cognitively Normal (CN) patients at all visits for modeling normal aging (\(N=886\) and \(N_{tot} = 3,205\)). We considered PET data for AD patients only (\(N=570\) and \(N_{tot}=1,463\)). Images are registered using the T1-linear and PET-linear pipelines of the Clinica software [27] and resampled to 80\(\,\times \,\)96\(\,\times \,\)80 resolution.

Fig. 3.
figure 3

(a) Sagittal, coronal and axial views of the population trajectory over pathological time (left to right) for the AD cohort. Enlargment of the ventricles and atrophy of the cortex and the hippocampus are visible. Red squares around the hippocampus are positioned at the 5 stages of the Schelten’ scale used in the clinics to evaluate AD progression. (b) Coronal view of the estimated normal aging scenario, with matched reparameterized age distribution. Atrophy is also visible but to a smaller extent. As is common in atlasing methods, these images average anatomical details from different subjects to provide a population trajectory, and are thus not as sharp as true images. (Color figure online)

Fig. 4.
figure 4

Sagittal, coronal and axial views of the average trajectory for FDG-PET scans, showing decreased level of metabolism across brain regions.

We set dim\((\mathcal {Z})=16\) for both modalities, as it is the smallest dimension that captured the reported dynamics with satisfying resolution. For MRI, error (10\(^{-3}\)) for train/test reconstruction and imputation on half-pruned data set for the AD model are \(14.15\pm .12/15.33\pm .23\) and \(18.65\pm .76\), again showing little over-fit and good prediction abilities. In Fig. 3, the reference trajectory for AD patients reveals the structural alterations that are typical of AD progression. The control trajectory displays alterations more in line with normal aging.

We tested differences in the mean of individual parameters between sub-groups using Mann-Whitney U test within a 5-fold cross-validation. AD average onset age occured earlier for women than for men: 72.2±.4 vs. 73.7±.6 years, \(p<3.10^{-7}\scriptstyle \pm 5.10^{-8}\)). APOE-\(\varepsilon 4\) mutation carriers experience also earlier onset than non-carriers (71.8±.2 vs 73.1±.4, \(p<3.3.10^{-2}\scriptstyle \pm 6.10^{-3}\)) and greater pace of progression (.1\(\pm 3.10^{-2}\) vs −.08\(\pm 2.10^{-2}\), \(p<1.4.10^{-4}\scriptstyle \pm 6.10^{-3}\)). The normal aging model shows an earlier onset for men (71.2±.4 vs 73.7±.6, \(p<3.10^{-10}\scriptstyle \pm 6.10^{-11}\)). These results are in line with the current knowledge in AD progression [15, 17] and normal aging [8]. For PET scans, the 5-fold train/test reconstruction MSE (10\(^{-2}\)) are \(4.71\pm .32/ 5.10\pm .23\) (Fig. 4).

4 Conclusion

We proposed a generative Variational Autoencoder architecture that maps longitudinal data to a low dimensional Euclidean space, in which a linear spatio-temporal structure is learned to accurately disentangle the effects of time and inter-patient variability, while providing interpretable individual parameters (onset age and acceleration factor). This is the first approach to provide a progression model for 3D MRI or PET scans and it relies on vanilla deep learning architectures that only require the tuning of the loss balance. We showed that it bridges the gap between former approaches to handle longitudinal images, namely GP-VAEs, and Riemannian disease progression models. The method applied to MRI and PET data retrieves known patterns of normal and pathological brain aging but without the need to extract specific biomarkers. It does not only save time but also makes the approach independent of prior choice of biomarkers.

Current work focuses on linking this progression model of brain alterations with cognitive decline, and exploring disease sub-types in the latent space.