Keywords

1 Introduction

The spinal serving as the central axis of the skeletal structure, assumes a vital role in protecting essential organs, blood vessels, and nerves [1]. As the population ages, the incidence of spinal disorders has witnessed a significant increase. In the domain of computer-aided diagnosis and treatment of spine-related diseases, the multi-label segmentation of volumetric magnetic resonance (MR) images pertaining to vertebral bones and intervertebral discs assumes a critical significance. Accurate segmentation of the spinal region, as depicted in Fig. 1, empowers medical practitioners to assess the structural characteristics and overall health of vertebrae and intervertebral discs, thereby facilitating early detection, diagnosis, and surgical planning for various spinal conditions, including deformities, traumas, tumors, and fractures.

Fig. 1.
figure 1

The task involves the multi-label segmentation of volumetric MR images depicting the vertebrae and intervertebral discs, encompassing 10 distinct labels for vertebrae and 9 for intervertebral discs. It is worth noting that the labels correspond to vertebrae located in the thoracic (T), sacral (S), and lumbar (L) regions.

Currently, with the progress of artificial intelligence, contemporary medical image spinal segmentation techniques are predominantly built upon two predominant strategies:1) Traditional machine learning based methods. Bao et al. [2] employed a linear iterative clustering algorithm to acquire superpixel MRI images of the spine, enabling the subsequent segmentation of the spinal region. Viji et al. [3] applied a probabilistic boosting tree (PBT) approach in conjunction with fuzzy support vector machine segmentation to achieve automated detection of the spinal canal.2) Deep learning-based methods. In contrast to conventional methodologies, deep learning techniques have demonstrated remarkable efficacy in the domain of spinal segmentation.

Fig. 2.
figure 2

Our proposed segmentation network consists of two stages, namely 3D coarse segmentation and 2D refinement segmentation.

Particularly, convolutional neural networks (CNNs) [4,5,6,7,8] have been widely adopted, yielding significant advancements in spinal MR image segmentation. Noteworthy models such as the fully convolutional neural network (FCNNs) [4, 5] and U-Net [6, 7] have played a prominent role in these advancements. However, the effectiveness of FCNNs is limited by the restricted spatial range of the convolutional layers, impeding the model’s ability to capture long-range spatial correlations. Despite the increasing diversity of models employed in spinal segmentation, they often overlook the distinctive chain structure of the spine and neglect the structural interdependencies among neighboring vertebrae and lumbar discs. These approaches overlook the holistic architecture of the spine, the persistent long-range dependencies between vertebrae, and the inherent relationships among them. Furthermore, the significant computational and memory requirements associated with these methods impose limitations on their adaptability in diverse spinal segmentation scenarios.

Our work presents the following main contributions:

  1. 1.

    We propose a novel two-stage network architecture designed specifically for the segmentation of biomedical 3D MR images. Our approach involves the integration of a coarsely segmented 3D Transformer to capture long-distance dependencies, along with a finely segmented 2D CNN to capture local high-level features effectively.

  2. 2.

    The incorporation of both 3D and 2D networks enables our model to assimilate a broader range of feature information from images with varying dimensions, thus enhancing its ability to learn diverse representations.

  3. 3.

    To further augment the segmentation performance of our proposed two-stage network, we introduce graph convolution modules within both the 3D and 2D networks. This integration harnesses the power of graph convolution to exploit spatial relationships, leading to improved segmentation outcomes.

Fig. 3.
figure 3

An outline of the architecture of the 3D coarse segmentation network is presented. The input to the initial segmentation stage consists of 3D multi-modal MRI images with 4 channels. The encoded feature representations in the Swin transformer are transmitted to a CNN-decoder via skip connections at multiple resolutions. The final segmentation output comprises 3 output channels.

2 Methods

2.1 Overall Architecture Design

We presents an innovative methodology for multi-class segmentation, employing a two-stage approach. In particular, we introduce a U-shaped 3D coarse segmentation network, leveraging Transformers as the foundation for the initial segmentation stage, followed by a refinement segmentation network based on DeepLabv3+ in the subsequent stage. The 3D coarse segmentation network utilizes Swin Transformers as the encoder, which is connected to FCNN-based decoders via skip connections. The decoder generates probability maps for the coarse segmentation task. Subsequently, during the refinement segmentation stage, the volumetric MR image and the probability map derived from the 3D coarse segmentation network serve as inputs for the 2D refinement segmentation network, aiming to achieve more precise and intricate segmentation results. Our proposed two-stage network is specifically tailored for multi-category segmentation of vertebrae and intervertebral discs in volumetric MR images. Figure 2 provides a visual depiction of the network architecture, offering an overview of its structural components.

2.2 3D Coarse Segmentation Stage

Inspired by the effectiveness of the “U-shaped” network architecture, we present a U-shaped 3D coarse segmentation network built upon the Swin Transformer. This network is designed for application during the coarse segmentation stage. The structural configuration of the coarse segmentation network is illustrated in Fig. 3 of this study.

Fig. 4.
figure 4

The architecture of the 2D refinement segmentation network is delineated. The inputs of this network include both the 2D MR sagittal slice and its corresponding coarse probability map, which is produced by the 3D coarse segmentation network.

Our coarse segmentation network follows a contracting-expanding pattern, incorporating a stack of transformers as the encoder and establishing connections with the decoder through skip connections. The input token \(X \epsilon R^{H \times W \times D \times S}\) to the coarse segmentation network exhibits a patch resolution of (\( \hat{H},\hat{W},\hat{D} \)) and a dimension of \( \hat{H} \times \hat{W} \times \hat{D} \times S \). To facilitate the projection of a 3D token sequence with a dimensional parameter \( [\frac{H}{\hat{H}}] \times [\frac{W}{\hat{W}}] \times [\frac{D}{\hat{D}}] \) onto an embedding space of dimensional parameter C, we employ a patch partition layer. This layer enables the transformation of the input token sequence into an embedded representation.

In order to capture token interactions effectively, we incorporate a self-attention mechanism that operates across non-overlapping windows generated during the partitioning phase. Within the transformer encoder architecture, at a specific layer denoted as l, we employ windows of size M\(\times \)M\(\times \)M to evenly divide a 3D token sequence into \( [\frac{\hat{H}}{M}] \times [\frac{\hat{D}}{M}] \times [\frac{\hat{W}}{M}] \) regions. These partitioned window segments are subsequently shifted by (\( [\frac{M}{2}],[\frac{M}{2}],[\frac{M}{2}] \)) voxels in layer l+1. Instead of the conventional multi-head self-attention (MSA) module, the Swin Transformer utilizes a shifted windows module, which constrains self-attention calculations to non-overlapping local windows using the shifted windows strategy. This approach not only facilitates efficient computation but also enables the modeling of token dependencies across the entire sequence.

The Swin Transformer module consists of a multi-head self-attention (MSA) module with a shifted window and a two-layer MLP, embedded between Gaussian Error Linear Units (GELU) nonlinearities. Prior to each MSA and MLP module, a LayerNorm (LN) layer is applied. Moreover, residual connections are established between two Swin Transformer modules, enhancing information flow within the network. The introduction of the shifted window division method in the Swin Transformer module optimizes its computational efficiency. The calculation process of the Swin Transformer module, employing this method, can be outlined as follows:

$$\begin{aligned} \hat{Z}^{l}=W-MSA(LN({Z}^{l-1}))+{Z}^{l-1} \end{aligned}$$
(1)
$$\begin{aligned} {Z}^{l}=MLP(LN(\hat{Z}^{l}))+\hat{Z}^{l} \end{aligned}$$
(2)
$$\begin{aligned} \hat{Z}^{l+1}=SW-MSA(LN(\hat{Z}^{l}))+{Z}^{l} \end{aligned}$$
(3)
$$\begin{aligned} {Z}^{l+1}=MLP(LN(\hat{Z}^{l+1}))+\hat{Z}^{l+1} \end{aligned}$$
(4)

where \(\hat{Z}^{l}\) and \({Z}^{l}\) stand for the (S) W-MSA modules and the MLP module’s respective block l output characteristics. Similar to other studies [9, 10], the following formula is used to calculate self-attention:

$$\begin{aligned} Attention(Q,K,V)=SoftMax(\frac{QK^{T}}{\sqrt{d}}+B)V \end{aligned}$$
(5)

where d is the query/key dimension,\({M}^{2}\) is the number of patches in a window, and Q, K, and V are the queries, key, and value metrics. Since the range of the relative position along each axis is \([-M+1, M-1]\), we parameterize a bias matrix with a smaller size, \({B}\in R^{(2M-1)\times (2M-1)}\), and values in B are obtained from \(\hat{B}\).

2.3 2D Refinement Segmentation Stage

During the 2D segmentation stage, our methodology is primarily guided by the design principles of DeepLabv3+ [?]. In the encoder phase, we employ parallel atrous convolution at multiple rates, commonly referred to as Atrous Spatial Pyramid Pooling (ASPP) [11], to effectively encode multi-scale context information. For the segmentation task, we adopt the Xception architecture and incorporate depthwise separable convolution to enhance both the efficiency and precision of network training. Furthermore, to refine the segmentation outcomes, we introduce a straightforward yet highly efficacious decoder module, which builds upon the aforementioned foundation. The architectural details of the refinement segmentation network are presented in Fig. 4.

The 2D refinement segmentation network takes as input the 2D MR sagittal slice and the coarse probability map corresponding to that slice, which is generated by the 3D coarse segmentation network. Incorporating the coarse probability map enables the 2D refinement segmentation network to leverage the implicit 3D semantic information of the image. By effectively integrating the semantic features of the spinal structure with detailed information, the network achieves accurate segmentation. The high-resolution MR slices contain detailed information, and the 2D refinement segmentation network combines this information with the 3D semantic information to produce fine segmentation.

2.4 Graph Convolution Module

The graph convolution module consists of three consecutive stages: Pooling, Graph Convolutional Network (GCN), and Unpooling. In the Pooling stage, the input image representation is transformed into a graph-based representation to facilitate subsequent processing by the GCN stage. The GCN stage aims to generate graph representations enriched with semantic information through the application of graph convolution operations. In the final Unpooling stage, the obtained semantic graph representation is mapped back to the semantic image representation and passed to the convolution layer for further processing.

Table 1. The mean DSC (%) for the proposed method and other methods on the MRSpineSeg Challenge dataset.
Table 2. The mean DSC (%) for the proposed method and other methods on the MRSpineSeg Challenge dataset.

3 Experiments

3.1 Dataset

Our proposed method was evaluated on the MRSpineSeg Challenge dataset, which comprises a total of 215 T2-weighted MR volumetric images. During the experiment, 172 images were utilized, and they were partitioned into training, validation, and testing sets in a ratio of 7:2:1. The volumetric images encompassed 10 vertebrae, 9 intervertebral discs (IVDs), and backgrounds, resulting in a total of 20 distinct categories. The original images exhibited varying dimensions, with widths and heights ranging from 512 to 1024, while the number of slices along the coronal axis ranged from 12 to 20.

3.2 Implementation Details

For 3D networks, the pre-processing stage comprises a series of steps aimed at preparing the input data. These steps include cropping, resizing, padding, and normalization. To begin with, the cropping step involves center-cropping the images along the depth direction to eliminate the non-spine portion, as half of the image does not contain spinal information. Subsequently, the cropped image is resized to a dimension of \(18\times 256\times 128\) pixels, with zero filling applied in the depth direction to ensure uniformity. Lastly, the normalization process involves computing the mean and variance values across all the images. These values are then utilized to subtract the mean from each pixel and divide by their standard deviation, resulting in a normalized representation of the data.

Fig. 5.
figure 5

Using our segmentation method, sagittal slices depicting vertebrae and intervertebral discs from six subjects were acquired. The label BG means the background in these slices.

Our methodology was implemented using the Python programming language based on the PyTorch deep learning framework. The model was trained on an Nvidia RTX 3090 GPU with 24 GB of RAM. During the 3D segmentation stage, a preliminary probability map of dimensions \(20\times 18\times 256\times 128\) was generated, with an MR volume of size \(16\times 256\times 128\) serving as input. The Adam [13] optimizer was employed for optimization, with a weight decay of 0.0001. We initiated the learning rate at 0.001, and reduced it by a factor of 5 every 33 epochs. The batch size was set to 2, which was limited by the available GPU memory.

Table 3. Ablation experiments were conducted on the MRSpineSeg Challenge dataset to assess the effectiveness of each component in segmentation of the ten classes of vertebrae T9-S. The mean DSC (%) was used to validate the components.
Table 4. Ablation experiments were conducted on the MRSpineSeg Challenge dataset to assess the effectiveness of each component in segmentation of the nine classes of IVDs T9T10-L5S. The mean DSC (%) was used to validate the components.

3.3 Evaluation Metrics

To assess the segmentation performance, several metrics were employed in our experiment, including the Dice similarity coefficient, precision(DSC), and recall. These metrics are computed as follows:

$$\begin{aligned} Dice =\frac{2TP}{FP+2TP+FN}. \end{aligned}$$
(6)
$$\begin{aligned} Pre =\frac{TP}{TP + FP}. \end{aligned}$$
(7)
$$\begin{aligned} Recall =\frac{TP}{TP + FN}. \end{aligned}$$
(8)

where TP, FP, FN, and TN denote the number of true positives, false positives, false negatives, and true negatives, respectively.

3.4 Experiment Results

The Table 5 displays the precise values of Mean Recall, Mean Precision, and Dice Similarity Coefficient (DSC) achieved by the two-stage segmentation network for vertebrae, intervertebral discs (IVD), and all 19 spinal structures. We have presented some exemplary images with well-performing segmentation results in Fig. 5.

We conducted a comparative analysis of our proposed spinal segmentation method with several other methods, including nnUNet [14], VNet [15], UNETR [16], 3D Graphonomy [12], 3D Deeplabv3 [11] + 2D ResidualUNet [17], and 3D Graphonomy [12] + 2D Deeplabv3 [11]. The evaluation of the segmentation performance across these methods was based on three crucial metrics: the Dice similarity coefficient (DSC), Precision, and Recall. Tables 1 and 2 present the DSC evaluation indexes specifically for the segmentation of each vertebra and intervertebral disc (IVD). Our proposed segmentation network demonstrated superior performance compared to the other methods, achieving excellent segmentation results for the seven categories of vertebrae T12-S (DSC > 87.34%) and the seven categories of IVDs T11-S (DSC > 86.03%). These quantitative comparison results highlight the notable superiority of our proposed methodology. Furthermore, Fig. 6 showcases specific segmentation results obtained by applying different algorithms to the aforementioned dataset, providing visual evidence of the superior segmentation outcomes achieved by our proposed method.

Table 5. The average values of Recall, Precision, and Recall were computed for the segmentation of vertebrae, intervertebral discs (IVDs), and all 19 spinal structures using our proposed two-stage segmentation network.

The performance of our network on the segmentation of T9-T11 vertebrae (DSC \(\le \) 80.75%) and T9-T12 IVDs (DSC \(\le \) 76.92%) is unsatisfactory due to several factors. Firstly, the dataset contains very limited samples of T9-T11 vertebrae and T9-T12 IVDs, with most of them being incompletely shaped. Secondly, the top of the image contains three types of vertebrae (T9-T11) and two types of IVDs (T9-T12), making segmentation difficult due to the limited receptive field at the top. These factors contribute to the suboptimal segmentation results of our network in these regions.

3.5 Ablation Study

To assess the efficacy of each constituent element within our network architecture, a series of ablation experiments were conducted, yielding results that have been presented in Tables 3 and 4. The evaluation process encompassed six distinct configurations involving the integration of the 3D Swin Transform and 3D GCM during the 3D Coarse Segmentation stage, as well as the utilization of the 2D Deeplabv3+ and 2D GCM during the 2D Refinement Segmentation stage. These meticulous experiments effectively demonstrated the augmentation of segmentation performance for both Vertebrae and IVDs through the inclusion of the graph convolutional module and the employment of a dual network strategy.

Fig. 6.
figure 6

Visualized comparison of results using different segmentation networks.

3.6 Effect of the Two-Stage Framework

The incorporation of 2D refinement stages into 3D segmentation tasks has demonstrated considerable effectiveness in enhancing the performance of segmentation algorithms. This enhancement is substantiated by the findings presented in Tables 1 and 2, which elucidate the improvements attained through the integration of 2D refinement in the 3D Graphonomy and 3D Graphonomy+2D Deeplabv3 frameworks, respectively. In comparison to the sole utilization of 3D Graphonomy, the inclusion of a 2D refinement stage within the 3D Graphonomy+2D Deeplabv3 approach yielded a notable increase in the average Dice similarity coefficient (DSC) across the eight classes of vertebrae T11-S and T10T11-L5S, as well as the eight classes of intervertebral discs (IVDs). Similarly encouraging results were obtained from the ablation experiments conducted, as evidenced by the outcomes presented in Tables 3 and 4. The incorporation of high-resolution images within the 2D networks contributed to a more comprehensive representation of detailed information, enabling the model to acquire a richer understanding of the underlying features. Consequently, the synergistic combination of 3D and 2D networks facilitated the assimilation of more contextual information from images with varying dimensions, culminating in a discernible enhancement in segmentation performance.

3.7 Effect of the GCM

The quantitative findings elucidated in Tables 3 and 4 provide compelling evidence that the incorporation of the graph convolution module into either the 3D or 2D network yields notable advantages in enhancing the segmentation performance of the model. Notably, it should be acknowledged that during the training phase of both the 3D and 2D networks, the segmentation results may not strictly adhere to the spatial order of the spinal structure. Consequently, the inclusion of the graph convolution module in both the 3D and 2D networks emerges as a more favorable approach for boundary position segmentation.

4 Conclusion

This paper presents a novel two-stage framework designed for achieving precise multi-label segmentation of vertebrae and intervertebral discs. The proposed framework integrates 3D transformers and 2D convolutional neural networks (CNN) to attain accurate and reliable segmentation outcomes. In the initial stage of the framework, 3D transformers are employed to generate preliminary probability graphs, thereby establishing a foundation for subsequent processing. Subsequently, in the second stage, the 2D MR sagittal slice and the corresponding rough probability graph derived from the 3D rough segmentation network are jointly inputted into the 2D network to achieve refined segmentation results with heightened precision. Notably, the integration of graph convolution modules within both the 3D and 2D networks plays a crucial role in addressing pertinent challenges associated with pixel labeling isolation, as well as rectifying errors pertaining to shape and positional segmentation outcomes. These modules contribute to the enhancement of segmentation accuracy by effectively resolving issues related to isolation and correction within the segmentation process. Through comprehensive comparisons with state-of-the-art spinal segmentation methodologies utilizing publicly available datasets, the proposed framework has exhibited superior performance, underscoring its efficacy and potential for advancing the field of spinal segmentation.