Keywords

1 Introduction

In recent years, convolutional neural networks (CNNs) have become the dominant approach to solve medical image segmentation tasks [14]. A wide variety of CNN models, training procedures and loss functions were built under the BRATS [16] and ISLES [15] competitions. The most common way to measure the performance of such a new method is to use segmentation voxel-wise metrics, e.g. Dice Score [2]. However, in the case of multiple lesions per image, clinical tasks also require analyzing algorithm in terms of the detection quality. For instance, all tumors, including the smallest ones, should be found and delineated in the brain stereotactic radiosurgery or in the lung cancer screening process. But since the Dice Score is a voxel-wise metric, it does not differentiate between missing several True Positives in a large lesion or in a small one.

Learning a model under the presence of extremely small targets is challenging. This is especially the problem for 3D medical image segmentation tasks. The total fraction of voxels with lesion is about \(0.1\%\) in the case of lung nodules and about \(1\%\) in case of multiple brain metastases. Moreover, in a series of medical image segmentation tasks we have a problem with the size imbalance. In some cases, large lesions could be up to 50 times bigger than the small ones (see typical lesion diameters distribution on Fig. 2).

Several approaches have been suggested to tackle the problem of target imbalance. The main idea is to add weight to a loss function to equally represent each class (lesion vs non-lesion or different lesion types in a multi-class problem). It is implemented, for example, in Weighted Cross-Entropy [18] and Generalized Dice Loss [19]. The shortcoming of this approach is that it pays attention only to the lesion type, but not the lesion size (see Fig. 1). Besides, most of the research focuses on the delineation quality and lacks an investigation into the detection performance. Ideal segmentation implies perfect detection, however, due to the substantial differences between large and small lesions, almost a perfect delineation could have poor detection quality. Here we address this problem by applying the idea of weighting a loss function with respect to target sizes.

Our contribution is twofold:

  • We propose a loss function reweighting strategy, that balances the lesions of different sizes. We call our approach inverse weighting, since the generated weights are inversely proportional to the lesion size.

  • We evaluate the effect of using the most popular segmentation loss functions on segmentation quality and network’s ability to detect lesions of different sizes. On a series of medical image segmentation tasks, we show how our approach improves the detection quality, especially for small lesions (Fig. 3), while preserving delineation performance.

Fig. 1.
figure 1

The effect of inverse weighting. No reweighting applied (left), class balancing via weighted-cross entropy (center), inverse weighting (right). Weights for every tumor are calculated using formulas in Table 1 and placed near the tumors.

2 Related Work

A large number of neural network architectures, improved training procedures, and loss functions have been proposed in recent years. We extensively investigate the behavior of loss functions keeping the rest of the deep learning pipeline on the state-of-the-art level without diving into details.

The Binary Cross-Entropy (BCE) is the standard loss function commonly used for segmentation tasks. It does not handle the problem of class imbalance and differently sized objects thus often yielding poor results. Authors of [13] suggested using Focal Loss as an extension of BCE in highly class imbalanced detection tasks and it is widely used in segmentation tasks as well [8]. Focal Loss does not apply any type of reweighting but automatically focuses the network attention on difficult examples. Dice Loss [17] has recently become one of the state-of-the-art losses for medical image segmentation tasks. The authors claim that Dice Loss establishes the right balance between classes without assigning any weights. But for the tasks with multiple targets, a large object overshadows the small one, hence the network tends to miss small lesions. Recent work [8] proposed Asymmetric Similarity Loss (ASL) based on \(F_{\beta }\) score. ASL extends Dice Loss (the special case with \(\beta = 1\)) and allows training a network with a better balance between precision and recall. But it shares the same drawback with Dice Loss: differently sized overshadowing objects. Authors of [5] proposed Sensitivity-Specificity loss which we left without consideration. It performs worse than Dice Loss on a 3D medical image segmentation task in [19] and utilizes a similar idea with ASL.

Several approaches reweight BCE and Dice Loss to improve network performance in medical image segmentation tasks. In [18] authors use Weighted Cross-Entropy (WCE) loss and [19] suggest Generalized Dice Loss (GDL) to tackle the problem of class imbalance. Both approaches utilize the same idea of reweighting the corresponding losses with weights inverse to the sizes of classes (see Table 1). Our approach simultaneously solves class imbalance problem and imbalance between differently sized objects. A deeper modification of Cross-Entropy loss to handle class imbalance is evaluated in [11], but the goal is quite different – overfitting on small datasets. In [21] authors suggest, a highly dependable on hyperparameters, a combination of Cross-Entropy and logarithmic Dice Loss to solve multiclass (19 classes) segmentation problem. In our work, we show an improvement for both of these losses independently.

We focus our attention on the most relevant loss functions and their explicitly reweighted modifications. Below we detail how our method is applied to state-of-the-art losses and compare it with WCE and GDL.

3 Method

We find out that all models tend to miss small targets when training with BCE or Focal Loss. We assume poor performance comes from the inability of these losses to equally represent differently sized targets. Dice Loss and ASL have the same drawback: large targets overshadow the small ones. Moreover, already developed losses handle only the imbalance between classes, not between lesion sizes. We aim to close the gap and propose a simple methodology to reweight loss functions in the way that all targets contribute equally, e.g. small targets have greater weights.

During the training stage, we generate a tensor of weights for every incoming patch. To form such a tensor we split the corresponding ground-truth patch into \(K+1\) connected components \(L_0,\ldots , L_K\), where \(L_0\) is the non-lesion component (background) and K is the number of lesions in the current patch. Next, we assign the weight to every component which is inverse to the component’s volume:

$$\begin{aligned} w_j = \dfrac{\sum _{k=0}^K |L_k|}{\left( K + 1 \right) \cdot |L_j|}, \end{aligned}$$
(1)

here \(w_j\) is the weight, assigned to every voxel inside the corresponding component \(L_j\). The constant in the denominator ensures that the sum of our weights is equal to the sum of the unit tensor of the same size (see derivation details in Supplementary Materials). We call this method inverse weighting (iw). Note, how our approach assigns greater weights to the smaller tumors (Fig. 1). At this point, we can modify any of the discussed loss functions with our reweighting. Since WCE and GDL explicitly reweight state-of-the-art losses, we do not apply reweighting twice. Corresponding modifications for BCE, Focal Loss, Dice Loss, and ASL are shown in the Table 1.

Table 1. Loss functions and their modifications. Here \(y_i\) denotes the \(i^{th}\) element of the ground truth binary mask, \(p_i\) is the corresponding predicted probability, and \(\mathbf {w_i}\) is the proposed inverse weight.

4 Experiments

4.1 Data

We report our results on three datasets. Two publicly available datasets that include 3D CT images: LUNA16 [10] with lung cancerous nodules and LiTS [4] with liver tumors; and one private dataset with MR images of multiple brain metastases.

LUNA16 includes 816 (we have excluded 72 cases with nodules located outside of lung masks) annotated chest scans from LIDC/IDRI database [1]. For every image, we clip intensities between \(-1000\) and 300 Hounsfield units (HU), and then set the voxels outside the given binary lung masks to \(-1000\). Ground truth mask was formed by averaging 4 given annotations.

Metastases (private dataset) includes 1952 unique patients with the T1-weighted MRI of the head. We apply no preprocessing steps to these images.

LiTS includes 131 annotated CT abdomen scans. For every image, we clip intensities between \(-300\) and 300 HU and then apply a given binary mask of liver the same way we did it with LUNA16 data.

Before passing through the network, we scale images to have voxel’s intensities between 0 and 1.

Fig. 2.
figure 2

Lesion diameters distribution. Metastases under 5 mm, lung nodules under 10 mm and liver tumors under 12 mm are considered small, according to the clinical recommendations [3, 12]

We use train-validation setup to compare different architectures and hyperparameters for loss functions. Then the merged combination of training and validation data is used to train the chosen methods and we report final results on previously unseen hold-out set. LUNA16 is presented as 10, approximately equal, subsets [10] thus we use the first 6 for training (534 images), next 2 for validation (178 images) and the last pair as hold-out (174 images). We divide Metastases into training (1250 images), validation (402 images) and hold-out (300 images). LiTS is also presented as 2 subsets, so we use the first for training (104 images) and the second as hold-out (27 images). We do not shrink the validation part of the LiTS, since this dataset is used only once for the final results reporting.

4.2 Architecture and Training

For all our experiments we consistently use a single CNN model – slightly modified 3D U-Net [6]. Implemented architecture within PyTorch framework is available in our repository along with a schematic image. Following the suggestion of [9], we do not focus our attention on fine-tuning the CNN model.

In all scenarios we train the model for 100 epochs, starting with learning rate of \(10^{-2}\), and reducing it to \(10^{-3}\) at the epoch 80. Each epoch consists of 100 iterations of stochastic gradient descent with Nesterov momentum (0.9). At every iteration we sample patches of size \(128 \times 128 \times 128\) and batch size of two. With the probability of 0.5 we sample the patch so that it contains at least one voxel with lesion, otherwise we sample it uniformly. The training takes about 26 hours on a 24 GB NVIDIA Tesla M40 GPU.

Note, that only two of the considered loss functions have hyperparameters: ASL (\(\beta \)) and Focal Loss (\(\gamma , \alpha \)). We use ASL with \(\beta = 1.5\) originally recommended in [8]. For Focal Loss we also use \(\gamma = 2\) originally recommended in [13], but change \(\alpha \) to be 0.75 chosen on validation.

4.3 Metric

Dice Score has a particular drawback measuring the delineation quality in the tasks with multiple lesions per image: big lesion overshadows small ones. We use object Dice Score – the average Dice Score over unique found lesions. Therefore it does not shift towards larger lesions. Note that we exclude missed lesions from this analysis, hence the delineation quality is independent from detection quality.

To measure the detection quality we suggest using a Free-response Receiver Operating Characteristic (FROC) curve analysis. It is extremely efficient operating with multiple targets and False Positive (FP) responses per case [7]. A FROC curve measures the sensitivity to detected objects instead of voxel-wise sensitivity, therefore does not have the same drawback of overshadowed lesions. A FROC curve summarizes the model’s efficiency with the trade-off between the fraction of lesions detected (Recall) and the average number of FPs per image. But it gives us only visual representation of experimental results. To compare the performance of different methods we extract a single value from the curves. Authors of [20] suggested using the average object-wise Recall over the predefined FP values (1/8, 1/4, 1/2, 1, 2, 4, 8) which is also the main metric of LUNA16 challenge [10]. This metric gives us the average fraction of detected lesions per case which is highly interpretable in terms of detection quality.

To calculate the confidence intervals for FROC curves and for average Recall we use bootstrapping. We sample 80% of test patients and build a curve on every of the 100 iterations. Average recall is calculated for every bootstrapped curve and we report the mean value along with the standard deviation.

4.4 Results and Discussion

We visualize our main contribution with the considerable improvement of the average object-wise Recall for all four chosen loss functions on all three datasets (Fig. 3). We also report our metrics separately for three groups of lesion sizes and show a solid contribution into the small lesion detection quality which satisfies our method’s motivation. However, a comparison with WCE is worth a more detailed discussion.

Fig. 3.
figure 3

The impact of inversely weighted loss functions in terms of average Recall and object-wise Dice Score. We show performance on three approximately equal subsets (1/3 each) of lesions divided by their size. Small and medium groups correspond to the clinical recommendations of small lesions (see Fig. 2).

Images from LUNA16 contain 1.3 nodules per scan on average, while Metastases and LiTS have 4.8 and 6.9 tumors per scan respectively. The latter means that LUNA16 is hardly an appropriate dataset to benefit from our method, since the majority of training patches contain only one lesion. One lesion per patch is clearly the class imbalance problem, and WCE outperforms the other methods in terms of average Recall. But nevertheless we show inverse weighting solving also the class imbalance task on the competitive level solidly improving BCE and Focal Loss performance. Finally, even the slight improvement in the detection quality of WCE comes with the dramatic delineation quality loss on the other two datasets, which is crucial for clinical tasks.

GDL failed to surpass inversely weighted loss function almost in all scenarios. But overall we find ASL and Dice Loss along with GDL and their inversely weighted modifications to be highly stable during the training. Respectively, Dice-like loss function sufficiently outperform BCE-like losses both in terms of the detection and the delineation qualities. We believe such a behaviour comes from two properties of Dice Loss. Firstly, it is designed to optimize the Dice Score metric, and one could clearly see the dominance of Dice-like losses in terms of object Dice Score (Fig. 3 and Table 2). Secondly, it partially solves the class imbalance problem, but only in the cases with exactly one object per patch. The latter is again perfectly demonstrated on LUNA16, as we put this dataset to be more about class imbalance problem in the previous paragraph. One could see the already high object Dice Scores and average Recall values of ASL and Dice Loss on LUNA16 along with minor changes of their reweighting.

However, modified with inverse weighting loss functions have a noticeable decrease in delineation quality on LiTS data. We consider this to be a side effect of highly increased object-wise Recall: modified losses find more difficult cases, hence joint object Dice Score could decrease.

Besides the separate performance on lesion sizes we also include more detailed results for all lesions in hold-out sets (Table 2). We give the visual representation of experimental results in terms of detection quality via FROC analysis (see Supplementary Materials).

Table 2. Results for all considered loss functions along with the proposed method – inverse weighting (“\(+\)” with iw, “−” without iw). The numbers in brackets are standard deviation.

5 Conclusion

We propose a universal approach to loss functions reweighting. It could be used with almost any state-of-the-art loss function. Our experiment demonstrates an improvement of network’s ability to detect lesions for Cross-Entropy, Focal Loss, Dice Loss and Asymmetric Similarity Loss on three medical tasks with multiple targets per case. Moreover, we believe the method can also improve quality with other complex multi-stage pipelines or with any other CNN architecture which is the goal for our future research.