Keywords

1 Introduction

Analysis of neurological processes in the human brain is a challenging process addressed by neuroimaging. Here, one typically obtains high-dimensional stochastic data, which encourages the usage of machine learning algorithms. In recent years, deep learning discriminative models have been actively applied to neuroimaging issues (see [1] for a review). They yielded state-of-the-art results in classification problems on a variety of benchmark datasets [2,3,4]. One downside of deep learning-based classifiers is that they operate as black boxes [5] meaning that interpreting their predictions is often severely complicated.

[1] suggested that hybrid generative-discriminative models might help resolve the issue. Such models can learn low-dimensional representations of data where each dimension corresponds to an independent generative factor (i.e. a disentangled representation, see [6] for a review). The discriminative part of the model then forces those factors to capture label information from data [7]. Interpretability is thus achieved via decoding the meaning of generative factors related to particular labels [8]. It is especially relevant in neuroimaging as one can observe how underlying pathologies govern the process of data generation.

The paper follows the intuition regarding hybrid generative-discriminative models for neuroimaging data, with particular application to EEG data. Our main contributions are as follows:

  1. 1.

    We demonstrate how one can apply characteristic capturing variational auto-encoders (CCVAEs) [7] to the interpretable classification of EEG data;

  2. 2.

    We compare the model to two generative models previously used for EEG data: conditional VAEs and VAEs with downstream classification;

  3. 3.

    We propose an algorithm for decoding generative factors learned by CCVAEs;

  4. 4.

    We demonstrated that learned generative mechanisms associated with pathologies are consistent with evidence from neurobiological studies.

2 Background

In this section, we introduce the relevant materials on variational auto-encoders, disentangled factorization and the role of supervision.

Variational Auto-Encoders. Variational auto-encoders (VAEs) [9] learn a model distribution \(p_{\theta }(\textbf{x}, \textbf{z})\) that describes the ground-truth data generation process \(p(\textbf{x}, \textbf{z})\) as first sampling random variables \(\textbf{z}\) from a prior distribution \(p(\textbf{z})\). Then, an observation \(\textbf{x}\) is inferred based on generative factors \(p_{\theta }(\textbf{x} | \textbf{z})\) yielding

$$\begin{aligned} p_{\theta }(\textbf{x}, \textbf{z}) = p_{\theta }(\textbf{x} | \textbf{z}) p(\textbf{z}) \end{aligned}$$
(1)

Here, the conditional distribution is parameterized with neural networks whose learned parameters are denoted with \(\theta \). Defining latent variables as jointly independent yields disentangled factorization [10] that separates the generative process into human-interpretable [8] generative mechanisms.

Supervised Learning. A label variable \(\textbf{y} \sim p(\textbf{y})\) can be interpreted as the context that partially governs the generation of an observed variable \(\textbf{x}\). In VAEs, it is reflected by generative factors \(p(\textbf{x}|\textbf{z}, \textbf{y})\) of a model. It leads to the joint distribution factorized as follows:

$$\begin{aligned} p_{\theta _1, \theta _2}(\textbf{x}, \textbf{z}, \textbf{y}) = p_{\theta _1}(\textbf{x}|\textbf{z}, \textbf{y}) p_{\theta _2}(\textbf{z}|\textbf{y}) p(\textbf{y}) \end{aligned}$$
(2)

where \(\theta _1, \theta _2\) are parameters of corresponding model distributions. The equality holds due to the chain rule. Incorporating label information into the model allows learning generative factors corresponding to those labels via \(p_{\theta _2}(\textbf{z}|\textbf{y})\).

3 Methods

The proposed framework consists of 2 steps (see Fig. 1). First, EEG data is mapped stochastically to the latent space via CCVAEs. The latent space is constructed such that each label is related to a single independent generative factor. Second, we perform an intervention analysis to decode the meaning of label-related generative factors. This way, we get an intuition regarding mechanisms through which labels govern data generation. In our case, we are interested how different pathologies manifest themselves in functional connectivity matrices.

Fig. 1.
figure 1

Scheme of the proposed approach (A). We receive EEG data as input and learn a stochastic mapping to the latent space with CCVAEs (B) [7]. We further manipulate learned generative factors of data to gain insights regarding neurological mechanisms related to the attribute of interest, e.g. a symptom.

Characteristic Capturing VAEs. We aim at learning a model of a joint distribution over observed EEG data \(\textbf{x}\), labels (e.g. pathology indicators) \(\textbf{y}\) and latent variables \(\textbf{z}\) partially conditioned to \(\textbf{y}\). Let us assume that \(\textbf{x}\) and \(\textbf{y}\) are conditionally independent given \(\textbf{z}\). Then, the generative model (see Eq. 2) can be rewritten as follows:

$$\begin{aligned} p_{\theta _1, \theta _2}(\textbf{x},\textbf{y},\textbf{z}) = p_{\theta _1}(\textbf{x}|\textbf{z}) \, p_{\theta _2}(\textbf{z} | \textbf{y} ) \, p(\textbf{y}) \end{aligned}$$

We further partition the latent space \(\textbf{z}\) such that one partition \(\textbf{z}_c\) encapsulates label associated characteristics, and the second partition \(\textbf{z}_{\backslash c}\) accounts for shared features of data (as in the vanilla VAEs):

$$\begin{aligned} p_{\theta _2}(\textbf{z} | \textbf{y} ) = p_{\theta _2}(\textbf{z}_c | \textbf{y} ) \cdot p(\textbf{z}_{\backslash c}) \end{aligned}$$

The characteristic partition \(\textbf{z}_c\) is further partitioned so that each label can access only a single latent variable. It guarantees the disentanglement of label information in latent representations. The intractable distribution \(p(\textbf{z} | \textbf{x},\textbf{y})\) is conditioned to both observation and label variables. It is approximated with the following inference model:

$$\begin{aligned} q_{\phi _1, \phi _2}(\textbf{z}|\textbf{x},\textbf{y}) = \frac{q_{\phi _1}(\textbf{y}|\textbf{z}_c) \, q_{\phi _2}(\textbf{z}|\textbf{x})}{q_{\phi _1, \phi _2}(\textbf{y}|\textbf{x})} \end{aligned}$$

where \(\phi _1, \phi _2\) are parameters of model distributions. The conditional distribution

$$\begin{aligned} q_{\phi _1, \phi _2}(\textbf{y} | \textbf{x}) = \int q_{\phi _1}(\textbf{y} | \textbf{z}_c) \, q_{\phi _2}(\textbf{z}|\textbf{x}) \, d\textbf{z} \end{aligned}$$

reflects that observation variables \(\textbf{x}\) and label variables \(\textbf{y}\) are connected via the characteristic partition \(\textbf{z}_c\). Label-related information is captured in an observation \(\textbf{x}\) by the inference model \(q_{\phi _2}(\textbf{z}|\textbf{x})\). At the same time, classifier \(q_\phi (\textbf{y}|\textbf{z}_c)\) forces the label-related latent variables \(\textbf{z}_c\) to capture characteristics of those labels.

As for the vanilla VAEs, the model is optimized by maximizing the evidence lower bound [9]. In the case of CCVAEs, it is equivalent to maximizing the following objective (see Appendix B.1 of [7] for derivation):

$$\begin{aligned} \mathcal {L}(\textbf{x},\textbf{y}) = \mathbb {E}_{q_{\phi _2}(\textbf{z}|\textbf{x})}\Bigg [ \frac{q_{\phi _1}(\textbf{y} | \textbf{z}_c)}{q_{\phi _1, \phi _2}(\textbf{y} | \textbf{x})} log \frac{p_{\theta _1}(\textbf{x}|\textbf{z}) \, p_{\theta _2}(\textbf{z} | \textbf{y} )}{q_{\phi _1}(\textbf{y} | \textbf{z}_c) \, q_{\phi _2}(\textbf{z}|\textbf{x})}\Bigg ] + log \, q_{\phi _1, \phi _2}(\textbf{y} | \textbf{x}) \end{aligned}$$
(3)

The classification term \(log \, q_{\phi _1, \phi _2}(\textbf{y}|\textbf{x})\) is essentially a learnable mapping from input data \(\textbf{x}\) to labels \(\textbf{y}\) that goes through the characteristic partition of the latent space \(\textbf{z}_c\). It applies pressure onto the partition to learn label-related characteristics from data and simultaneously performs data classification.

Intervention Analysis. The learned generative model forms the bridge between observations \(\textbf{x}\) and their labels \(\textbf{y}\) via latent variables \(\textbf{z}_c\). It allows one to analyze generative factors \(p_{\theta _1, \theta _2}(\textbf{x}|\textbf{z}, \textbf{y})\) of data related to those labels. One can explore the relation via intervention analysis. The algorithm for a single binary label of interest \(\textbf{y}^i\) is as follows. First, one fixes every dimension of the latent space \(\textbf{z}\) except the one \(\textbf{z}_c^i\) that corresponds to the label \(\textbf{y}^i\). Next, the value of \(\textbf{z}_c^i\) is sampled from \(p_{\theta _2}(\textbf{z}_c^i | \textbf{y}^i)\) for each value of \(\textbf{y}^i \in \{0, 1\}\). As a result, one receives two latent representations \(\textbf{z}_0\), \(\textbf{z}_1\) that vary only in a single dimension \(\textbf{z}_c^i\). Those representations are then reconstructed to the observation space \(\textbf{x}_0 \sim p_{\theta _1}(\textbf{x}|\textbf{z} = \textbf{z}_0)\), \(\textbf{x}_1 \sim p_{\theta _1}(\textbf{x}|\textbf{z} = \textbf{z}_1)\). The procedure is repeated for N times. As a result, one gets multiple pairs of reconstructions \((\textbf{x}_0, \textbf{x}_1)\) that are different only to the varied generative factor \(\textbf{z}_c^i\). One further calculates the average difference \(\frac{1}{N} \sum _{k=1}^N(\textbf{x}^k_1 - \textbf{x}^k_0)\) for each pair, and thus observes how the label \(\textbf{y}^i\) manifests itself in data.

4 Related Works

The fusion of generative and discriminative models with application to neuroimaging data is an active area of research. [11] demonstrate that using learned representations leads to more robust classification performance compared to feed-forward neural networks. [12] introduce VAEs into feature extraction from multichannel EEG data yielding better accuracy than traditional unsupervised approaches. [13] use stacked VAEs for semi-supervised learning on EEG data. However, the label information is usually encapsulated by multiple latent variables simultaneously. In this case, label characteristics are smeared across the latent space, thus complicating the analysis of label-related generative factors. It, in turn, limits both the interpretability and explainability of these models. One has to decode and interpret each latent variable and then infer the relation with label variables which is not a trivial task.

Two flavours of VAEs that are commonly applied to EEG data are conditional VAEs [13]Footnote 1 and VAEs with downstream classification [12]. In both approaches, the latent space is not partitioned with respect to label variables. Hence, compared to CCVAEs, their general disadvantage is reduced interpretability of classification as it is difficult to build a bridge between labels and generative factors.

Conditional VAEs. Conditional VAEs have a graphical model similar to the one of CCVAEs. The only difference is that the latent space is not partitioned to labels, i.e. \(\textbf{z} = \mathbf {z_c}\). Learnable parameters are optimized via maximizing the objective Eq. (3). The framework allows conditional sampling, so one can use intervention analysis to decode the meaning of learned generative factors. Nevertheless, the interpretation is complicated as a single label variable is connected to each dimension of the latent space.

VAEs + Downstream Classification. The model approximates the joint distribution of observed data and latent variables that is factorized as Eq. (1). The relation between latent variables and labels is built via classifying a latent representation. The model is trained via optimizing the following objective [11]:

$$\begin{aligned} \mathcal {L}(\textbf{x},\textbf{y}) = \mathbb {E}_{q_{\phi }(\textbf{z}|\textbf{x})}\bigg [ log \, p_\theta (x|z) - D_{KL}(q_\phi (z|x) || p(z)) - BCE(f_\xi (z), y)\bigg ] \end{aligned}$$
(4)

where \(f_\xi : Z \rightarrow Y\) is a learnable classifier with parameters \(\xi \), BCE is binary cross-entropy function. Here, the information about label variables is incorporated into the latent space via pressure applied by a downstream classification task. The model can be seen as a feed-forward deep neural network with additional regularization imposed by the decoder part of VAEs.

5 Experimental Details

Experimental Study. The study comprised 29 patients suffering from schizophrenia and 52 healthy controls. 14 subjects out of those 29 indicated the emergence of auditory verbal hallucinations (AVH), i.e. hearing voices with no external stimuli presented. Every participant was right-handed. Six different syllables were spoken to each participant (/ba/, /da/, /ka/, /ga/, /pa/, /ta/) for 500 ms simultaneously to each ear after 200 ms silence period. Meanwhile, the EEG recording was conducted with 64 electrodes where 4 EOG channels were used to monitor eye movements. For each subject, we repeated the procedure multiple times (number of trials for AVH: \(68.23 \pm 19.43\); SZ: \(68.76 \pm 14.79\) and HC: \(71.19 \pm 12.93\)). At the preprocessing step, the data was filtered from 20 120 Hz according to a protocol described in [14]. Therefore, only gamma-band frequencies are preserved. Afterwards, all channels were re-referenced to the common average. At last, muscle and visual artefacts were identified and removed. For our experiments, we utilized two parts of a recording: the resting one (first 200 ms with no syllable given) and the listening one (initial 200 ms when syllables were presented). The study of [14] contains detailed data acquisition and preprocessing information.

Experimental Data. For each EEG recording \([\zeta _1, \zeta _2, ..., \zeta _{61}] \), we assessed functional connectivity by calculating a correlation matrix:

$$\begin{aligned} \textbf{x}_{ij} = \frac{ cov(\zeta _i, \zeta _j) }{ \sqrt{ var(\zeta _i) \cdot var(\zeta _j) } } \end{aligned}$$

As a result, functional connectivity matrices play the role of observed data \(\textbf{x}\). We introduce 3 binary labels such that \(\textbf{y} = \) \([\textit{listening},\) \(\textit{schizophrenia}\), \(\textit{hallucinations}]\). To create a dataset for training models, we use the intra-patient paradigm, i.e. data from the same subject can appear simultaneously in training and test datasets. Thus, data of all the subjects are randomly sampled to form those datasets yielding 9000 training samples and 2000 test samples.

Implementation Details. One can find details regarding the parametrization of distributions in the supplementary material (Section S.1). We release the implementation at GitHub. For each framework, the parameters \(\theta _i, \phi _j\) (and \(\xi \) for VAEs) are trained via optimizing the corresponding objective. We use Adam optimizer with a learning rate of \(10^{-3}\). The training was performed in mini-batches of size 32 for 100 epochs. All models are trained on an NVIDIA Tesla V100 GPU from the Hemera HPC system of HZDR.

6 Results and Discussion

We found that high-dimensional latent spaces (dim \(> 32\)) hinder the reproducibility of generative factors learned by CCVAEs. For that reason, we keep the latent space of all models low-dimensional: \(\textbf{z} \in \mathbb {R}^{5}\) (\(\textbf{z}_{c} \in \mathbb {R}^{3}\), \(\textbf{z}_{\backslash c} \in \mathbb {R}^{2}\) for CCVAEs).

Table 1. Comparison of CCVAEs to baseline models in terms of accuracy and disentanglement scores on the test dataset. For each framework, 10 experiments were conducted.

Results. As shown in Table 1, CCVAEs outperform baseline models in both classification performance and disentanglement (see supplementary material S.2 for details). The framework consistently classifies observed data based on its low-dimensional representation, yielding a low standard deviation of accuracy. Besides, it demonstrates a high level of disentanglement, meaning that each label variable is captured only by a single latent dimension. For CCVAEs, generative factors are disentangled in the latent space by design, leading to the highest score. It is not as high as expected due to the correlation between pathology labels.

Fig. 2.
figure 2

Confusion matrices for different methods. The size of the circle indicates the value of the corresponding element. Rows correspond to label variables (L - listening, S - schizophrenia, H - hallucinations) while columns represent latent generative factors.

Fig. 3.
figure 3

Latent space generated by sampling from the inference model \(q_\phi (\textbf{z}|\textbf{x})\) of different methods. For CCVAEs, the axes corresponding to pathology variables are shown. For baseline methods, 2 randomly selected dimensions are visualized.

Disentangled Latent Space. To demonstrate how hard-wired disentanglement affects the latent space learned by CCVAEs, we construct confusion matrices in the following way. We compute latent representation for each data point in the test dataset and intervene (i.e. randomly change its value) upon a single dimension. We further observe how log-probabilities assigned by a pre-trained classifier change due to the intervention for each label. We calculate the difference for each label-latent pair yielding a confusion matrix (see Fig. 2). The optimal result would be one non-zero element per row (i.e. label), which means that each label corresponds to only a single generative factor. This is the case of CCVAEs, where one can observe one-to-one dependence between label variables and corresponding generative factors. This leads to latent representations being robust to variations in data generative factors, as those are independent by design. At the same time, the characteristics of labels are entangled within latent spaces of baseline frameworks. Hence, it is difficult to disentangle the influence of a label from other generative factors, which severely hinders interpretability.

Posterior Distribution. We further compare latent spaces learned by each framework. To visualize the latent space, we sample \(\textbf{z} \sim q_\phi (\textbf{z}|\textbf{x})\) for multiple \(\textbf{x}\) from the test dataset for each model. The result is shown in Fig. 3. In the case of CCVAEs, the distribution has three modes corresponding to subject cohorts in data (healthy, schizophrenia, schizophrenia followed by AVH). The separation is caused by the influence of the conditional prior and the classifier, aiming to separate representations encoding different label combinations. In the case of baseline frameworks, there is no strict regularisation that preserves label information within a partition of the latent space. As a result, features of data encoded by their latents are shared between cohorts of subjects (thus one or two modes). We discovered that both baseline methods often fail to jointly learn a low-dimensional representation and classify labels when the pressure on the KL divergence term in the loss objective is high. The problem is partially solved by introducing a scaling factor \(\beta \) for the term [8]. However, reducing the pressure might lead to untrustable reconstruction if prior \(p(\textbf{z})\) is not sufficiently close to the inference model \(q(\textbf{z}|\textbf{x})\). This is not the case for CCVAEs that do not require any manual fine-tuning and operate stably with low-dimensional latent spaces.

Fig. 4.
figure 4

Average difference in reconstructions of functional connectivity matrices when intervening on a single label: schizophrenia (Left) and hallucinations (Right). Connections that are stronger when a disorder is presented are shown in red; otherwise, blue. For clarity, we visualize only 40 connections with the highest absolute value.

Analyzing Pathological Mechanisms. We further investigate what connections are affected when intervening upon a single label dimension via intervention analysis (Fig. 4, see supplementary material S.3 for computation details). The model associates the emergence of AVH with alterations in frontotemporal brain areas (the highest positive difference), which have been repeatedly observed in prior studies [17, 18]. The salient connections are mainly located in the right hemisphere, which is supported by the fMRI study of [19]. The model also points toward reduced connectivity between hemispheres. It is coherent with the current hypothesis (see [14] for review) that connects the emergence of auditory verbal hallucinations with the interhemispheric miscommunication during auditory processing. Overall, the model can at least partially reconstruct the neurological mechanism of the symptom for functional connectivity. To explain the emergence of schizophrenia, the model focuses mainly on the left hemisphere. It is not surprising since the auditory function is left-lateralized for right-handed people [20, 21]. It would be an interesting direction for further studies to apply CCVAEs to learn the mechanisms of particular symptoms of the composite disorder (e.g. hallucinations, delusions, etc.).

7 Conclusion

We demonstrated how to apply the framework of characteristic capturing variational auto-encoders to EEG data analysis. The method encapsulates and disentangles the characteristics associated with different pathologies in the latent space. As generative factors are independent by design, one can decode their meaning and discover how those pathologies alter observed data. It leads to improved interpretability coupled with the high classification performance of neural networks. The framework is not limited to functional connectivity analysis or EEG data and can be easily adapted to different neuroimaging modalities.