Keywords

1 Introduction

Machine learning calibration measures the agreement between the predicted probability estimates and the true correctness likelihood [8]. Proper calibration is vital for high-risk applications. Modern deep neural networks (DNNs) achieve impressive accuracy at poor calibration [8]. Incorrectly calibrated DNNs are unreliable on out-of-distribution data and don’t know when they are likely to be incorrect. This discrepancy leaves them vulnerable in critical decision-making such as self-driving cars, surgical robots, and disease subtyping On the other hand, well-calibrated models are less certain when incorrect and comparably certain when correct. Their reliable confidence establishes trustworthiness.

We hypothesize that domain-aware model calibration that leverages the semantic confusability and hierarchical similarity among class labels can yield well-calibrated and higher-performing models. To test this hypothesis, we have chosen medical image segmentation because it is fundamental in medical image analysis. Overly-confident tissue boundaries can introduce significant errors in brain volume estimations [4]. Head image segmentation is prone to errors due to fine tissue boundaries, tissue imbalance, and low contrast. These challenges can make open-source software fall short on patient sub-populations [3, 12, 17]. Errors in head segmentation can lead to downstream errors in clinical pipelines, like in estimating parameters for non-invasive brain stimulation [2, 11].

Hence, we address uncertainty in medical image segmentation by introducing DOMINO, a framework that leverages domain information among class labels to calibrate DNNs. Unlike prior works that push class means to be orthogonal [15], we assume some class labels are naturally similar. The choice of the loss function is important to calibration because loss drives how a model learns [16]. Medical image segmentation still largely relies on standard losses [1]. We extend these approaches with domain-aware loss regularization to improve model calibration. We study two regularization schemes that are based on confusion matrices (CM) and hierarchical classes (HC). The former imposes a penalty based on class confusability when using a standard network on a held-out data subset. The latter groups labels into hierarchical classes based on common tissue properties.

2 Domain-Aware Model Calibration

2.1 U-Net Transformers (UNETR) Model

We employ UNETR [9] as our base model due to its superior segmentation performance. UNETR utilizes a U-Net architecture with a transformer encoder. This approach combats the relative locality of convolutional layers in fully convolutional networks (FCNs). Transformers have revolutionized Natural Language Processing due to superior long-range learning [18]. Transformers encode images as sequences of one-dimensional patch embeddings. Self-attention modules learn weighted sums from hidden layers. Hence, UNETR reformulates 3D image segmentation as sequence-to-sequence predictions. Skip connections pass the transformer’s global context to a traditional FCN decoder. The decoder concatenates local information with the global multi-scale information from the encoder. This paper refers to un-regularized UNETR as UNETR-Base.

2.2 Domain-Aware Loss Regularization

Concept. Our penalty addresses a deficit with cross-entropy (CE) loss in uncertainty. CE loss maximizes the output of the ground truth label. Due to this, the network increases the true label logit more than the incorrect label logits. The resulting networks are overly confident in their predictions. Meanwhile, the non-selected classes’ softmax outputs do not represent the true likelihood. Our work introduces more meaningful uncertainty by penalizing incorrect classes. Specifically, we assume that some classes are more similar to others. Network presentation often pushes class means to all be orthogonal to one another [15]. Such networks assume that all classes are equally separable. This assumption fights the natural similarities between certain classes. Thus, we hypothesize that a network can learn better class representation by taking advantage of class similarities, rather than fighting them. Our methods apply to classification and segmentation. This treats segmentation as pixel-wise classification [13].

Derivation. Our regularization term adds to any loss function as follows:

$$\begin{aligned} \mathcal {L}(y, \hat{y}) + \beta (y')(W)(\hat{y}) \end{aligned}$$
(1)

where \(\mathcal {L}\) is a suitable loss function (we use DiceCE which is a combination of Dice score and cross-entropy), y is the one-hot encoded true label, and ŷ is the softmax output. \(\beta \) can take on any value between zero and one. W represents a generic regularization term of size \(N\times N\), where N is the number of classes. The diagonals are zero, whereas the off-diagonals represent the penalties for confusing classes. We propose two domain-aware approaches to design W as below.

Confusion Matrix (UNETR-CM). Confusion matrix-based calibration utilizes the natural confusability among class labels using a non-calibrated DNN. First, we train UNETR-base without regularization on the training set. Then, we evaluate the trained model on a held-out validation set to generate a confusion matrix for all classes. The loss regularization is computed as below:

$$\begin{aligned} W_{ij} = S \cdot \frac{I_i - C_{ij}}{Ii} \end{aligned}$$
(2)

Here, i and j represent the row and column indices, respectively. C is the confusion matrix generated when UNETR-Base is applied on a held-out validation set and normalized by class prevalence. \(W_{ij}\) represents any given matrix entity. \(I_i\) is \(i^{th}\) row of the identity matrix. Thus \(W_{ii}=0\) so there is no penalty for the correct class. Finally, S is a scaling factor to make the regularization weights more significant. We set \(S=3\) based on empirical experiments; however, jointly varying \(\beta \) and S can change the balance of the loss function. Low values for both result in no regularization; too high and it begins to affect model accuracy. The correct values for these hyperparameters will depend on the model and dataset.

Table 1. Hierarchical class groupings. \(^{*}\)Eyes are considered to fall within CSF and soft tissue due to have aqueous and fibrous components.

Hierarchical Class (UNETR-HC). Here, we regularize using hierarchical relationships between semantic labels. Hierarchical groups are more likely to have similar properties than inter-group classes. Hence, confusion within groups can facilitate more informed and safer mistakes when wrong. Table 1 shows the hierarchy for our head segmentation. We define the matrix penalty in Fig. 1b by considering which classes are subsets of the same super-class. In Fig. 1b, each row represents the penalties for confusing the given class with any other class. The maximum penalty is 3, and penalties are manually lowered within the groups of Table 1. The eye class is considered close to two groupings. This matrix penalty is more subjective than UNETR-CM, but it incorporates domain knowledge.

Fig. 1.
figure 1

Computed matrix penalties (W terms) for both experiments

3 Experiments and Results

3.1 Dataset

This study uses data from a Phase III clinical trial on cognitive training and non-invasive brain stimulation for cognitive improvements. The study recruited participants between 65–89 years old and with age-related cognitive decline. The trial was approved by all relevant Institutional Review Boards. Structural T1-weighted magnetic resonance images (MRIs) were obtained using a 32-channel, receive-only head coil from a 3-T Siemens MAGNETOM Prisma MRI scanner. MPRAGE sequence parameters: repetition time \(=\) 1800 ms; echo time \(=\) 2.26 ms; flip angle \(=\) 8\(^\circ \); field of view \(=\) \(256 \times 256 \times 256\) mm; voxel size \(=\) 1 mm\(^3\).

Ground Truth. Trained staff segmented the T1 MRIs into 11 tissues using semi-automated segmentation. These 11 tissues included muscle, fat, skin, cortical bone, cancellous bone, major artery (blood), air, cerebrospinal fluid (CSF), eyes, grey matter (GM), and white matter (WM). Semi-automated segmentation consists of automated segmentation followed by manual correction. First, base segmentations for WM, GM, and bone were obtained using Headreco, while air was generated in the Statistical Parametric Mapping toolbox (SPM12). Next, these automatic outputs were manually corrected using ScanIP Simpleware™ (version 2018.12, Synopsys, Inc., Mountain View, USA). Bone was separated into cancellous and cortical tissue using thresholding and morphology. Blood, skin, fat, muscle, and eyes (sclera and lens) were manually segmented in Simpleware. CSF was generated by subtracting the other ten tissues from the entire head. The resulting 11 tissue masks served as the ground truths for learned segmentation.

Implementation Details. We implement UNETR using the Medical Open Network for Artificial Intelligence (MONAI-0.8) in Pytorch 1.10.0 [6]. We split our 113 MRIs into 93 training/10 validation/10 testing. Each DNN required 1 GPU, 4 CPUs, and 30 GB of memory. Each model was trained for 25,000 iterations with evaluation at 500 intervals. The models were trained on \(256 \times 256 \times 256\) images with batch sizes of 2 images. We trained our models with Adam optimization using stochastic gradient descent. UNETR segmentation results took 3 s per head. Headreco takes roughly 20 min per head.

3.2 Evaluation Metrics

We employ the following metrics on the 11-class and 6-class segmentation tasks.

Dice. represent the overlap of two binary masks [5]: \(Dice = \frac{2|Y \cap \hat{Y}|}{|Y| + |\hat{Y}|}\) where Y and \(\hat{Y}\) represent the ground truth mask and generated mask for a given tissue, respectively. A perfect overlap between these two generates a Dice score of 1, whereas a 0 represents no mask overlap.

Hausdorff Distance (Hausdorff). calculates the average distances between the closest points in two data subsets [7, 10]. Hausdorff distances are generally more robust than Dice in respect to the precise boundaries.

$$\begin{aligned} H(Y,\hat{Y}) = max(h(Y,\hat{Y}), h(\hat{Y},Y)) \end{aligned}$$
(3)
$$\begin{aligned} h(Y,\hat{Y}) = \max _{y \in Y}(\min _{\hat{y} \in \hat{Y}}(d(y,\hat{y}))),\quad h(\hat{Y},Y) = \max _{\hat{y} \in \hat{Y}}(\min _{y \in Y}(d(\hat{y},y))) \end{aligned}$$
(4)

where y represents a point in Y and \(\hat{y}\) represents a point in \(\hat{Y}\). \(H(Y,\hat{Y})\) is the overall modified Hausdorff distance, whereas \(h(Y,\hat{Y})\) and \(h(\hat{Y},Y)\) are directed Hausdorff distances. \(d(y,\hat{y})\) and \(d(\hat{y},y)\) are Euclidean distances. Smaller the Hausdorff distance indicates better segmentation.

Top-N Accuracy. Top-N accuracy measures how often your true class falls within your top N highest softmax outputs. This metric reflects meaning in the outputs that were not the selected class. For instance, higher Top-2 and Top-3 predictions can show that a well-calibrated makes reasonable mistakes that are supported by the data, rather than random misclassifications.

Calibration Curves. show the relationship between the predicted probability estimates and the true correctness likelihood. These plots are meant for binary classification, so for segmentation one class “positive” is compared to the rest “negative”. The prevalence of positive classes is compared to predicted certainty for that class. Perfect calibration is a straight line from the origin to (1,1).

3.3 Calibrated Models Outperform UNETR-Base on 11-Classes

Qualitative Analysis. Figure 2 shows that UNETR-HC best captures the fine detail of the boundary between GM and CSF. This observation is noticeable in the upper left and upper right “grooves” in the light blue (CSF) color. UNETR-HC attempts to tract out these regions and label them as CSF, whereas the UNETR-Base and UNETR-CM assign more of these pixels as GM. This boundary is a major challenge in automatic segmentation due to partial volume effects.

Table 2. Top-N accuracy on 11 classes

Quantitative Comparison. Figure 3 and Table 2 show the Dice, Hausdorff, and Top-N. UNETR-CM performs best in Dice and Top-N accuracy, whereas UNETR-CM and UNETR-HC outperform UNETR-Base in Hausdorff. Hence, UNETR-CM classifies the most pixels correctly, whereas both models capture tissue boundaries.

Table 3. Top-N accuracy on 6 classes

3.4 Calibrated UNETR Outperforms or Performs Comparably to Headreco in 6-Class Segmentation

Qualitative Analysis. We compare 6-classes because the current field standard in head segmentation (e.g., Headreco) provides different tissues than our method. For example, Headreco [14] uses 8 tissues and SPM uses 6 tissues. Thus, we had to combine tissues into groups for a fair comparison. We combine DOMINO classes that are subsets of Headreco classes; for example, cancellous and cortical bone are both labeled as bone. Figure 4 shows the results for our models and Headreco. Differences are highlighted with white rectangles. Our methods show comparable or superior performance to Headreco across all tissue types.

Fig. 2.
figure 2

Sample image slice for 11-tissue segmentation. The red squares show that UNETR-HC captures the GM - CSF boundary better than other methods (Color figure online)

Fig. 3.
figure 3

(a) Dice scores and (b) Hausdorff distances in 11-class segmentation.

Fig. 4.
figure 4

Sample image slice for 6-tissue segmentation. The white squares highlight important regions where our methods outperformed Headreco

Fig. 5.
figure 5

(a) Dice scores and (b) Hausdorff distances in 6-class segmentation.

Fig. 6.
figure 6

Calibration curves for 6-class problem.

Quantitative Comparison. Figure 5 and Table 3 show the Dice, Hausdorff, and top-1/2/3 accuracy on 6-classes. Calibrated UNETR is comparable to Headreco in WM, GM, and CSF; our models outperform Headreco in Air, Bone, and Soft tissue. UNETR-HC’s Hausdorff shows that the regularization can improve 6-class segmentation without retraining. UNETR-CM performs the best in Top-1/2/3 accuracy. Figure 6 shows that DOMINO achieves better calibration than UNETR-Base. All algorithms are approximately evenly calibrated on GM and air. Our methods are better calibrated than Headreco on WM, CSF, bone, and soft tissue.

4 Conclusions

There is often a trade-off between performance and calibration. This work proposes a novel domain-aware calibration method that improves model calibration, top-N accuracy, and segmentation metrics. The calibrated models perform well on full class and reduced class tasks without retraining. This highly-flexible approach can be applied to widespread medical segmentation. Further, model calibration can help improve cross-talk between automated algorithms and manual labelers. Finally, our calibration can be applied to classification tasks in medical image diagnosis. We will release DOMINO to the community to support open science research.