1 Introduction

Accurate segmentation of the ventricles and myocardium is fundamental to the diagnosis and treatment of myocardial infarction (MI) [17]. Cardiac MRI sequences are usually used for the MI diagnosis, in particular the T2-weighted MRI detect damaged and ischemic areas, the balanced-Steady State Free Precession (bSSFP) MRI clearly shows the heart structure boundary, and the late gadolinium enhancement (LGE) MRI can enhance infarcted myocardium with distinctive brightness compared to healthy structure [16]. Manual segmentation is time-consuming, so automatic segmentation is significant in the clinic. Recently, deep learning network has become a powerful tool for semantic segmentation on heart structures [12, 13]. Obviously, the ventricles and myocardium segmentation results can be improved combining the complimentary information from T2-weighted and bSSFP MRI sequences [16]. To save labeling time, sometimes only the T2-weighted and bSSFP MRI sequences and corresponding labels are available. However, a well-trained segmentation model may underperform when being tested on data from different modalities, which is caused by the domain shift (as shown in Fig. 1). Fine-tuning on the target domain data is a simple but efficient method to alleviate the performance drop. But it still requires massive data collection and enormous annotation workload which are impossible for many real-world medical scenarios. For this reason, constructing a general segmentation model suitable for various modalities is promising yet still challenging.

Fig. 1.
figure 1

Performance drops due to domain shift. (a) Original T2-weighted MRI (Source1). (b) Original bSSFP MRI (Source2). (c) Original LGE MRI (Target). (d) LGE MRI annotation (Label). (e) The segmentation results of LGE MRI using an established model trained on T2-weighted and bSSFP MRI data (T-noDA). (f) The segmentation results of an LGE MRI using our model trained on T2-weighted and bSSFP MRI data (T-DA). The yellow region denotes the right ventricle, the green region denotes the left ventricle, and the blue region denotes the myocardium. (Color figure online)

Unsupervised Domain Adaptation (UDA) methods have shown compelling results on reducing the dataset shift across distinct domains. Prior efforts on this problem intended to match the source and target data distributions to learn a domain-invariant representation. For example, Maximum Mean Discrepancy (MMD) was introduced to minimize the distance of source and target feature distributions in Reproducing Kernel Hilbert Space (RKHS) [11]. CycleGAN [15] tackled the image-to-image translation task in a fully unsupervised manner, and thus is capable of reducing the domain shift in the pixel-level. AdaptSegNet [10] solved the unsupervised cross domain segmentation problem by leveraging the domain adversarial training approach. In the context of medical imaging, [3] developed an UDA framework based on adversarial networks for lung segmentation on chest X-rays. [8] improved the UDA framework with Siamese architecture for Gleason grading of histopathology tissue. [5] proposed a domain critic module and a domain adaptation module for the unsupervised cross-modality adaptation problem. These approaches, which based on the domain adversarial training, required empirical feature selection. [2] proposed the synergistic fusion of adaptations from both image and feature perspectives for heart structures segmentation. However, this approach, which based on image-to-image adaptation, cannot be directly introduced to the multiple source domain adaptation problems due to the presence of multiple domain shifts between different source domains.

In this paper, we propose a domain alignment method for the UDA problem, which helps the established model segment the ventricles and myocardium accurately in the target domain without requiring target labels. Firstly, in order to reduce the domain shift with respect to the image appearance, we propose a histogram match operation for all the data. Secondly, we introduce the domain adversarial training in the output space, which can directly align the predicted segmentation results across different domains. Finally, we further propose a group-wise feature recalibration module (GFRM) to improve the domain adversarial training by integrating multi-level features without requiring manual selection to progressively align the source and target feature distributions. The proposed method is extensively evaluated on the multi-sequence cardiac MR Segmentation (MS-CMRSeg) Challenge 2019 datasets, including bSSFP, LGE and T2-weighted MRI sequences.

Fig. 2.
figure 2

Schematic view of our proposed framework.

2 Method

Figure 2 overviews our segmentation method for ventricles and myocardium in MRI sequences. We use modified 2D attention U-Net with pyramid pooling module as our segmentation backbone architecture [7, 14]. To align the distance over feature and output spaces across different domain, feature-level and mask-level discriminator are adopted. Moreover, the group-wise feature recalibration module (GFRM) is introduced to transfer multi-level feature information. The details of the above modules are shown in Fig. 3.

2.1 Network Architecture

Segmentation Network. It is essential to build upon a good baseline model to achieve high-quality segmentation results. Our segmentation network follows the spirit of attention U-Net architecture [7]. In encoder network, we keep the convolution layer as the initial setting. We perform three maxpool operations totally. Dilated convolution is adopted after third maxpool operation to capture large receptive field to alleviate loss of structural information. Inspired by [14], pyramid pooling module is introduced to generate multi-scale features to alleviate the variance of heart size over each patient. In decoder network, we perform three deconvolution operations totally. For further accurate segmentation results, attention gate (as the black dot shown in Fig. 3(a)) is utilized to learn to focus on ventricles and myocardium structures. In attention gate, the features in the encoder part (as the blue rectangle shown in Fig. 3(a)) and decoder part (as the gray rectangle shown in Fig. 3(a)) are first squeezed with \(3\times 3\) convolution layer along the channel direction respectively and then added together. After that, we squeeze the features to single channel feature map to form structure attention with \(1\times 1\) convolution layer and generate final feature maps by dot product. Finally, we use \(1\times 1\) convolution layer with four output channels followed by the sigmoid activation function to generate the probability maps. To save computational resources, we share the network with the same parameters between source and target domain.

Fig. 3.
figure 3

Architecture of the sub-networks in our framework. (Color figure online)

Group-wise Feature Recalibration Module. Before we perform group-wise feature recalibration operation, different size features from segmentation network above are expanded and concatenated via upsampling and concatenating operations. The features are send to GFRM. Our GFRM follows the spirit of [9]. Different from the above method, we divide features into four groups corresponding to the segmentation categories to focus on specific heart structures and we recalibrate features in each group (as shown in Fig. 3(b)). GFRM consists of two parts: channel attention part and spatial attention part. In channel attention part, we first squeeze global spatial information with global average pooling and fully connection layer. Then, we can generate the channel-wise attention features by simple dot product. In the spatial attention part, we first squeeze channel information with \(1\times 1\) convolution layer. Then, we can obtain the spatial-wise attention features by simple dot product. The features from channel attention part are added with the features from spatial attention part to generate group-wise recalibrated features. Finally, the features from each group are concatenated to generate final recalibrated features.

Discriminator. The feature-level and mask-level discriminator are based on the multi-level features from GFRM and predicted mask results. We use PatchGAN as our discriminator [6]. The network consists of 3 convolution layers with stride of 2 and 2 convolution layers with stride of 1. The kernel size of all convolution layers is \(4\times 4\) and the corresponding channel number is 64, 128, 256, 256, 1. Except for the last layer, each convolution layer is followed by a leaky ReLU parameterized by 0.2.

2.2 Hybrid Loss Function for Source Data

Since the labels for source domain are available, we train the segmentation network with a hybrid loss. The vanilla cross-entropy loss with our unbalanced training data leads to low accuracy. We add the Jaccard loss [1] into our loss function. The training objective for source data is

$$\begin{aligned} \mathcal L_{ce}^{s}=-\mathbb {E}_{x_{s}\sim S}(\sum _{i=1}^{N_{s}}\sum _{c=1}^{C}y_{s,i,c}\log G(x_{s,i};\varTheta _{g})) \end{aligned}$$
(1)
$$\begin{aligned} \mathcal L_{jac}^{s}=-\mathbb {E}_{x_{s}\sim S}(\sum _{i=1}^{N_{s}}\sum _{c=1}^{C}\frac{y_{s,i,c} G(x_{s,i};\varTheta _{g})}{y_{s,i,c}+G(x_{s,i};\varTheta _{g})-y_{s,i,c}G(x_{s,i};\varTheta _{g})}) \end{aligned}$$
(2)

S represents source domain; For each source image \(x_{s}\), there is one corresponding annotation \(y_{s}\); \(N_s\) is the number of all source images; \(\mathbb {E}_{x_{s}\sim S}\) means that all \(x_{s}\) are from S; C is the number of all categories; G is segmentation network; \(\varTheta _{g}\) is the parameters of G; \(y_{s,i,c}\) and \(G(x_{s,i};\varTheta _{g})\) mean the annotation and prediction vectors, respectively. For cross entropy loss, the imbalance of training data leads to a local optimum with inappropriate direction of gradient decreasing, especially in the early stage. The Jaccard loss effectively helps to avoid the local optimum due to its better perceptual quality and scale invariance [1].

2.3 Adversarial Learning for Target Data

In the target domain, due to the lack of annotations, we leverage the adversarial learning to train the segmentation network by minimizing the discrepancy across the source and target domain. Domain adaptation based on both feature and output space is proved to be effective for heart structure segmentation [4]. In our framework, we employ two discriminators. The features input to feature domain discriminator are selected empirically in [4]. To overcome this problem, we propose the GFRM to leverage the full feature spectrum and automatically select prominent features in the feature space. In the segmentation network, each feature scale generates one output feature map in the same dimension via convolution and upsampling operations. The feature maps are further processed by the GFRM to highlight the prominent features and suppress the irrelevant ones. The combined feature maps are then fed to the feature discriminator network for the adversarial learning, where the losses are defined as

$$\begin{aligned} \begin{aligned} \mathcal L_{adv_{D_f}}=&-\mathbb {E}_{x_{s}\sim S}\log D_f(R(G(x_{s};\varTheta _{g});\varTheta _{r});\varTheta _{d_f})\\&-\mathbb {E}_{x_{t}\sim T}(1-\log D_f(R(G(x_{t};\varTheta _{g});\varTheta _{r});\varTheta _{d_f})) \end{aligned} \end{aligned}$$
(3)
$$\begin{aligned} \mathcal L_{adv_{G_f}}=-\mathbb {E}_{x_{t}\sim T}\log D_f(R(G(x_{t};\varTheta _{g});\varTheta _{r});\varTheta _{d_f}) \end{aligned}$$
(4)

T represents target domain; where \(x_{t}\) is target data; \(\mathbb {E}_{x_{t}\sim T}\) means that all \(x_{t}\) are from T; R is the GFRM; \(\varTheta _{r}\) is the parameters of R; \(D_f\) is the feature discriminator; \(\varTheta _{d_f}\) is the parameters of \(D_f\).

In the output space, the segmentation results of target domain should be similar to the ones of source domain. To achieve this, we employ the adversarial learning technique in the output space, where the losses are defined as

$$\begin{aligned} \begin{aligned} \mathcal L_{adv_{D_m}}=&-\mathbb {E}_{x_{s}\sim S}\log D_m(G(x_{s};\varTheta _{g});\varTheta _{d_m})\\&-\mathbb {E}_{x_{t}\sim T}(1-\log D_m(G(x_{t};\varTheta _{g});\varTheta _{d_m})) \end{aligned} \end{aligned}$$
(5)
$$\begin{aligned} \mathcal L_{adv_{G_m}}=-\mathbb {E}_{x_{t}\sim T}\log D_f(G(x_{t};\varTheta _{g});\varTheta _{d_m}) \end{aligned}$$
(6)

where \(D_m\) is the mask discriminator; \(\varTheta _{d_m}\) is the parameters of \(D_m\).

Combined with the aforementioned loss, the full objective function

$$\begin{aligned} \begin{aligned} \mathcal L_{FULL}=&\lambda _{ce}\mathcal L_{ce}+\lambda _{jac}\mathcal L_{jac}+\lambda _{D_f}\mathcal L_{adv_{D_f}}\\&+\,\lambda _{G_f}\mathcal L_{adv_{g_f}}+\lambda _{D_m}\mathcal L_{adv_{D_m}}+\lambda _{G_m}\mathcal L_{adv_{g_m}} \end{aligned} \end{aligned}$$
(7)

3 Experiment

Dataset. The validation of the proposed method is performed in the MS-CMRSeg Challenge 2019 dataset covering 45 patients. There are bSSFP, T2-weighted and LGE MRI sequences in each patient data. In one patient data, the slice number and annotation of three MRI modalities are different. We combine labeled bSSFP and T2-weighted MRI sequences as source data, and unlabeled LGE MRI sequences as target data. Experienced experts manually annotated the left ventricle (LV), right ventricle(RV) and myocardium (Myo) as ground truth. We pre-processing the data for domain adaptation. The data is resized and cropped to \(400\times 400\) in the center of each slice. In order to eliminate the inconsistency in appearance, we perform histogram match operation on both source and target data, as shown in Fig. 4.

Fig. 4.
figure 4

Visual comparison for histogram match operation: (a) T2-weighted MRI. (b) T2-weighted MRI after histogram match. (c) bSSFP MRI. (d) bSSFP MRI after histogram match. (e) LGE MRI. (f) LGE MRI after histogram match.

Implementation Details. In our experiments, we implement our whole network with PyTorch, using a standard PC with a single NVIDIA 1080Ti. To train the segmentation network, we use the Stochastic Gradient Descent (SGD) optimizer with Nesterov acceleration where the momentum is 0.9 and the weight decay is \(1e^-4\). The initial learning rate is set as 0.01 and is decreased to 0.001 after 80 epochs. For training the both feature and mask discriminator, we use Adam optimizer with the fixed learning rate as 0.0002. The weight decay is set as \(5e^-5\). We totally trained 150 epochs with a mini-batch size of 8. We set \(\lambda _{ce}\), \(\lambda _{jac}\), \(\lambda _{G_f}\), \(\lambda _{D_f}\), \(\lambda _{G_m}\) and \(\lambda _{D_m}\) to 0.5, 0.5, 0.05, 1.0, 0.005 and 1.0. The training time cost only 5 h to converge.

Fig. 5.
figure 5

Visual comparison for the LV, RV, and Myo segmentation results from ablation setting. (a) Original image from source domain. (b) Annotation. (c) S2T. (d) S2T+HM. (e) S2T+HM+MDA. (f) S2T+HM+MDA+FDA. (g) S2T+HM+MDA+FDA+GFRM.

Quantitative and Qualitative Analysis. In order to verify the effectiveness of the proposed method, we adopt Dice coefficient (DSC), Jaccard coefficient (Jac) for further evaluation. We first trained segmentation network on the source data and then test on the target data (S2T). The results in Table 1 shows that the mean Dice in S2T is too slow. As we can see, our method can promote about \(36.09\%\) in DSC and \(38.38\%\) in Jac than S2T, which indicates that our method can alleviate dataset shift across different domains.

In addition, we examine the effect of the histogram match operation (HM), mask-level adversarial learning (MDA), feature-level adversarial learning (FDA) and GFRM on the performance in the target domain. The result of the ablation study in Table 1 shows that our proposed modules can achieve a better performance than S2T. Figure 5 demonstrates that each proposed module can contribute to alleviate the domain misalignment.

Table 1. Quantitative evaluation of our proposed methods

4 Conclusion

In this paper, we proposed an unsupervised domain alignment method for left ventricle (LV), right ventricle (RV) and myocardium (Myo) segmentation from different cardiac MR sequences. We first introduced a segmentation network with hybrid segmentation loss to generate accurate prediction. We alleviate the dataset shift across different domains by leveraging the adversarial learning in both feature and output spaces. The proposed GFRM can enforce the fine-grained semantic-level feature alignment that matching features from different networks but with the same class label. Experiments show that the proposed method can achieve competitive results.