Keywords

1 Introduction

Bayesian segmentation of medical images, particularly in the context of brain MRI, is a well-studied problem. Probabilistic models for image segmentation frequently exploit atlas priors, and account for variations in contrast and imaging artifacts such as MR inhomogeneity [19, 21]. Most of the popular neuroimage processing pipelines rely on segmentation algorithms based on these ideas [2, 8, 15]. While these tools achieve high robustness to changes in MRI contrast of the input scan, they are computationally demanding (e.g., 23 min per image using a multi-threaded setup [16]), which limits their deployment at scale and in time-sensitive applications. Therefore, there is a need for computationally efficient methods that are contrast-adaptive, requiring no additional labeled training images to segment a new dataset.

Recently, there has been a surge in the application of deep learning (DL) techniques to medical image segmentation, often based on convolutional neural networks (CNN) that excel at learning contextually important multi-scale features. An advantage of these methods is their computational efficiency at test (segmentation) time, offering the potential to use automatic segmentation in new application areas, and large datasets [11, 17]. Moreover, these algorithms can be combined with atlas priors for increased robustness [5, 12, 13]. However, DL based techniques are notoriously sensitive to changes in the image intensity data distribution. For example, upgrades to MRI scanners or changes in pulse sequence, field strength, or RF coils, can alter contrast properties and dramatically reduce the performance of a CNN-based segmentation model [10]. This issue can be alleviated via domain adaptation or data augmentation, which requires simulating expected variations. However, even with additional data, these methods only partially close the gap with the fully supervised setting [14]. Furthermore, the dependency on manually annotated datasets means that existing DL approaches are only applicable if enough resources are available to compile the required training data. This is often infeasible, for example in the context of continuously upgrading imaging technologies.

In this paper, we consider the scenario in which we have a general probabilistic atlas prior and a collection of images with no manual delineations. The probabilistic atlas is a volume where each voxel has an associated vector with the prior probabilities of observing the various segmentation labels at that location. Our approach assumes the availability of such an atlas (in brain imaging, they are readily available), and is independent of how it was created. For example, it could have been obtained by averaging a collection of manually annotated volumes of a different imaging modality, or derived from an anatomical template after applying spatial blurring to account for spatial variability.

The main contribution of this paper is the integration of mathematical ideas from the Bayesian segmentation literature with an unsupervised deep learning framework, to achieve fast, contrast adaptive brain MRI segmentation. Specifically, we assume a probabilistic model, which requires estimation of parameters comprising an atlas deformation and image intensity statistics. The estimation of the atlas warp has traditionally relied on classic deformable registration algorithms [18], which are based on iterative optimization, and are therefore computationally expensive. Instead, we leverage recent advances in learning-based registration [3, 4, 20] to efficiently estimate the warp jointly with the intensity parameters. We use a novel interpretable loss function from the probabilistic model via Bayesian inference. Integrating DL with Bayesian segmentation, we attain two highly desirable features. First, given a probabilistic atlas, the method is unsupervised and contrast adaptive. Training it does not require any ground truth segmentations. Second, the segmentation is efficient, requiring approximately 15 s on a GPU.

2 Method

2.1 Segmentation as Bayesian Inference

Let \(\varvec{I}\) represent the intensities of a 3D brain MRI scan, defined over a discrete domain \(\Omega \subset \mathbb {R}^3\). Let \(\varvec{S}\) be a corresponding discrete segmentation into L neuroanatomical labels. Bayesian segmentation relies on Bayes’ rule to derive the posterior probability distribution of the segmentation given the input image. Then, the segmentation \(\hat{\varvec{S}}\) is estimated as the mode of this posterior:

$$\begin{aligned} \hat{\varvec{S}} = \mathop {{\text {arg~max}}}\limits _{\varvec{S}} p(\varvec{S} | \varvec{I}) = \mathop {{\text {arg~max}}}\limits _{\varvec{S}} p(\varvec{I} | \varvec{S}) p(\varvec{S}). \end{aligned}$$
(1)

The posterior distribution \(p(\varvec{S} | \varvec{I})\) depends on two terms: a prior \(p({\varvec{S}})\) and a likelihood \(p(\varvec{I} | \varvec{S})\), in contrast to discriminative approaches which model \(p(\varvec{S} | \varvec{I})\) directly. The prior represents knowledge about the spatial distribution of segmentation labels (i.e., underlying anatomy), and typically has the form of a probabilistic atlas endowed with a deformation model. The likelihood models the relationship between the segmentation and image intensities, including image artifacts such as noise and bias field. Both the prior and likelihood may have associated parameters, which we define as \(\varvec{\theta }_S\) and \(\varvec{\theta }_I\), respectively. The former describes attributes such as label probabilities and an atlas deformation, while the latter typically includes image intensity statistics as a function of label.

The likelihood parameters may be learned from a training dataset, or estimated specifically for each test scan. We build on Bayesian segmentation models that follow the latter approach [2, 16, 19, 21, 22], enabling models to adapt to the intensity characteristics of input scans, making them robust to changes in MRI contrast. Expanding Eq. (1) to include model parameters, which we treat as random variables, yields:

$$\begin{aligned} \hat{\varvec{S}} = \mathop {\text {arg\, max}}\limits _{\varvec{S}} \int _{\varvec{\theta }_S} \int _{\varvec{\theta }_I} p(\varvec{S} | \varvec{\theta }_S, \varvec{\theta }_I, \varvec{I}) p( \varvec{\theta }_S, \varvec{\theta }_I | \varvec{I}) d\varvec{\theta _S} d\varvec{\theta _I}, \end{aligned}$$
(2)

which is intractable. A standard approximation uses point estimates for the parameters. First, one estimates the mode of the parameter posterior distribution:

$$\begin{aligned} \{ \hat{\varvec{\theta }}_S , \hat{\varvec{\theta }}_I \} = \mathop {\text {arg\, max}}\limits _{\{\varvec{\theta }_S,\varvec{\theta }_I\}} p(\varvec{\theta }_S,\varvec{\theta }_I | \varvec{I}) = \mathop {\text {arg\, max}}\limits _{\{\varvec{\theta }_S,\varvec{\theta }_I\}} p(\varvec{\theta }_S) p(\varvec{\theta }_I ) \sum _{\varvec{S}} p(\varvec{I} | \varvec{S}, \varvec{\theta }_I) p(\varvec{S} | \varvec{\theta }_S), \end{aligned}$$
(3)

where we assume independence between the parameters of the prior and of the likelihood. The computation often requires estimating an atlas deformation in \(\varvec{\theta }_S\) and intensity parameters in \(\varvec{\theta }_I\) and is typically achieved with a combination of numerical optimization and the Expectation Maximization (EM) algorithm [6]. Given point estimates, the final segmentation is computed as:

$$\begin{aligned} \hat{\varvec{S}} = \mathop {\text {arg\, max}}\limits _{\varvec{S}} p(\varvec{S} | \hat{\varvec{\theta }}_S, \hat{\varvec{\theta }}_I, \varvec{I}) = \mathop {\text {arg\, max}}\limits _{\varvec{S}} p(\varvec{I} | \varvec{S}, \hat{\varvec{\theta }}_I) p(\varvec{S} | \hat{\varvec{\theta }}_S). \end{aligned}$$
(4)

2.2 Proposed Model

Our model instantiation builds on existing work [2, 16, 19]. The prior is defined by a given probabilistic atlas \(\varvec{A}\), where \(A (l, \varvec{x})\) provides the probability of observing each neuroanatomical label \(l=1,\ldots ,L\) at each location \(\varvec{x}\in \Omega \). The atlas is deformed by a transform \(\varvec{\phi }\), parameterized by a stationary velocity field \(\varvec{v}\), (i.e., \(\varvec{\phi }_v = \exp [\varvec{v}]\), which guarantees differomorphic \(\varvec{\phi }\) [1]). Therefore, the prior is parametrized by \(\varvec{\theta }_S = \varvec{v}\). Assuming independence over voxels:

$$\begin{aligned} p(\varvec{S} | \varvec{\theta }_S ; \varvec{A}) = p (\varvec{S} | \varvec{v}; \varvec{A}) = \prod _{j\in \Omega } A \Big ( S_j , \varvec{\phi }_v(\varvec{x}_j) \Big ), \end{aligned}$$
(5)

where \(S_j\) is the segmentation at voxel j, and \(\varvec{x}_j\) is its spatial location. We discourage strongly varying deformations by penalizing the spatial gradient \(\nabla \varvec{u}_v\) of displacement \(\varvec{u}_v\), where \(\varvec{\phi }_v = Id + \varvec{u}_v\), i.e., \(p(\varvec{\theta }_S ; \lambda ) = p (\varvec{v}; \lambda ) \propto \exp [-\lambda \Vert \nabla \varvec{u}_v \Vert ^2 ]\). The hyperparameter \(\lambda \) controls the strength of the penalty.

Conditioned on a segmentation, we assume that the observed intensities at different voxel locations are independent samples of Gaussian distributions:

$$\begin{aligned} p(\varvec{I} | \varvec{S},\varvec{\theta }_I) = p(\varvec{I} | \varvec{S},\varvec{\mu },\varvec{\sigma }^2) = \prod _{j\in \Omega } \mathcal {N}(I_j ; \mu _{S_j} , \sigma ^2_{S_j}), \end{aligned}$$
(6)

where \(\mathcal {N}(\cdot ;\mu ,\sigma ^2)\) is the Gaussian distribution, \(I_j\) is the image intensity at voxel j, and the likelihood parameters \(\varvec{\theta }_I = \{\varvec{\mu }, \varvec{\sigma }^2\}\) are L means \(\mu _l\) and variances \(\sigma _l^2\), each associated with a different label l. We complete the model with a flat prior for these parameters: \(p(\varvec{\theta }_I)\propto 1\). The model can be easily extended to the multi-spectral case (i.e., inputs with multiple MRI contrasts) by replacing means and variances by mean vectors and covariance matrices, respectively.

Learning. To avoid the computationally expensive optimization typically required for maximum a posteriori (MAP) estimation in Eq. (3), we propose to train a CNN to estimate model parameters directly from an input scan. Specifically, we design a CNN \(g_{\varvec{\theta }_C} (\varvec{I}, \varvec{A}) = (\varvec{\theta }_{S},\varvec{\theta }_{I}) = (\varvec{v},\varvec{\mu },\varvec{\sigma }^2)\) with global convolutional parameters \(\varvec{\theta }_C\) that takes as input a scan \(\varvec{I}\) and probabilistic atlas \(\varvec{A}\), and outputs model parameters \(\varvec{v},\varvec{\mu },\varvec{\sigma }^2\) for that scan. To learn the network parameters \(\varvec{\theta }_{C}\), we use a pool of N unlabeled scans \(\{I^n\}_{n=1}^N\) to minimize the negative log posterior distribution of the image-specific parameters given the training images:

$$\begin{aligned}&- \sum _{n=1}^N \log p(\varvec{v}^n,\varvec{\mu }^{n},[\varvec{\sigma }^2]^{n} | \varvec{I}^n ; \varvec{A}, \lambda ) \\ =&- \sum _{n=1}^N \sum _{j\in \Omega } \log \left[ \sum _{l=1}^L \mathcal {N}(I_{j}^{n} ; \mu _{l}^{n} , [\sigma ^2_{l}]^{n}) A \Big ( l , \varvec{\phi }_{v_m}(\varvec{x}_j) \Big ) \right] + \lambda \Vert \nabla \varvec{u}_{v}^{n} \Vert ^2 + \text {const} \nonumber \end{aligned}$$
(7)

The network outputs different parameters \(\varvec{v}\)\(\varvec{\mu }\), and \(\varvec{\sigma }\) for each test image \(\varvec{I}\).

We design the CNN \(g_{\varvec{\theta }_C}(\cdot ,\cdot )\) based on a 3D UNet-style architecture [17] and the VoxelMorph implementation [3]. The network consists of downsampling convolutional layers with 32 filters, 3\(\,\times \,\)3 kernels, stride of 2, and LeakyReLu activations, followed by mirror upsampling layers and skip-connections. An additional convolutional layer is used to output \(\varvec{v}\), a dense 3D velocity field defined over \(\Omega \); and an additional pair of convolutional layers followed by global max pooling yield the Gaussian parameters \(\varvec{\mu },\varvec{\sigma }^2\). We compute \(\varvec{\phi }_v = \exp (\varvec{v})\) using a scaling and squaring integration layer [1, 4], enabling the computation of the loss regularization term. Combining the Gaussian parameters with the input image yields likelihood maps. These maps, together with a warped probabilistic atlas \(\varvec{A}\) via a spatial transform layer, enable computation of the first term in Eq. (7) (Fig. 1).

Fig. 1.
figure 1

Method overview. The network block \(g_\psi (\cdot ,\cdot )\) outputs a stationary velocity field \(\varvec{v}\), enabling alignment of the probabilistic atlas to the input volume, and likelihood Gaussian parameters \(\varvec{\mu },\varvec{\sigma }^2\), which yield likelihood maps for each label.

Efficient Segmentation. Given a test scan, the trained network efficiently provides the image-specific parameter point estimates \(\hat{\varvec{v}}\) and \(\hat{\varvec{\theta }}_I\) in a single forward pass. The optimal segmentation is efficiently computed for each voxel j:

$$\begin{aligned} \hat{S}_j = \mathop {\text {arg\, max}}\limits _l \mathcal {N}(I_j ; \hat{\mu }_l , \hat{\sigma }^2_l) A \Big ( l , \phi _{\hat{v}}(\varvec{x}_j) \Big ). \end{aligned}$$
(8)

3 Experiments and Results

Data. We evaluate our approach on three different image sets. The first (“multi-site”) includes 8,332 T1-weighted scans from several public datasets.Footnote 1 We randomly selected 7,332 scans to train and validate, and the remaining 1,000 were held out for testing. Manual delineations are not available for these scans, but we used automated segmentations produced by FreeSurfer [8] as a silver standard. The second dataset (“T1”) consist of 38 T1-weighted scans, used only for testing, each with 36 manually delineated brain structures [8]. The third dataset (“PD”) consists of eight proton density-weighted (PD) scans, manually segmented with the same protocol [9]. All scans were preprocessed with FreeSurfer, including skull stripping, bias field correction, intensity normalization, affine registration to Talairach space, and resampling to 1 mm\(^3\) isotropic resolution [7].

Experimental Setup. We perform three experiments, one for each dataset. In the first, we fit our network to the T1-weighted training scans of the multi-site dataset, and use the resulting model to segment the 1,000 test scans. Despite the lack of manual gold standard, this experiment enables assessment of performance on a large, heterogeneous dataset. In a second experiment, we use the model already trained in the first to segment test scans from the separate T1 dataset. This enables evaluation with manual ground truth on scans from a scanner and pulse sequence not observed during training. In the third experiment, we train a network on the PD dataset, and then use it to segment the 8 PD scans. This is a different scenario than the first two experiments, since we learn to segment the test dataset directly. This experiment enables us to assess the ability of our algorithm to segment a substantially different MRI contrast, and fit datasets of reduced size. In all experiments, we use our method with the publicly available atlas from [16]. We emphasize that all networks are trained in an unsupervised fashion, and segmentation maps are only used for evaluation.

Baseline. We compare our method to a reimplementation of [19], which solves Eq. (7) with no deformation (i.e., \(\varvec{u}=\varvec{0}\), \(\phi _v = Id\)) via the EM algorithm. Since the model does not include deformation, using the nonrigid version of the atlas would yield low performance, and instead it relies on an affine version of the aforementioned atlas and Gaussian likelihood functions.

Evaluation. We used Dice scores for a subset of structures of interest (Fig. 2), and also focus on deep structures such as the hippocampus, which is the target of many neuroimaging studies due to its significance in dementia.

Implementation. We implement our method using Keras with a Tensorflow backend and the ADAM optimizer. We predict the velocity field \(\varvec{v}\) and resulting deformation field \(\varvec{\phi }\) at every second voxel in each dimension, due to memory constraints. We obtain a final dense deformation field via linear interpolation.

Fig. 2.
figure 2

Segmentation Statistics. Dice scores for: cerebral cortex (CT) and white matter (WM); lateral ventricle (LV); cerebellar cortex (CC) and white matter (CW); thalamus (TH); caudate (CA); putamen (P); pallidum (PA); brainstem (BS); hippocampus (HP); and amygdala (AM). Scores of contralateral structures are averaged. The number of outliers under the x axis is shown in red (baseline) and blue (ours). (Color figure online)

Fig. 3.
figure 3

Example Results. Coronal slices of two scans (one from each of the T1 and PD datasets), along with the initial and deformed probabilistic atlas, and corresponding segmentations. In the atlas, the color of each pixel is a combination of the colors of different labels, weighted by their probabilities. In the segmentations, we show the contour of the labels in the corresponding colors. We use the FreeSurfer color map [7]. (Color figure online)

For all experiments, we set \(\lambda =10\), the only free parameter, having visually evaluated segmentation results for several validation subjects (held out from training). We also group anatomical labels with similar intensity properties into eleven merged labels, forcing groups of original labels to share Gaussian parameters, increasing robustness [16]. Specifically, we group: contralateral structures (in general), gray matter structures (cerebral gray matter, hippocampus, amygdala, caudate, accumbens), and cerebrospinal fluid structures.

Results. Our method requires only 15 s per scan on an NVIDIA Titan Xp GPU. Figure 2 reports segmentation statistics for all experiments. Our method achieves considerably higher Dice scores than the baseline on the multi-site dataset (average over all structures 83.5% vs. 79.0%), especially in deep brain structures such as the hippocampi (81.1% vs. 73.1%). Moreover, it largely reduces the number of outliers with very poor segmentation (e.g., there are over 100 cases with Dice \(< 50\)% in the caudate for the baseline approach, and none for our method). In the T1 dataset, the test intensity distribution is slightly different that of the training dataset. However, our approach successfully generalizes and outperforms the baseline (average 81.9% vs. 79.4%, hippocampi 79.9% vs. 73.5%). The results of the third experiment illustrate the ability of our method to adapt to contrasts other than T1, even when the data are limited, and outperform the baseline (average 80.5% vs. 78.3%, hippocampi 76.6% vs. 69.8%).

Figure 3 shows two segmentations from the T1 and PD datasets. In the T1 scan, the atlas successfully deforms to match the large ventricles of the subject, producing more accurate segmentations than the baseline – not only for the ventricles (purple), but also for surrounding structures, e.g., thalami (green). In the PD scan, our method manages to segment all structures including the amygdalae (light blue), which are missed by the baseline.

4 Conclusion

We propose a principled approach for unsupervised segmentation, which enables training a CNN for a dataset without the need for any manually annotated images. The likelihood model may be extended to incorporate more complex functions (such as mixtures of Gaussians) and artifacts such as partial voluming and bias fields. In addition to segmentations, the method produces a dense nonlinear deformation field that is a useful output by itself, e.g., for tensor-based morphometry. Using a large dataset, we demonstrate that the proposed approach achieves state-of-the-art accuracy for unsupervised brain MRI segmentation in different MRI contrasts. Our method runs in under 15 s on a GPU, facilitating deployment on large studies and in time-sensitive applications.