1 Introduction

Multiple Sclerosis (MS) is a severe central nervous system disease with a highly nonlinear disease course where periodic relapses impair the patient’s quality of life. Clinical studies show that relapses co-occur with the appearance of new inflammatory MS lesions in MR images [19, 25], making MR imaging a central element for the clinical management of MS patients. Further, assessing new MS lesions is crucial for disease assessment and therapy monitoring [5, 8]. Unfortunately, prevailing therapies often involve highly active immunomodulatory drugs with potentially severe side effects. Hence, it necessitates developing machine learning models capable of predicting the future disease activity of individual patients to select the best therapy.

While recent approaches applied convolutional neural networks (CNN) to directly learn features from MR image space [3, 4, 21, 26, 28], there remain challenges to obtain an effective global representation to characterize disease status. We attribute the difficulty of MS inflammatory disease activity prediction to a set of distinct disease characteristics that are well observable in MR images. First, while lesions have a sufficient signal-to-noise profile in images, their variation in shape, size, and number of occurrences amongst patients make it challenging for existing CNN-based methods that process the whole MRI scan in one go. Second, with advanced age, small areas appearing in the MRI of healthy individuals may resemble MS. As such, it is crucial to not only predict MS but also to identify the lesions deemed consequential for the final prediction.

To solve this problem, we use concepts from geometric deep learning. Specifically, we propose a two-stage pipeline. First, the lesions in the MRI scans are segmented using a state-of-the-art 3D segmentation algorithm [9], and their image features are extracted with a self-supervised method [14]. Second, the extracted lesions are converted into a patient graph. The lesions act as nodes of the graph, while the edge connectivity is determined using the spatial proximity of the lesions. By this representation, we can solve the MS inflammatory disease activity prediction task as a graph-level classification problem. We argue that formulating the MS inflammatory disease activity prediction in our two-stage pipeline has the following advantages: (1) Graph neural networks can easily handle different numbers of nodes (lesions) and efficiently incorporate their spatial locations. (2) Modern segmentation [9, 16, 23] and representation learning methods [2, 11, 14] are effective tools for lesion detection and allow us to extract pathology-specific features. (3) By operating at the lesion level, it is possible to discover the lesions that contribute most to the eventual prediction, making the decisions more interpretable. Thus, our proposed solution can be a viable methodology for MS inflammatory disease activity prediction to handle the associated challenges.

Contributions. Our contribution is threefold: (1) we are the first to formulate the MS inflammatory disease activity prediction task as a graph classification problem, thereby bringing a new set of methods to a significant clinical problem. (2) We propose a two-stage pipeline that effectively captures inherent MS variations in MRI scans, thus generating an effective global representation. (3) We develop a self-pruning module, which assigns an importance score to each lesion and reduces the task complexity by prioritizing the critical lesions. Additionally, the assigned per-lesion importance score improves our model’s explainability.

2 Methodology

Overview. The objective is to predict MS inflammatory disease activity, i.e., to classify if new or significantly enlarged inflammatory lesions appear in the follow-up after the initial MRI scans. We denote the dataset as D(Xy), where X is the set of lesion patches extracted from MR scans, and y \(\in \) {0, 1} is the inflammatory disease activity status. For patient i, multiple lesion patches {\(x_1^i, x_2^i, ... x_n^i\)} can exist, where n is the total number of lesions. We aim to learn a mapping function f: \(\{x_1^i, x_2^i, ... x_n^i\}\rightarrow \{0, 1\}\). Please note that our formulation differs from existing methods [21, 26, 28], which aim to learn a direct mapping from the MR image to the inflammatory disease activity label. As shown in Fig. 1, our proposed method consists of four distinct components. We describe each component in the following sections.

Fig. 1.
figure 1

Our proposed Multiple Sclerosis (MS) inflammatory disease activity prediction framework. We first detect lesions in the MRI scan using nn-Unet [9]. A crop centered at the detection is extracted and used to learn self-supervised lesion features. Next, we build a graph from these detected lesions, where each lesion becomes a node, with the connections (edges) between lesions (nodes) defined by spatial proximity. This graph is processed using a graph neural network to generate enriched lesion features. Next, our self-pruning module (SPM) processes these enriched lesion features to determine an importance score for each lesion. The least scoring lesions are pruned-off and the highest scoring lesions are passed to the readout layer to obtain a graph-level feature vector. This graph-level vector is used for the final prediction.

Lesion Detection and Feature Extraction. We focus on the individual lesions instead of processing whole-brain MRI scans. This is important because MS lesions typically comprise less than 1% of voxels in the MRI scan. With this strategy, the graph model can aggregate lesion-level features for an effective patient-specific representation. First, we detect the lesions using a state-of-the-art nn-Unet [9] pre-trained with MR images and their consensus annotations from two neuro-radiologists. Then, for each detected lesion, we extract a small fixed-size patch centered at it. Finally, we learn self-supervised features z for the lesion using a recent Transformer-based approach [14].

Lesion Graph Processor. In the second stage, we generate a patient-specific graph G(VEZ) from the detected lesions. The lesions act as vertices V of this graph and are initialized with the crop-derived features Z. The spatial location s of the lesions is used to determine their connectivity E using a k-Nearest Neighbor (kNN) graph method [27]. Furthermore, the connected edges are weighted based on their spatial proximity. Specifically, given two lesions \(v_i\) and \(v_j\), with spatial locations \(s_{i}\) and \(s_{j}\) respectively, the edge weight \(w(v_i, v_j) = exp (- \frac{\Vert s_{i} - s_{j}\Vert ^2}{\tau ^2})\). \(\tau \) is a scalar that controls the contribution of distant lesions. Hence the final graph connectivity can be represented as:

$$\begin{aligned} E_{i,j}=\left\{ \begin{array}{@{}ll@{}} w(v_i, v_j), &{} \text {if}\ j \in N(v_i)\ or\ i \in N(v_j)\\ 0, &{} \text {otherwise} \end{array}\right. \end{aligned}$$
(1)

where \(N(v_i)\) are the nodes directly connected to the node \(v_i\). Constructing a graph from the lesions is instrumental in two aspects: (1) The framework allows us to work with varying numbers of lesions in different patients. Alternatively, sequential models could be employed. However, since the lesions lack a canonical ordering, such models would not achieve an effective global representation [12]. (2) It is possible to incorporate meaningful lesion properties such as spatial proximity and the number of lesions.

Please note that separate graphs are created for individual patients. Thus, MS inflammatory disease activity prediction is formulated as a graph-level classification task. The graph G(VEZ) can be processed using message-passing neural networks (such as GCN [10], GAT [1], EdgeConv [24], GraphSage [6], to name a few) to learn enriched lesion features \(\hat{Z}\). These enriched features are passed through to the self-pruning module.

Self-pruning Module. The number of lesions can vary substantially among the patients, including the possibility of false positives in the segmentation stage. As such, it is crucial to recognize the most relevant lesions for the final prediction. In addition, this will bring inherent explainability and make it easier for a doctor to validate model predictions. To accomplish this, the enriched lesion features \(\hat{Z}\) are passed through a self-pruning module (SPM). The SPM produces a binary mask M for each lesion to determine whether a lesion contributes to the classification. The SPM uses a learnable projection vector \(\vec {p}\) to compute importance scores (\(\hat{Z} \vec {p}\)/\({||\vec {p}||}\)) for the lesions. These scores are scaled with a sigmoid layer. We retain the high-scoring lesions and discard the rest, which is formulated as:

$$\begin{aligned} \begin{aligned} M&= top{\text{- }}r(\sigma (\frac{\hat{Z} \vec {p}}{||\vec {p}||}), r) \end{aligned} \end{aligned}$$
(2)

where \(\sigma (x) = \) 1/(\(1 + e^{-x}\)) is the sigmoid function and \(top{\text{- }}r(\cdot )\) is an operator which selects a fraction r of the lesions based on high importance score. r is a hyper-parameter in our setup. Since the masking process is part of the forward pass through the model during both training and inference stages and not a post hoc modification, we refer to it as self-pruning of nodes. Features of the remaining nodes (\(\hat{Z^\prime } = \hat{Z} \otimes M \)) are passed to the classification head.

It should be noted that the existence of multiple lesions is a typical characteristic of MS. Hence, a crucial aspect of MS management is that clinicians must identify signs of inflammatory disease activity in MR images to make treatment decisions [19]. Therefore, along with predicting inflammatory disease activity, interpreting the contribution of individual lesions to the prediction is essential. By assigning an importance score to each lesion, the SPM can provide explainability to clinicians at a lesion level, while existing CNN methods can not.

Classification Head. The classification head consists of a readout layer aggregating all the remaining node’s features to produce a single feature vector \(\hat{z^\prime }\) for the entire graph. This graph-level feature is passed through an MLP to obtain the final prediction \(\hat{y}\). We train our model using a binary cross-entropy loss.

$$\begin{aligned} \mathcal {L}_{clf} = - \frac{1}{N} \sum _{i=1}^N {(y_i\log (\hat{y_i}) + (1 - y_i)\log (1 - \hat{y_i}))} \end{aligned}$$
(3)

where N is the number of patients, \(y_i\) is the ground truth inflammatory disease activity information and \(\hat{y_i}\) is the model prediction.

3 Experiments

Datasets and Image Preprocessing. Our approach is evaluated on a cohort of 430 MS patients collected following approval from the local IRB [7]. Patients included in this analysis were diagnosed with relapsing-remitting MS, with a maximum disease duration of three years at the time of baseline scan. We collect the FLAIR and T1w MR scans for each patient. The scans have a uniform voxel size of 1\(\,\times \,\)1\(\,\times \,\)1 mm3, were rigidly co-registered to the MNI152 atlas and skull-stripped using HD-BET [17]. Three neuro-radiologists independently read longitudinal subtraction imaging, where FLAIR images from two time points were co-registered and subtracted. In this vein, new and significantly enlarged lesions are identified as positive inflammatory disease activity.

The dataset contains MS inflammatory disease activity information for clinically relevant one-year and two-year intervals [20]. At the end of the first year, we have the inflammatory disease activity status of 430 patients, with 303 showing activity and 127 not. Similarly, at the end of two years, we have data available for 347 patients, with 287 showing activity and 60 not. Thus, the dataset shows a slight imbalance in favor of inflammatory disease activity. This imbalance is a typical property in the MS patients cohort that impairs algorithm development.

Feature Extraction and Training Configuration. We use an nn-Unet [9] for lesion segmentation and detection. Then a uniform crop of size \(24 \times 24 \times 24 \) mm3 is extracted centered at each lesion. The cropped patches are passed through a transformer-based masked autoencoder [14] to extract self-supervised lesion features. The encoder produces a 768-dimensional feature vector for each patch. We also append normalized lesion coordinates to the encoder output to get the final lesion features.

The lesions are connected using a k-nearest neighbor algorithm with \(k = 5\). Further, these connections (edges) are weighted using \(\tau \) = 0.01 (Eq. 1). Two message-passing layers with hidden dimensions of 64 and 8, respectively, process the generated graph to enrich lesion features. Next, the enriched features are passed through the SPM. The SPM uses a learnable projection vector \(\vec {p} \in R^8\) and sigmoid activation to learn the importance score. Based on this importance score, a mask is produced to select r = 0.5 (i.e., 50%) of the highest-scoring lesions and discard the rest. Next, a sum aggregation is used as the readout function. These aggregated features are passed through 2 feed-forward layers with hidden dimensions of size 8. Finally, the features are passed to a sigmoid function to obtain the final prediction.

The model is trained for 300 epochs using AdamW optimizer [13] with 0.0001 weight decay. The base learning rate is 1e-4. The batch size is set to 16. A dropout layer with p = 0.5 is used between different feed-forward blocks. Since the dataset is imbalanced in favor of patients experiencing inflammatory disease activity, we use a balanced batch sampler to load approximately the same number of positive and negative samples in each mini-batch.

Evaluation Strategy, Classifier, and Metrics. We report our results on MS inflammatory disease activity prediction for the clinically relevant one and two-year intervals [20]. The Area Under the Receiver Operating Characteristic Curve (AUC) is used as the evaluation metric. We use 80% of the samples as the training set and 10% as the validation. The validation set is used to select the best model which is then applied to the remaining 10% cases. This procedure is iterated until all cases have been assigned to a test set once (ten-fold cross-validation). The same folds are used for the proposed model and baseline algorithms.

4 Results

Quantitative Comparison. Table 1 shows the classification performance for the ten folds on one-year and two-year lesion inflammatory disease activity prediction. The ± indicates the corresponding standard deviations. We compare our method against two existing approaches for MS inflammatory disease activity prediction baselines, a 3D Res-Net [28], and a multi-resolution CNN architecture [21]. These methods learn a direct mapping from the MR image to the inflammatory disease activity label. Our graph model outperforms the baseline methods on one (0.67 vs. 0.61 AUC) and two-year inflammatory disease activity prediction (0.66 vs. 0.60 AUC). In the following, we analyze and discuss each component of our established framework.

Table 1. Comparison of our method against the existing CNN-based solutions. We report the AUC score for MS Inflammatory Disease Activity (IDA) prediction at the end of one and two years. Our proposed two-stage solution outperforms the existing baselines, achieving the best AUC score on both prediction tasks.

Ablation Study. In this section, we analyze the importance of different components of our proposed method. We defer the analysis of the lesion feature representation to the appendix owing to space constraints.

The Effectiveness of Graph Structure. Since the lesion feature extractor generates rich lesion features, one may argue that the graph structure is unwarranted. There are two alternatives to using a graph, (i) completely discard the graph structure, use a feed-forward layer to enrich the lesion features further, and aggregate them to perform classification [15] (Since this formulation regards the input as a set, we call this Set-Proc model). (ii) Aggregate all the lesion features for a patient using a mean aggregation and process the aggregated feature by traditional machine learning algorithms such as random forest (RF), support vector machine (SVM) with the RBF-kernel, and logistic regression (LR). Table 2 compares our model’s performance against these alternatives. Our proposed solution obtains better AUC than the alternatives, indicating that incorporating the graph structure is beneficial for eventual prediction.

Table 2. Effectiveness of the graph structure. We compare our method to traditional ML algorithms and a set-based aggregation baseline, both of which discard the graph structure. The incorporation of the graph structure is beneficial for downstream prediction.
Table 3. Importance of spatial proximity. GAT (partially) and TransformerConv (completely) ignore spatial proximity in the graph while GCN, EdgeConv, and GraphSAGE incorporate it. The incorporation of spatial proximity is beneficial for downstream prediction.

The Importance of Encoding Spatial Proximity. The spatial proximity in our model is encoded at two levels. First, lesion connectivity is determined using a NN graph, and second, we weigh the edges based on their distance. Graph convolution layers such as EdgeConv, GCN, and GraphSAGE take the edge weights into account. (EdgeConv does it implicitly by taking a difference of lesion features that already contain spatial information).

On the other hand, the GAT model learns an attention weight and ignores the pre-defined edge weights. However, it still computes these coefficients on only the connected nodes. We can go further, completely ignore the distances and instead use a fully connected graph. The TransformerConv [18] on such a graph is equivalent to applying the well-known transformer encoder [22] on the inputs. Table 3 shows that the methods that ignore spatial proximity (TransformerConv) or do not use distance-based weighting (GAT) struggle. On the other hand, EdgeConv, GCN, and GraphSAGE work better. We use GCN in our model owing to its superior performance.

The Contribution of the Self-Pruning Module (SPM). The SPM selects a subset of lesions for the final prediction during the training and evaluation phases. However, the proposed classification method can work without it. In this case, none of the lesions is discarded during the readout operation. Table 4 shows the classification results with and without the SPM. We observe that including SPM leads to better outcomes across different message-passing networks. An explanation could be that the SPM can better handle patient variations (in terms of the total number of lesions) by operating on a subset of lesions (Fig. 3).

Table 4. Comparison of the performance of different message passing layers with and without the self-pruning module. Incorporating the self-pruning module is beneficial for most message passing layers.
Fig. 2.
figure 2

Model performance against different retention ratio r. Best performance observed for \(r=0.5\).

Analysis of Hyperparameters. The retention ratio r and the number of neighbors k used for building the graph are the two critical hyperparameters in our proposed framework. We discuss the effect of the retention ratio r here and defer discussion about k to the appendix.

Effect of Retention Ratio r. The retention ratio \(r \in (0, 1]\) controls the fraction of lesions retained after the self-pruning module. If we set its value to 1, all the lesions are retained for the final prediction and thus, bypassing the self-pruning module. Any other value implies that we ignore at least a few lesions in the readout layer. Since the number of lesions can vary across graphs, we retain \(\lceil (N. r) \rceil \) lesions after the self-pruning layer. To find the optimal r, we test our model with r between 0.1 and 1.0. The results are summarized in Fig. 2. We set r to 0.5 for both tasks.

Fig. 3.
figure 3

Lesions selected by the SPM for two-year inflammatory disease activity prediction are highlighted with a green bounding box. We also show the zoomed-in view of the lesion. A concurrent lesion in the scan ignored by the SPM is shown with a blue bounding box. According to their size and location, the selected lesions are most likely to be relevant to the prediction. (Color figure online)

5 Conclusion

Predicting MS inflammatory disease activity is a clinically relevant, albeit challenging task. In this work, we propose a two-stage graph-based pipeline that surpasses existing CNN-based methods by decoupling the tasks of detecting and learning rich semantic features for lesions. We also propose a self-pruning module that further improves model generalizability by handling variations in the number of lesions within patients. Most importantly, we frame the MS inflammatory disease activity prediction as a graph classification problem. We hope our work provides a new perspective and leads to cutting-edge research at the intersection of graph processing and MS inflammatory disease activity prediction.