Keywords

These keywords were added by machine and not by the authors. This process is experimental and the keywords may be updated as the learning algorithm improves.

1 Introduction

Substantial evidence suggests that many major psychiatric and neurological disorders are associated with aberrations in the network structure of the brain [5, 7]. With the availability of modern neuroimaging modalities such as diffusion tensor (DTI) and functional (fMRI) imaging, there is currently an exciting potential for researchers to identify connectivity-based biomarkers of disease states. Since brain networks are known to exhibit complex interactions, multivariate pattern analysis (MVPA) methods are particularly suitable here, as they aim to identify the site of the pathology by examining the data as a whole, accounting for the correlations among the network features.

In this work, we are interested in applying MVPA methods on diffusion-based structural connectomes (SCs) to identify the patterns of structural dysconnectivity induced by a brain disorder. However, due to the high dimensionality of SCs, standard MVPA methods such as the support vector machine (SVM) become prone to overfitting and thus tend to generalize poorly to test data. Even when generalizability is achieved, SVM lacks clinical interpretability since it returns a dense, high dimensional weight vector. One way to address this is by adding an L1-regularizer to the SVM objective for feature selection [6], but this approach is known to perform poorly when the features are highly correlated. Thus dimensionality reduction becomes critical for improving classification performance and interpretability. Some well-established dimensionality reduction methods in neuroimaging include the principal and independent component analysis (PCA and ICA). However, these approaches do not preserve the non-negativity of the SCs, thus return global representations of brain network that are highly overlapping and lack interpretability since negative structural connection is biologically ill-defined.

Non-negative matrix factorization (NMF) [9] is a relatively recent method that addresses this problem by incorporating non-negativity as a constraint. This constraint leads to a more localized “parts-based” representation where the data is decomposed into purely additive combinations of non-negative basis components. For our work, the bases can be interpreted as data-driven subnetworks, and the corresponding coefficients provide a low-dimensional representation of the SC that can be used in a classifier.

However, despite its success, NMF possesses several limitations. First, NMF does not guarantee the basis components to be local and parts-based, i.e., the subnetworks may be global representations that are overlapping and redundant. Moreover, standard NMF and many of its variants are unsupervised, thus they ignore discriminative structures that may signify important group differences. Finally, NMF assumes that the data are sampled from a Euclidean space, thus does not account for the intrinsic manifold structure underlying the data. While this last issue was addressed in a recent work by Ghanbari et al. [7] under a graph-embedding framework, their method is also unsupervised and thus ignores label information. On the other hand, although supervised subnetwork detection frameworks have been introduced in some recent works [2, 8], these methods do not account for the manifold structure underlying the data.

To overcome these limitations, in this paper we introduce a novel supervised NMF framework for identifying an orthogonal set of subnetworks that is interpretable and emphasizes group differences in structural connectivity. The method also respects the intrinsic geometric structure in the data through manifold regularization [7, 10], which encourages subnetwork representations to be smooth with respect to the data manifold. To solve the proposed objective function, we introduce an optimization algorithm based on the alternating direction method (ADM), which has recently been demonstrated to solve NMF with superior performance over other state-of-the-art algorithms [12]. The proposed framework was evaluated on a TBI dataset, and the results demonstrate the interpretability and the discriminative capacity of the subnetworks.

2 Method

Projective NMF.

Let \( \varvec{X} = \left[ {\varvec{x}_{1} , \cdots ,\varvec{x}_{n} } \right] \) and \( \varvec{y} = \left[ {y_{1} , \cdots ,y_{n} } \right]^{T} \) denote a set of training samples consisting of SCs \( \varvec{ x}_{i} \in {\mathbb{R}}_{ + }^{p} ,\,i = 1, \cdots ,n, \) and \( y_{i} \in \left\{ { \pm 1} \right\} \) indicates the label of subject \( i \). An SC is a vector representation of the brain network obtained via tractography, where each vector elements represents the strength of structural connection between distinct pair of brain regions (see Sect. 3 for details). Given a target dimension \( r \ll p \), NMF learns a decomposition of the form \( \varvec{X} \approx \varvec{WH} \) by minimizing the Frobenius norm error \( \left\| {\varvec{X} - \varvec{WH}} \right\|_{F}^{2} \), where \( \varvec{W} = [\varvec{w}_{1} , \cdots ,\varvec{w}_{r} ] \in {\mathbb{R}}_{ + }^{p \times r} \) is the basis matrix and \( \varvec{H} = [\varvec{h}_{1} , \cdots ,\varvec{h}_{n} ] \in {\mathbb{R}}_{ + }^{r \times n} \) is the coefficient matrix. In the context of our work, the columns of \( \varvec{W} \) are connectivity bases that represent subnetworks.

Following [10], we assume that \( \varvec{H} \) is obtained from a linear projection of \( \varvec{X} \), i.e., \( \varvec{H} = \varvec{PX} \), where \( \varvec{P} \in {\mathbb{R}}_{ + }^{r \times p} \) is a nonnegative projection matrix that embeds the data onto the intrinsic subspace. Under this assumption, the objective function for NMF becomes

$$ \mathop {min}\limits_{{\varvec{W},\varvec{P} \ge 0}} \frac{1}{2}\left\| {\varvec{X} - \varvec{W}(\varvec{PX})} \right\|_{F}^{2} . $$
(1)

A key advantage of this projective NMF is that once an optimal projection \( \varvec{P}^{\varvec{*}} \) is learned from solving (1), the trained model can be readily generalized to unseen data. That is, given a new test data \( \varvec{x}^{\varvec{*}} \), we can immediately obtain its low dimensional representation by \( \varvec{h}^{\varvec{*}} = \varvec{P}^{\varvec{*}} \varvec{x}^{\varvec{*}} . \) This is extremely important for running cross-validation (CV).

Orthogonal NMF with Manifold Regularization and Label Information.

Despite the merits of the projective NMF, it has three key deficiencies. Firstly, it is often reported that NMF does not necessarily return meaningful parts-based decompositions for some datasets. Secondly, although many real-world data are found to lie in a low dimensional manifold, NMF assumes that the data are sampled from a Euclidean space, neglecting the intrinsic geometric structure in the data. Thirdly, traditional NMF models are unsupervised and thus ignore the discriminative information from the different label groups.

In light of these limitations, we propose to include the following terms in our model:

  1. 1.

    Orthogonality constraint: \( \varvec{F}_{1} \left( \varvec{W} \right) = I_{{\varvec{\Omega}}} (\varvec{W}) \), where \( \Omega : = \left\{ {\varvec{W}\, \in \,{\mathbb{R}}^{{\varvec{p}\, \times \,\varvec{r}}} |\varvec{W}^{\varvec{T}} \varvec{W} = \varvec{I}_{\varvec{r}} } \right\} \) and \( I_{C} ( \cdot ) \) is the indicator function of a set \( C: \) \( I_{C} \left( \varvec{W} \right) = 0 \) if \( \varvec{W} \in C \) and \( I_{C} \left( \varvec{W} \right) = \infty \) elsewise.

  2. 2.

    Manifold regularization: \( \varvec{F}_{2} \left( \varvec{P} \right) { = }\mathop \sum \limits_{i = 1}^{n} \mathop \sum \limits_{j = 1}^{n} \left\| {\varvec{Px}_{i} - \varvec{Px}_{j} } \right\|S_{ij} . \)

  3. 3.

    Classification error: \( \varvec{F}_{3} \left( {\varvec{P},\varvec{\beta},b} \right) = \left\| {\varvec{y} - \left( {\varvec{PX}} \right)^{T}\varvec{\beta}- b1_{n} } \right\|_{2}^{2} \), where \( \varvec{\beta}\in {\mathbb{R}}^{\varvec{r}} \) and \( b \in {\mathbb{R}} \) defines a hyperplane in the intrinsic subspace, and \( 1_{n} \in {\mathbb{R}}^{\varvec{n}} \) is a vector of all ones.

The \( \varvec{F}_{1} \) term constrains the basis matrix to reside within the set \( \Omega \), which is the set of orthogonal matrices known as the Stiefel manifold [11]. Since \( \varvec{W} \) is non-negative, orthogonality implies that the bases representing the subnetworks are non-overlapping, which enhances interpretability and eliminates redundancy.

The \( \varvec{F}_{2} \) term ensures smoothness of the low dimensional representation with respect to the manifold structure encoded in affinity matrix \( \varvec{S} \in {\mathbb{R}}^{{\varvec{n} \times \varvec{n}}} \). Intuitively, this regularizer preserves the intrinsic geometric structure in the data by encouraging representations \( \varvec{Px}_{i} \) and \( \varvec{Px}_{j} \) to be close if \( S_{i,j} \) is large, i.e., subjects \( i \) and \( j \) are similar under some notion. This regularizer can also be expressed in terms of the trace operator: \( \varvec{F}_{2} \left( \varvec{P} \right) = {\text{Tr}}\left( {(\varvec{PX})\varvec{L}(\varvec{PX})^{\varvec{T}} } \right) \), where \( \varvec{L} \in {\mathbb{R}}^{{\varvec{n} \times \varvec{n}}} \) is the graph Laplacian defined by \( \varvec{L} = \varvec{D} - \varvec{S}, \) and \( \varvec{D} \) is a diagonal matrix with \( D_{i,i} = \mathop \sum \limits_{j = 1}^{n} S_{i,j} \forall i \). While the type of inter-subject relationship that can be encoded via the affinity matrix \( \varvec{S} \) is general, in this work, we will take advantage of the clinical scores that are used to evaluate patients, and create a “disease-severity graph” to capture the disease-induced variation in the SCs. Specifically, we will assign higher value to \( S_{i,j} \) if subjects \( i \) and \( j \) share similar severity scores.

Finally, the classification error term \( \varvec{F}_{3} \) enhances the discriminatory power of NMF by encouraging the label groups in the low dimensional embedding \( \varvec{PX} \) to be separated by a hyperplane \( \varvec{\beta} \) (for clarity, the intercept term b is dropped from our presentation hereon after). Thus, our proposed NMF model seeks to identify subnetwork bases that are not only reconstructive of data but also discriminative of label groups (note that the squared error is used here to allow the ADM algorithm to admit a closed form solution).

Integrating the above constraint terms into the projective NMF Eq. (1) gives us our final objective function (\( \lambda_{1} ,\lambda_{2} \ge 0 \) below are regularization parameters):

$$ \mathop {\hbox{min} }\limits_{{\varvec{W},\varvec{P} \ge 0,\varvec{\beta}}} \left\| {\varvec{X} - \varvec{W}\left( {\varvec{PX}} \right)} \right\|_{F}^{2} +\uplambda_{1} {\text{Tr}}\left( {(\varvec{PX})\varvec{L}(\varvec{PX})^{\varvec{T}} } \right){ + }\uplambda_{2} \left\| {\varvec{y} - \left( {\varvec{PX}} \right)^{\varvec{T}}\varvec{\beta}} \right\|_{2}^{2} { + }I_{\Omega } \left( \varvec{W} \right). $$
(2)

ADM Algorithm.

We now introduce an optimization algorithm based on the ADM algorithm [12] for solving the proposed cost function. Before applying ADM, we first convert objective function (2) into the following equivalent constrained form by introducing auxiliary variables \( \{ \varvec{H},\tilde{\varvec{H}},\tilde{\varvec{W}}_{1} ,\tilde{\varvec{W}}_{2} ,\tilde{\varvec{P}}\} \) (a technique called variable splitting):

$$ \mathop {\hbox{min} }\limits_{{\begin{array}{*{20}c} {W,P,H,\beta ,} \\ {\tilde{\varvec{P}},\tilde{\varvec{H}},\tilde{\varvec{W}}_{1} ,\tilde{\varvec{W}}_{2} } \\ \end{array} }} \left\| {\varvec{X} - \varvec{WH}} \right\|_{F}^{2} +\uplambda_{1} {\text{Tr}} \left( {\varvec{\tilde{H}L\tilde{H}}^{\varvec{T}} } \right) +\uplambda_{2} \left\| {\varvec{y} - \varvec{H}^{\varvec{T}}\varvec{\beta}} \right\|_{2}^{2} { + }I_{ + } \left( {\tilde{\varvec{W}}_{1} } \right) + I_{\Omega } \left( {\tilde{\varvec{W}}_{2} } \right) { + }I_{ + } \left( {\tilde{\varvec{P}}} \right) $$
$$ {\text{such that }}\varvec{H} = \varvec{PX},\varvec{W} = \tilde{\varvec{W}}_{1} ,\varvec{W} = \tilde{\varvec{W}}_{2} ,\varvec{P} = \tilde{\varvec{P}},\varvec{H} = \tilde{\varvec{H}}, $$

where \( I_{ + } ( \cdot ) \) denotes the indicator function of the non-negative orthant. Although the auxiliary variables introduced from variable splitting may appear redundant, this strategy is commonly used in ADM frameworks (see [12] for example), as it allows the ADM subproblems to be solved in closed form. In the context of our work, the augmented Lagrangian (AL) function for the above constrained problem is given by:

$$ \begin{aligned} {\mathcal{L}}_{\text{AL}} & \left( {\varvec{W},\varvec{P},\varvec{\beta},\tilde{\varvec{P}},\varvec{H},\tilde{\varvec{H}},\tilde{\varvec{W}}_{1} ,\tilde{\varvec{W}}_{2} ,{\varvec{\Lambda}}_{{\tilde{\varvec{W}}_{1} }} ,{\varvec{\Lambda}}_{{\tilde{\varvec{W}}_{2} }} ,{\varvec{\Lambda}}_{{\tilde{\varvec{P}}}} ,{\varvec{\Lambda}}_{\varvec{H}} ,{\varvec{\Lambda}}_{{\tilde{H}}} } \right) = \left\| {\varvec{X} - \varvec{WH}} \right\|_{F}^{2} \\ & + \lambda_{1} {\text{Tr}} \left( {\varvec{\tilde{H}L\tilde{H}}^{\varvec{T}} } \right) + \lambda_{2} \left\| {\varvec{y} - \varvec{H}^{\varvec{T}}\varvec{\beta}} \right\|_{2}^{2} + I_{ + } \left( {\tilde{\varvec{W}}_{1} } \right) + I_{\Omega } \left( {\tilde{\varvec{W}}_{2} } \right) + I_{ + } \left( {\tilde{\varvec{P}}} \right) \\ & + \left\langle {{\varvec{\Lambda}}_{{\tilde{\varvec{W}}_{1} }} ,\varvec{W} - \tilde{\varvec{W}}_{1} } \right\rangle + \left\langle {\varvec{\varLambda}_{{\tilde{\varvec{W}}_{2} }} ,\varvec{W} - \tilde{\varvec{W}}_{2} } \right\rangle + \left\langle {\varvec{\varLambda}_{\varvec{P}} ,\varvec{P} - \tilde{\varvec{P}}} \right\rangle + \left\langle {\varvec{\varLambda}_{\varvec{H}} ,\varvec{H} - \varvec{PX}} \right\rangle + \left\langle {\varvec{\varLambda}_{{\tilde{\varvec{H}}}} ,\varvec{H} - \tilde{\varvec{H}}} \right\rangle \\ & + \frac{\uprho}{2}\left\{ {\left\| {\varvec{W} - \tilde{\varvec{W}}_{1} } \right\|_{F}^{2} + \left\| {\varvec{W} - \tilde{\varvec{W}}_{2} } \right\|_{F}^{2} + \left\| {\varvec{P} - \tilde{\varvec{P}}} \right\|_{F}^{2} + \left\| {\varvec{H} - \varvec{PX}} \right\|_{F}^{2} + \left\| {\varvec{H} - \tilde{\varvec{H}}} \right\|_{F}^{2} } \right\}, \\ \end{aligned} $$

where \( \left\{ {\varvec{W},\varvec{ P},\varvec{ \beta },\tilde{\varvec{W}}_{1} ,\tilde{\varvec{W}}_{2} ,\tilde{\varvec{P}},\varvec{H},\tilde{\varvec{H}}} \right\} \) and \( \left\{ {\varvec{\varLambda}_{{\tilde{W}_{1} }} ,\varvec{\varLambda}_{{\tilde{W}_{2} }} ,\varvec{\varLambda}_{{\tilde{P}}} ,\varvec{\varLambda}_{\varvec{H}} ,\varvec{\varLambda}_{{\tilde{H}}} } \right\} \) are primal and dual variables, \( \rho > 0 \) is the AL penalty parameter, and \( \cdot , \cdot \) denotes the trace inner product. The ADM algorhm is derived by alternately minimizing \( {\mathcal{L}}_{\text{AL}} \) with respect to each primal variable while holding others fixed, followed by a gradient ascent step on dual variables. The overall ADM algorithm can be summarized as follows:

Repeat until convergence after variable initialization:

Primal updates (1)

Primal updates (2)

Dual updates

\( \varvec{P} \leftarrow {\text{arg min}}_{\varvec{P}} {\mathcal{L}}_{AL} \)

\( \tilde{\varvec{P}} \leftarrow {\text{arg min}}_{{\tilde{P}}} {\mathcal{L}}_{AL} \)

\( \varvec{\varLambda}_{{\tilde{\varvec{P}}}} \leftarrow\varvec{\varLambda}_{{\tilde{\varvec{P}}}} + \rho (\varvec{P} - \tilde{\varvec{P}}) \)

\( \varvec{W} \leftarrow {\text{arg min}}_{\varvec{W}} {\mathcal{L}}_{AL} \)

\( \tilde{\varvec{W}}_{1} \leftarrow {\text{arg min}}_{{\tilde{\varvec{W}}_{1} }} {\mathcal{L}}_{AL} \)

\( \varvec{\varLambda}_{{\tilde{\varvec{W}}_{1} }} \leftarrow\varvec{\varLambda}_{{\tilde{\varvec{W}}_{1} }} + \rho (\varvec{W} - \tilde{\varvec{W}}_{1} ) \)

\( \varvec{H} \leftarrow {\text{arg min}}_{\varvec{H}} {\mathcal{L}}_{AL} \)

\( \tilde{\varvec{W}}_{2} \leftarrow {\text{arg min}}_{{\tilde{\varvec{W}}_{2} }} {\mathcal{L}}_{AL} \)

\( \varvec{\varLambda}_{{\tilde{\varvec{W}}_{2} }} \leftarrow\varvec{\varLambda}_{{\tilde{\varvec{W}}_{2} }} + \rho (\varvec{W} - \tilde{\varvec{W}}_{2} ) \)

\( \varvec{\beta}\leftarrow {\text{arg min}}_{\varvec{\beta}} {\mathcal{L}}_{AL} \)

\( \tilde{\varvec{H}} \leftarrow {\text{arg min}}_{{\tilde{\varvec{H}}}} {\mathcal{L}}_{AL} \)

\( \varvec{\varLambda}_{\varvec{H}} \leftarrow\varvec{\varLambda}_{\varvec{H}} + \rho \left( {\varvec{H} - \varvec{PX}} \right) \)

  

\( \varvec{\varLambda}_{{\tilde{\varvec{H}}}} \leftarrow\varvec{\varLambda}_{{\tilde{\varvec{H}}}} + \rho (\varvec{H} - \tilde{\varvec{H}}) \)

The primal updates above can all be carried out efficiently in closed form:

\( \varvec{P} \leftarrow \left( {\varvec{HX}^{\varvec{T}} + \tilde{\varvec{P}} + \left[ {\varvec{\varLambda}_{\varvec{H}} \varvec{X}^{\varvec{T}} -\varvec{\varLambda}_{\varvec{P}} } \right]/\rho } \right)\left( {\varvec{XX}^{T} + \varvec{I}_{p} } \right)^{ - 1} \)

\( \tilde{\varvec{P}} \leftarrow { \hbox{max} }\left( {0,\varvec{P} + {\varvec{\Lambda}}_{{\tilde{\varvec{P}}}} /\rho } \right) \)

\( \varvec{W} \leftarrow \left( {\varvec{XH}^{\varvec{T}} + \rho \left[ {\tilde{\varvec{W}}_{1} + \tilde{\varvec{W}}_{2} } \right] -\varvec{\varLambda}_{{\tilde{\varvec{W}}_{1} }} -\varvec{\varLambda}_{{\tilde{\varvec{W}}_{2} }} } \right)\left( {\varvec{HH}^{\varvec{T}} { + 2}\rho \varvec{I}_{\varvec{r}} } \right)^{ - 1} \)

\( \tilde{\varvec{W}}_{1} \leftarrow { \hbox{max} }\left( {0,\varvec{W}_{1} +\varvec{\varLambda}_{{\tilde{\varvec{W}}_{1} }} /\rho } \right) \)

\( \varvec{H} \leftarrow \left( {\varvec{W}^{\varvec{T}} \varvec{W} + 2\rho \varvec{I}_{\varvec{r}} + \lambda_{2} \varvec{\beta \beta }^{T} } \right)^{ - 1} \left( {\varvec{W}^{\varvec{T}} \varvec{X} + \rho \varvec{PX} -\varvec{\varLambda}_{\varvec{H}} + \lambda_{2} \varvec{\beta y}^{T} } \right) \)

\( \tilde{\varvec{H}} \leftarrow \left( {\rho \varvec{H} +\varvec{\varLambda}_{{\tilde{\varvec{H}}}} } \right)\left( {\lambda_{1} \varvec{L} + \rho \varvec{I}_{\varvec{n}} } \right)^{ - 1} \)

\( \varvec{\beta}\leftarrow \left( {\varvec{HH}^{T} } \right)^{ - 1} \varvec{y} \)

\( \tilde{\varvec{W}}_{2} \leftarrow {\text{Pro}}j_{\Omega } \left( {\varvec{W} +\varvec{\varLambda}_{{\tilde{\varvec{W}}_{2} }} /\rho } \right) \)

Note \( {\text{Pro}}j_{\Omega } \left( \cdot \right) \) for the \( \tilde{\varvec{W}}_{2} \) update denotes the Euclidean projection of a matrix onto the Stiefel manifold. Letting \( \varvec{A} \in {\mathbb{R}}^{p \times r} \,( r \le p) \) denote a rank-r matrix, this is given by:

$$ {\text{Pro}}j_{\Omega } \left( \varvec{A} \right) = \mathop { {\mathbf{arg\,min}} }\limits_{{\varvec{Q} \in\Omega }} || \varvec{A} - \varvec{Q}|| _{\varvec{F}}^{2} = \varvec{U}\left[ {\begin{array}{*{20}c} {\varvec{I}_{\varvec{r}} } \\ 0 \\ \end{array} } \right]\varvec{V}^{\varvec{H}} $$
(3)

Here \( \varvec{U}{\varvec{\Sigma}}\varvec{V}^{\varvec{H}} \) represents the SVD of \( \varvec{A} \) and \( \varvec{ }0 \in {\mathbb{R}}^{{(\varvec{p} - \varvec{r}) \times \varvec{r}}} \) is a matrix of all zeros; solution (3) is unique as long as \( \varvec{A} \) is full column rank (see Proposition 7 in [11]).

3 Experiments and Conclusions

Dataset.

We apply our method to a TBI dataset consisting of 34 TBI patients and 32 age-matched controls. While the control subjects were scanned only once, the TBI patients were scanned and evaluated at three different time points: 3, 6, and 12 months post-injury. Of the 34 TBI patients, 18 had all 3 time points, 9 had 2 and 7 had only one timepoint. The functional outcome of patients was evaluated using the Glasgow Outcome Scale Extended (GOSE) and Disability Rating Scale (DRS), which are commonly used in TBI. GOSE ranges from 1 = dead to 8 = good recovery, whereas DRS ranges from 0 = normal to 29 = extremely vegetated. In total, the dataset comprises 111 total scans, with 32 labeled control and 79 labeled TBI. All scans are accompanied with 11 clinical scores that are intended to assess the cognitive functioning of the subject.

Creating the SCs.

DTI data was acquired for each subject (Siemens 3T TrioTim, 8 channel head coil, single shot spin echo sequence, TR/TE = 6500/84 ms, b = 1000 s/mm2, 30 gradient directions). 86 ROIs from the Desikan atlas were extracted to represent the nodes of the structural network. Probabilistic tractography [3] was performed from each of these regions with 100 streamline fibers sampled per voxel, resulting in an 86 × 86 matrix of weighted connectivity values, where each element represents the conditional probability of a pathway between regions, normalized by the active surface area of the seed ROI. Finally, the 86 × 86 connectivity matrix of each subject was vectorized to its \( p \) = 3655 lower triangular elements, resulting in \( \varvec{x} \in {\mathbb{R}}_{ + }^{\varvec{p}} \) representing the SC.

Implementation Details.

We applied our method to SCs computed from the TBI dataset to compute the subnetwork bases and their corresponding NMF coefficients; here we let \( {\text{y = + }}1 \) indicate TBI and \( {\text{y = - }}1 \) indicate control. The disease-severity graph was created using the functional outcome indices of GOSE/DRS as follows. First, we constructed a symmetrized k-nearest-neighbor (k-NN) graph with k = 5, where the distance between scans \( i \) and \( j \) was measured as \( d_{i,j} = ({\text{GOS}}E_{i} - {\text{GOS}}E_{j} )^{ 2} + \, ({\text{DR}}S_{i } - {\text{DR}}S_{j} )^{ 2} \). Then a binary affinity graph was created by setting \( S_{i,j} \) to 1 if and only if scans \( i \) and \( j \) were connected by the k-NN graph and did not represent the same subject (to avoid connecting same TBI patients who underwent multiple scans); controls were left un-connected.

We identified r = 5 subnetwork bases using this affinity graph, and the regularization parameters were set at \( \lambda_{1} = \lambda_{2} = 0.25 \), as the model became stable around this value (degradation in classification performance was observed when parameters were set at \( \lambda_{1} = \lambda_{2} = 0 \), i.e., a setup equivalent to traditional NMF). To initialize the ADM variables, we use the strategy introduced in [4] to deterministically initialize \( \varvec{W} \) and \( \varvec{H} \) and set all other variables to zero for replicability. The AL parameter value was set to \( \rho \) = 1000 based on empirical test runs, and the ADM algorithm was terminated when the relative change in the objective function value (Eq. 2) at successive iterations fell below \( 10^{ - 4} \) and the following primal residual condition was met:

$$ \hbox{max} \left( {\frac{{\left\| {\varvec{W} - \tilde{\varvec{W}}_{1} } \right\|_{F} }}{{\left\| \varvec{W} \right\|_{F} }}, \frac{{\left\| {\varvec{W} - \tilde{\varvec{W}}_{2} } \right\|_{F} }}{{\left\| \varvec{W} \right\|_{F} }},\frac{{\left\| {\varvec{H} - \varvec{PX}} \right\|_{F} }}{{\left\| \varvec{H} \right\|_{F} }},\frac{{\left\| {\varvec{H} - \tilde{\varvec{H}}} \right\|_{F} }}{{\left\| \varvec{H} \right\|_{F} }},\frac{{\left\| {\varvec{P} - \tilde{\varvec{P}}} \right\|_{F} }}{{\left\| \varvec{P} \right\|_{F} }} } \right) < 10^{ - 4} . $$

To remove features that are likely non-biological, we applied feature selection using the aforementioned 11 clinical scores. Precisely, we first correlated individual SC features with each clinical score to obtain 11 separate p-value rankings (rank = 1 the smallest), and summed these rankings to obtain a rank-sum value for each feature. We then selected 1000 features having the smallest rank-sum that were then standardized via linear scaling to the range [0,1]. This feature selection and standardization procedures were conducted within the CV-folds to avoid biasing the classification performance.

We compared the performance for the following classifiers (implemented using Liblinear [6]). The first three methods are applied to the 1000 features selected using the above procedure: (1) L1-loss L2-regularized SVM (SVM), (2) L2-loss, L1 regularized SVM (SVM + L1), (3) L1-regularized Logistic regression (LogReg + L1), and (4) L1-loss L2-regularized SVM applied to the projected NMF coefficients with our method. A weighted loss function was used for all classifiers, where the weights assigned to each label class is inversely proportional to the class frequency. Since subjects have multiple timepoints, the classification accuracy was assessed using a Leave-One-Subject-Out CV (LOSO-CV) procedure, where all scans from a test subject are iteratively left out during training. Finally, the hyperparameter C, which is common to all classifiers, were tuned via an internal LOSO-CV over the range \( C \in \{ 2^{ - 10} , 2^{ - 9} , \cdots , 2^{10} \} \).

3.1 Experimental Results and Conclusions

Classification Results.

Table 1 reports the classification results from LOSO-CV for different methods, showing overall accuracy, specificity (type I error), sensitivity (type II error), and balanced score rate (BSR), which is the mean of specificity and sensitivity. The results show that the classification performance obtained using the proposed subnetwork features demonstrates a noticeable improvement over using the SC features in its original form, achieving accuracy of 82.0 % and a BSR of 81.8 %. The SVM achieves the next best performance, but the model is hard to interpret since all 1000 edge features contribute to the classifier. Finally, despite using a weighted loss function, we see the sparsity-promoting L1-regularized classifiers suffer from low sensitivity, which is likely caused by data label imbalance, as well as the correlated structures among the features (a case where L1-regularizations tend to suffer).

Table 1. Classification results from “leave-one-subject-out“cross-validation.

Effect of Manifold Regularization.

We next assessed whether the manifold regularizer with the disease-severity graph has successfully preserved the inter-patient relationship in terms of GOSE/DRS functional outcome indices. To do this, we computed Spearman’s rank correlation between the subnetwork bases coefficients and GOSE/DRS indices from the 79 TBI scans. The results reported in Table 2 reveal that for all basis coefficients, consistently positive and negative correlations (statistically significant) are obtained for GOSE and DRS, respectively. This result indicates that subjects with similar level of disease-severity share similar representations in the embedding space, demonstrating the impact of manifold regularization.

Table 2. Spearman’s correlation coefficients and corresponding p values between the r = 5 subnetwork basis coefficients and DRS/GOSE severity scores among TBI patients.

Subnetwork Visualization.

Given the high predictive capacity of subnetwork coefficients, we next examine their corresponding subnetwork bases \( \varvec{W} = \left[ {\varvec{w}_{1} , \cdots ,\varvec{w}_{5} } \right] \) to assess the pathological impact TBI may have induced on structural connectivity. For visualization and interpretation, we retrained the proposed NMF model using the entire dataset, and learned an SVM hyperplane \( \varvec{\beta}\in {\mathbb{R}}^{5} \) in the corresponding embedding space. The resulting subnetworks are rendered in 3-D brain space in Fig. 1 (figures generated using Python module Nilearn [1]); the color of the edges represent the sign of the hyperplane coefficients in \( \varvec{\beta} \), with red indicating contribution towards TBI (positive) and blue indicating contribution towards control (negative). From the figure, we can see that the network structure of the first basis exhibits strong bilateral symmetry with notable inter-hemispheric connections between the cerebellar, precuneus, and cingulate regions. Moreover, the second subnetwork basis resembles dense inter-hemispheric connections among the subcortical regions, with the sign indicating that these edges tend to be the weaker among TBI patients. On the other hand, subnetwork bases 3–5 represents connection towards TBI. Overall, the subnetworks exhibit a diffuse connectivity pattern that spans across the cortex, suggesting that damages from TBI results in a widespread disturbance in brain network. Interestingly, the connectivity patterns in the first two bases exhibit rich connectivity pattern within the subcortical and medial posterior regions, which are frequently reported to be vulnerable in TBI.

Fig. 1.
figure 1

The subnetwork bases obtained with \( r = 5 \). The edge color represents the sign of the corresponding hyperplane coefficient \( \varvec{\beta}\in {\mathbb{R}}^{\varvec{r}} \) (blue = negative/control, red = positive/TBI).

Conclusions.

We have presented a supervised NMF framework for extracting a disjoint set of subnetworks that are interpretable and highlight group differences in structural connectivity. The method is also capable of preserving the manifold structure in the data encoded by an affinity graph, thereby respecting the intrinsic geometry of the data. Experiment on a TBI dataset shows that the subnetworks identified from our method can not only be used to reliably discriminate TBI from controls, but also exhibit tight correlation with TBI-outcome indices, indicating that subjects with similar level of TBI-severity share similar subnetwork representations due to manifold regularization.