Keywords

1 Introduction

Diffusion magnetic resonance imaging (dMRI) [1] uniquely enables mapping of the brain’s white matter fiber tracts via tractography [2], to study the brain’s connections in health and disease [9]. Tractography of a single brain can generate hundreds of thousands of streamlines (fibers), which are not immediately useful to clinicians or researchers. Therefore, tractography parcellation, i.e. dividing the massive number of tractography fibers into multiple subdivisions, is needed.

One widely used tractography parcellation strategy, white matter fiber clustering (WMFC), groups fiber streamlines with similar geometric trajectory into clusters [22]. WMFC is useful in applications such as disease classification [42], anatomical tract identification [33] and neurosurgical brain mapping [29]. In general, WMFC first computes pairwise fiber geometric similarities, then applies a computational clustering method to group similar fibers into clusters [7, 31, 43]. Existing WMFC methods show good performance, but key challenges remain. First, it is computationally expensive to compute pairwise fiber geometric similarities. Second, the computation of fiber similarity is sensitive to the order of points along the fibers, even though a fiber can equivalently start from either end [7]. Third, false positive fibers are prevalent in tractography; thus outlier fiber removal is needed to filter undesired fibers from the clustering result [16, 18]. Fourth, it is a challenge for WMFC to use all available information to improve cluster anatomical quality: most methods use either fiber spatial coordinate information [7, 43] or anatomical information about brain regions that fibers pass through [27]. Fifth, WMFC methods should ideally consider inter-subject correspondence of fiber clusters, which is essential for group-wise analysis [21]. To achieve this goal, some studies perform WMFC across subjects (to form an atlas) and predict clusters of new subjects with correspondence to the atlas [23, 38, 39], while other approaches first perform within-subject WMFC then match (or cluster) the fiber clusters across subjects [7, 10, 13, 27].

In computer vision, clustering has been extensively studied as an unsupervised learning task [3, 11, 28, 34, 37], which requires a data feature representation and similarity computation between the features for cluster assignment. Autoencoder-based approaches are popularly used for unsupervised clustering [11, 28, 34]. The Deep Embedding Clustering (DEC) framework performs simultaneous embedding of input data and cluster assignments in an end-to-end way [34]. Deep Convolutional Embedded Clustering (DCEC) is an extension of DEC to the image clustering task [11]. In addition to autoencoder approaches, [3] and [37] also realized joint embedding learning and cluster assignments by alternative feature learning and traditional clustering, which is time consuming.

Self-supervised learning is a promising subclass of unsupervised learning that shows advanced performance in many applications [15, 25]. It aims to learn high-level features without requiring manual annotations. This is achieved by designing pretext tasks, such as predicting context [5] or image rotation [8], and giving the network pseudo annotations generated from the input itself. The high-level representations learned from the pretext task can then be transferred to downstream tasks such as clustering. Therefore, besides the classical autoencoder network, the self-supervised learning framework can also be a promising approach to learn deep embeddings of inputs.

Considering the advances of deep neural networks in feature extraction, deep learning is a promising direction for WMFC. In related work, multiple deep learning methods have been proposed for white matter tractography segmentation [12, 32, 35, 40]. In [12, 32, 40], known fiber labels are provided for training. One proposed method [35] has shown the potential of unsupervised deep learning for fiber clustering; however, the anatomical utility of this approach was not tested as results were limited to a maximum of 11 clusters in the whole brain. The goal of our study is to propose an anatomically meaningful unsupervised deep learning framework, Deep Fiber Clustering (DFC), for fast and effective white matter fiber clustering. The paper has four contributions. First, we propose a novel deep learning pipeline that adopts self-supervised learning for deep embedding and achieves joint representation learning and cluster assignment. Second, anatomical information is incorporated into the neural network to improve cluster anatomical coherence. Third, outliers are removed by rejecting fibers with low soft label assignment probabilities. Our approach automatically creates a multi-subject fiber cluster atlas that is applied for white matter parcellation of new subjects. Finally, our approach has demonstrated superior performance and efficiency via evaluations on a large scale dataset.

Fig. 1.
figure 1

Overview of our DFC framework. A self-supervised learning strategy is adopted with the pretext task of pairwise fiber distance prediction. In the pretraining stage, a pair of FiberMaps are encoded as embeddings with Siamese Networks, and prediction loss (\(L_p\)) is calculated based on the difference between embedding distance and fiber distance. In the clustering stage, a clustering layer is connected to the embedding layer and generates soft label assignment (as shown in the dashed box). A KL divergence loss (\(L_c\)) and the prediction loss are combined to optimize the neural network.

2 Methods

As shown in Fig. 1, our training pipeline includes two stages, pretraining and clustering. In the pretraining stage (Sect. 2.1), a CNN is trained in a self-supervised way with a designed pretext task to obtain deep embeddings. After that, \( k \)-means clustering is performed on the embeddings to get initial clusters, which is performed only once during training. In the clustering stage (Sect. 2.2), the clustering results are fine-tuned in a self-learning manner and cluster centroids are automatically optimized as parameters of the network. During network inference, when the model is applied to a new subject, cluster assignments are obtained from the network directly in an end-to-end way without any \( k \)-means clustering.

In this work, we adopt the FiberMap fiber representation [40], which was found to be effective for tractography segmentation in supervised learning. One benefit of using FiberMap is that it is a 2D multi-channel feature descriptor (analogous to a RGB image); thus it can be effectively processed by CNNs.

2.1 Self-supervised Deep Embedding

We propose a novel self-supervised learning strategy for learning deep fiber embeddings. The goal is to obtain embeddings with similar distances to fiber distances in the brain space, enabling subsequent WMFC in the embedding space. (We note that a DCEC model with a convolutional autoencoder could be adopted here for unsupervised WMFC, but as we show in the Results, this straightforward approach is sensitive to fiber point ordering.) To learn the embeddings, a pretext task is first designed to predict the distance between a pair of input fibers. Specifically, the input to the network is the FiberMaps of a fiber pair and a pseudo annotation of the fiber pair distance. For the pairwise fiber distance, we use the minimum average direct-flip (MDF) distance which is widely successful in WMFC [7, 43]. The computation of fiber distance considers the order of points along the fibers; thus, fiber distance is not affected if a fiber point sequence is flipped. A Siamese Network [4], a neural network that encodes different inputs and computes comparable outputs with shared weights, is then adopted to learn embeddings of an input FiberMap pair and output Euclidean distance between the embeddings. The distance prediction loss \(L_p\) is the mean squared error between embedding distance and fiber-distance pseudo annotations. By using fiber distances as pseudo annotations, the network is guided to generate similar embeddings for close fibers, even those with flipped point orders.

2.2 Clustering Layer and Clustering Loss

Here we adopt the DCEC model design [11]. In the clustering stage, a clustering layer is designed to encapsulate cluster centroids as its trainable weights and compute a soft assignment label \(q_{ij}\) using Student’s t-distribution [17, 34]:

$$\begin{aligned} q_{ij} = (1+\left\| z_i-\mu _j \right\| ^2)^{-1} / (\begin{matrix} \sum _{j'} (1+\left\| z_i-\mu _j' \right\| ^2)^{-1} \end{matrix}) \end{aligned}$$
(1)

where \( z _i\) is the embedding of fiber i and \(\mu _j\) is the centroid of cluster j. \(q_{ij}\) is the probability of assigning fiber i to cluster j. The network is trained in a self-training manner and its clustering loss \(L_c\) is defined as a KL divergence loss [34]: \(L_c = KL (P||Q)= \begin{matrix} \sum _{i} \begin{matrix} \sum _{j} p_{ij}log\frac{p_{ij}}{q_{ij}} \end{matrix} \end{matrix}\), where \( p_{ij}= ({q^2_{ij}/ \sum _{i} q_{ij})} / (\sum _{j'} (q^2_{ij'}/ \sum _{i} q_{ij'}))\). The distance prediction loss is retained in this stage, and the total loss is \(L=L_p + \lambda L_c\), where \(\lambda \) is the weight of \(L_c\). During inference, a fiber i is assigned to the cluster with the maximum \(q_{ij}\).

2.3 Incorporation of Anatomical Information and Outlier Removal

We extend our proposed self-learning framework described above to enable two important tasks in WMFC, i.e., inclusion of additional anatomical information for anatomical coherence and removal or filtering of false positive outlier fibers. For the first task, we propose to incorporate Freesurfer parcellation [6] information during the clustering stage. We design a new soft label assignment probability definition which is used to calculate loss and extends Eq.(1) to further regularize that fibers within a cluster pass through the same brain regions:

$$\begin{aligned} q_{ij} = (1+\left\| z_i-\mu _j \right\| ^2*(1-D_{ij}))^{-1} / ( \begin{matrix} \sum _{j'} (1+\left\| z_i-\mu _{j'} \right\| ^2*(1-D_{ij'}))^{-1} \end{matrix} ) \end{aligned}$$
(2)

where \(D_{ij}\) is the Dice score between the set of Freesurfer regions of fiber i and the set of Freesurfer regions of cluster j. We use the Tract Anatomical Profile (TAP) proposed in [43] to define the set of Freesurfer regions commonly intersected by the fibers in a cluster (at least 40% of fibers, as in [43]). During training, the TAP is initially calculated from the clusters generated by \( k \)-means and is updated iteratively with new predictions during the training process. During inference, soft label assignments are calculated with Eq. (2) and fibers are assigned to the cluster with maximum \(q_{ij}\), referred to as \(q_m\).

For outlier removal, we remove fibers using the maximum label assignment probability \(q_m\), considering that fibers with higher \(q_m\) tend to have more confidence of belonging to the corresponding cluster and are less likely to be outliers. Therefore, we remove outliers by setting a threshold h on the \(q_m\) values of fibers, meaning that fibers with \(q_m < h\) will be rejected from the final clusters.

2.4 Implementation Details

As shown in Fig.  1, our model architecture includes three convolutional layers of sizes 5 \(\times \) 5 \(\times \) 32, 5 \(\times \) 5 \(\times \) 64 and 3 \(\times \) 3 \(\times \) 128, respectively, to extract feature maps. These feature maps are flattened to a vector, followed by a fully connected layer to compute embeddings with a dimension of 10 (suggested in [11]). In the pretraining and clustering stages, the network is trained for 25000 iterations with a learning rate of 0.0001 and another 4000 iterations with a learning rate of 0.00001, which are sufficient to achieve training convergence. Admax [14] is used for optimization in both stages. All experiments are performed on an NVIDIA RTX 2080Ti GPU using Pytorch (v1.7.1) [26]. The weight of clustering loss \(\lambda \) is set to be 0.1, as suggested in [11]. We set the threshold h for outlier removal to be 0.015 to reject fibers with extremely low cluster assignment probabilities.

3 Experiments and Results

3.1 Dataset

In our experiments, we used a dataset of 200 healthy adults from the Human Connectome Project [30]. 100 subjects were used for training, 50 for validation and 50 for testing. Tractography data were generated using a two-tensor unscented Kalman filter (UKF) method [19], and tractography co-registration was performed using an affine followed by a nonrigid registration [24]. Fibers longer than 40 mm were retained to avoid any bias towards implausible short fibers. For each training subject, 10,000 fibers were randomly selected, generating a training dataset of 1 million samples. For testing and validation, all whole-brain tractography fibers were used (around 500,000 per subject). Fibers were downsampled to 14 points [40] to obtain the FiberMap input to neural network. We performed diffusion MRI tractography and visualization in 3D Slicer (www.slicer.org) via the SlicerDMRI project (http://dmri.slicer.org) [20, 41].

3.2 Evaluation Metrics

Three evaluation metrics were adopted to quantify performance of our proposed method and enable comparisons among approaches. The first one is the Davies–Bouldin (DB) index [36], which is computed as:

$$\begin{aligned} DB = ({1}/{n}) \begin{matrix} \sum _{k=1}^n \mathop {max}_{x\ne y}(\frac{\alpha _i+\alpha _j}{d(c_i,c_j)}) \end{matrix} \end{aligned}$$
(3)

where n is the number of clusters, \(\alpha _i\) and \(\alpha _j\) are mean pairwise intra-cluster fiber distances, and \(d(c_i,c_j)\) is the inter-cluster fiber distance between centroids \(c_i\) and \(c_j\) of cluster i and j [31]. A smaller DB score indicates a better separation between clusters. The second metric is White Matter Parcellation Generalization (WMPG) [43], which is used to represent the percentage of clusters successfully detected across the testing subjects. In our work, clusters with a over 10 fibers are considered to be successfully detected [43]. The last metric is Tract Anatomical Profile Coherence (TAPC) [43], which measures if the fibers within a cluster c commonly pass through the same brain anatomical regions:

$$\begin{aligned} TAPC (c)= ( \begin{matrix} \sum _{f=1}^{ NF (c)} Dice( TAP (f), TAP _ atlas (c)) \end{matrix}) / { NF (c)} \end{aligned}$$
(4)

Higher TAPC scores indicate better anatomical coherence.

3.3 Evaluation Results

Comparison with State-of-the-Art Methods. We compare our proposed approach with two open-source state-of-the-art WMFC algorithms, WhiteMatterAnalysis (WMA)[43] and QuickBundles (QB) [7]. WMA is an atlas-based WMFC method that shows high performance and strong correspondence across subjects. QB is a widely used WMFC method that performs clustering within each subject and achieves group correspondence with post-processing steps. We use the open-source software packages WMA v0.3.0 and Dipy v1.3.0 with their default settings. For all experiments, we perform WMFC into 800 clusters (which has been suggested to be a good whole brain tractography parcellation scale [43]). Dipy does not accept an input number of clusters; therefore, we tuned parameters in each subject to obtain a number as close as possible to 800 clusters (greater than or equal to 800). All results are reported using data from the 50 test subjects. The WMPG and TAPC metrics require corresponding clusters across all subjects; these are automatically generated by WMA and our proposed DFC method. For QB, correspondence is achieved by matching cluster centroids from all subjects to those of one selected subject (with exactly 800 clusters) according to the fiber distances between centroids, as suggested by the QB developers [7].

Table 1. Quantitative comparison results. SOTA: state of the art.
Fig. 2.
figure 2

Visualization of example clusters generated from DFC, WMA, and QB in one subject. Similar clusters were identified across methods for visualization.

As shown in Table 1, our DFC method exhibits the best performance in general. For the DB index metric, QB obtained a slightly lower value than DFC, likely because intra-cluster distances are lower when performing within-subject clustering since the obtained clusters do not describe anatomical variability across subjects. When compared to the atlas-based WMA, the DB index of our method is obviously smaller, indicating more compact and/or better separated clusters. As for WMPG, both our method and WMA successfully detected over 99\(\%\) of clusters while the WMPG score of QB is around 80\(\%\) indicating poor correspondence across subjects. The TAPC metric of DFC obtained the highest value among the three methods owing to the incorporation of anatomical information, indicating the best anatomical coherence of clusters. Figure 2 gives a visual illustration of obtained clusters for each method.

Fig. 3.
figure 3

Illustration of corresponding clusters from DFC and DCEC. Colors represent order of points along a fiber with red for starting point and blue for ending point. (Color figure online)

To evaluate efficiency of approaches, inference time of one subject is also recorded and shown in Table 1. All methods were tested on a computer equipped with a 2.1 GHz Intel Xeon E5 CPU (8 DIMMs; 32 GB Memory). For fair comparison, DFC was set to run on CPU instead of GPU. The results show that our method is much faster than WMA and slightly better than QB.

Comparison with the Baseline Method. We also compared our proposed method with the DCEC baseline model. The results in Table 1 show a large improvement of the DB index of our method compared to DCEC, because DCEC separately clusters fibers with close positions but flipped point orders. As shown in Fig. 3, spatially close fibers with different point orders are split into two clusters in DCEC, while our proposed DFC method groups them together.

Ablation Study. We performed an ablation study to investigate how different factors influence performance of our method. Evaluation of three models was performed, including DFC\(_{no-fs-ro}\) (DFC without FreeSurfer information and outlier removal), DFC\(_{no-ro}\) (DFC without outlier removal but with FreeSurfer information) and DFC\(_{proposed}\), as shown in Table 1. By adding FreeSurfer information into the model, the DB index and WMPG metrics do not show much difference, while the TAPC score exhibits obvious improvement. With implementation of outlier removal, the DB index and TAPC improve obviously, while WMPG shows slight decrease, which is inevitable due to the decreased number of fibers (but it still remains a high percentage). These results demonstrate effectiveness of our designed modules. As shown in Fig. 4, outlier fibers have apparently low values of soft label assignment probabilities and are then removed.

Fig. 4.
figure 4

Illustration of outlier removal process. Left: cluster before outlier removal; Middle: fiber soft label assignment probability (rainbow coloring with red representing 0); Right: cluster after outlier removal.

4 Conclusion

In this paper, we present a novel unsupervised deep learning framework for dMRI tractography WMFC. We adopt the self-supervised learning strategy to enable joint deep embedding and cluster assignment. Our method can handle several key challenges in WMFC methods, including handling flipped order of points along fibers, incorporating anatomical brain segmentation information, false positive fiber filtering and inter-subject correspondence of fiber clusters. Our results show advantages over clustering performance as well as efficiency compared to the state-of-art algorithms. Further research could be conducted to improve the framework, such as designing more complex network architectures, incorporating additional sources of anatomical information and balancing anatomical and fiber geometry information for clustering.