Keywords

1 Introduction

Transformers [89] are receiving a growing attention in the computer vision community. On the one hand, the transformer encoder, with multi-head self-attention as the central component, demonstrates a great potential for building powerful network architectures in various visual recognition tasks [32, 70, 93]. On the other hand, the transformer decoder, with multi-head cross-attention at its core, provides a brand-new approach to tackling complex visual recognition problems in an end-to-end manner, dispensing with hand-designed heuristics.

Recently, the pioneering work DETR [10] introduces the first end-to-end object detection system with transformers. In this framework, the pixel features are firstly extracted by a convolutional neural network [58], followed by the deployment of several transformer encoders for feature enhancement to capture long-range interactions between pixels. Afterwards, a set of learnable positional embeddings, named object queries, is responsible for interacting with pixel features and aggregating information through several interleaved cross-attention and self-attention modules. In the end, the object queries, decoded by a Feed-Forward Network (FFN), directly correspond to the final bounding box predictions. Along the same direction, MaX-DeepLab [92] proves the success of transformers in the challenging panoptic segmentation task [55], where the prior arts [21, 54, 100] usually adopt complicated pipelines involving hand-designed heuristics. The essence of this framework lies in converting the object queries to mask embedding vectors [49, 87, 97], which are employed to yield a set of mask predictions by multiplying with the pixel features.

The end-to-end transformer-based frameworks have been successfully applied to multiple computer vision tasks with the help of transformer decoders, especially the cross-attention modules. However, the working mechanism behind the scenes remains unclear. The cross-attention, which arises from the Natural Language Processing (NLP) community, is originally designed for language problems, such as neural machine translation [4, 86], where both the input sequence and output sequence share a similar short length. This implicit assumption becomes problematic when it comes to certain vision problems, where the cross-attention is performed between object queries and spatially flattened pixel features with an exorbitantly large length. Concretely, usually a small number of object queries is employed (e.g., 128 queries), while the input images can contain thousands of pixels for the vision tasks of detection and segmentation. Each object query needs to learn to highlight the most distinguishable features among the abundant pixels in the cross-attention learning process, which subsequently leads to slow training convergence and thus inferior performance [37, 112].

In this work, we make a crucial observation that the cross-attention scheme actually bears a strong similarity to the traditional k-means clustering [72] by regarding the object queries as cluster centers with learnable embedding vectors. Our examination of the similarity inspires us to propose the novel \(\textbf{k}\)-means Mask Xformer (kMaX-DeepLab), which rethinks the relationship between pixel features and object queries, and redesigns the cross-attention from the perspective of k-means clustering. Specifically, when updating the cluster centers (i.e., object queries), our kMaX-DeepLab performs a different operation. Instead of performing softmax on the large spatial dimension (image height times width) as in the original Mask Transformer’s cross-attention [92], our kMaX-DeepLab performs argmax along the cluster center dimension, similar to the k-means pixel-cluster assignment step (with a hard assignment). We then update cluster centers by aggregating the pixel features based on the pixel-cluster assignment (computed by their feature affinity), similar to the k-means center-update step. In spite of being conceptually simple, the modification has a striking impact: on COCO val set [66], using the standard ResNet-50 [41] as backbone, our kMaX-DeepLab demonstrates a significant improvement of 5.2% PQ over the original cross-attention scheme at a negligible cost of extra parameters and FLOPs. When comparing to state-of-the-art methods, our kMaX-DeepLab with the simple ResNet-50 backbone already outperforms MaX-DeepLab [92] with MaX-L [92] backbone by 1.9% PQ, while requiring 7.9 and 22.0 times fewer parameters and FLOPs, respectively. Our kMaX-DeepLab with ResNet-50 also outperforms MaskFormer [24] with the strong ImageNet-22K pretrained Swin-L [70] backbone, and runs 4.4 times faster. Finally, our kMaX-DeepLab, using the modern ConvNeXt-L [71] as backbone, sets a new state-of-the-art performance on the COCO val set [66] with 58.0% PQ. It also outperforms other state-of-the-art methods on the Cityscapes val set [28], achieving 68.4% PQ, 83.5% mIoU, 44.0% AP, without using any test-time augmentation or extra dataset pretraining [66, 75].

2 Related Works

Transformers.   Transformer [89] and its variants [2, 8, 26, 39, 57, 74, 94, 106] have advanced the state-of-the-art in natural language processing tasks [30, 31, 82] by capturing relations across modalities [4] or in a single context [25, 89]. In computer vision, transformer encoders or self-attention modules are either combined with Convolutional Neural Networks (CNNs) [9, 96] or used as standalone backbones [32, 44, 70, 80, 93]. Both approaches have boosted various vision tasks, such as image classification [7, 19, 32, 44, 64, 70, 80, 93, 101, 105], image generation [42, 77], object detection [10, 43, 80, 83, 96, 112], video recognition [3, 19, 33, 96], semantic segmentation [11, 17, 35, 46, 99, 108, 109, 111, 113], and panoptic segmentation [93].

Mask Transformers for Segmentation.   Besides the usage as backbones, transformers are also adopted as task decoders for image segmentation. MaX-DeepLab [92] proposed Mask Xformers (MaX) for end-to-end panoptic segmentation. Mask transformers predict class-labeled object masks and are trained by Hungarian matching the predicted masks with ground truth masks. The essential component of mask transformers is the conversion of object queries to mask embedding vectors [49, 87, 97], which are employed to generate predicted masks. Both Segmenter [85] and MaskFormer [24] applied mask transformers to semantic segmentation. K-Net [107] proposed dynamic kernels for generating the masks. CMT-DeepLab [104] proposed to improve the cross-attention with an additional clustering update term. Panoptic Segformer [65] strengthened mask transformer with deformable attention [112], while Mask2Former [23] further boosted the performance with masked cross-attention along with a series of technical improvements including cascaded transformer decoder, deformable attention [112], uncertainty-based pointly supervision [56], etc.. These mask transformer methods generally outperform box-based methods [54] that decompose panoptic segmentation into multiple surrogate tasks (e.g., predicting masks for each detected object bounding box [40], followed by fusing the instance segments (‘thing’) and semantic segments (‘stuff’) [14] with merging modules [60, 62, 67, 78, 100, 103]). Moreover, mask transformers showed great success in the video segmentation problems [20, 52, 61].

Clustering Methods for Segmentation.   Traditional image segmentation methods [1, 72, 110] typically cluster image intensities into a set of masks or superpixels with gradual growing or refinement. However, it is challenging for these traditional methods to capture high-level semantics. Modern clustering-based methods usually operate on semantic segments [13, 15, 18] and group ‘thing’ pixels into instance segments with various representations, such as instance center regression [22, 50, 63, 76, 88, 93, 102], Watershed transform [5, 90], Hough-voting [6, 59, 91], or pixel affinity [36, 47, 51, 69, 84].

Recently, CMT-DeepLab [104] discussed the similarity between mask transformers and clustering algorithms. However, they only used the clustering update as a complementary term in the cross-attention. In this work, we further discover the underlying similarity between mask transformers and the k-means clustering algorithm, resulting in a simple yet effective k-means mask transformer.

3 Method

In this section, we first overview the mask-transformer-based segmentation framework presented by MaX-DeepLab [92]. We then revisit the transformer cross-attention [89] and the k-means clustering algorithm [72], and reveal their underlying similarity. Afterwards, we introduce the proposed \(\textbf{k}\)-means Mask Xformer (kMaX-DeepLab), which redesigns the cross-attention from a clustering perspective. Even though simple, kMaX-DeepLab effectively and significantly improves the segmentation performance.

3.1 Mask-Transformer-Based Segmentation Framework

Transformers [89] have been effectively deployed to segmentation tasks. Without loss of generality, we consider panoptic segmentation [55] in the following problem formulation, which can be easily generalized to other segmentation tasks.

Problem Statement.   Panoptic segmentation aims to segment the image \(\textbf{I} \in \mathbb {R}^{H \times W \times 3}\) into a set of non-overlapping masks with associated semantic labels:

$$\begin{aligned} \{y_i\}_{i=1}^K = \{(m_i, c_i)\}_{i=1}^K \,. \end{aligned}$$
(1)

The K ground truth masks \(m_i \in {\{0,1\}}^{H \times W}\) do not overlap with each other, i.e., \(\sum _{i=1}^{K} m_i \le 1^{H \times W}\), and \(c_i\) denotes the ground truth class label of mask \(m_i\).

Starting from DETR [10] and MaX-DeepLab [92], approaches to panoptic segmentation shift to a new end-to-end paradigm, where the prediction directly matches the format of ground-truth with N masks (N is a fixed number and \(N\ge K\)) and their semantic classes:

$$\begin{aligned} \{\hat{y_i}\}_{i=1}^N = \{(\hat{m_i}, \hat{p}_{i}(c))\}_{i=1}^N, \end{aligned}$$
(2)

where \(\hat{p}_{i}(c)\) denotes the semantic class prediction confidence for the corresponding mask, which includes ‘thing’ classes, ‘stuff’ classes, and the void class \(\varnothing \).

The N masks are predicted based on the N object queries, which aggregate information from the pixel features through a transformer decoder, consisting of self-attention and cross-attention modules.

The object queries, updated by multiple transformer decoders, are employed as mask embedding vectors [49, 87, 97], which will multiply with the pixel features to yield the final prediction \(\textbf{Z} \in \mathbb {R}^{HW \times N}\) that consists of N masks. That is,

$$\begin{aligned} \textbf{Z}&= {\mathop {\hbox {softmax}}\limits _{N}}(\textbf{F} \times \textbf{C}^{\textrm{T}}), \end{aligned}$$
(3)

where \(\textbf{F} \in \mathbb {R}^{HW \times D}\) and \(\textbf{C} \in \mathbb {R}^{N \times D}\) refers to the final pixel features and object queries, respectively. D is the channel dimension of pixel features and object queries. We use underscript N to indicate the axis to perform softmax.

3.2 Relationship Between Cross-Attention and k-means Clustering

Although the transformer-based segmentation frameworks successfully connect object queries and mask predictions in an end-to-end manner, the essential problem becomes how to transform the object queries, starting from learnable embeddings (randomly initialized), into meaningful mask embedding vectors.

Cross-Attention.   The cross-attention modules are used to aggregate affiliated pixel features to update object queries. Formally, we have

$$\begin{aligned} \hat{\textbf{C}}&= \textbf{C} + {\mathop {\hbox {softmax}}\limits _{HW}}(\textbf{Q}^{c} \times (\textbf{K}^{p})^{\textrm{T}}) \times \textbf{V}^{p}, \end{aligned}$$
(4)

where \(\textbf{C} \in \mathbb {R}^{N \times D}\) refers to N object queries with D channels, and \(\hat{\textbf{C}}\) denotes the updated object queries. We use the underscript HW to represent the axis for softmax on spatial dimension, and superscripts p and c to indicate the feature projected from the pixel features and object queries, respectively. \(\textbf{Q}^c \in \mathbb {R}^{N \times D}, \textbf{K}^p \in \mathbb {R}^{HW \times D}, \textbf{V}^p \in \mathbb {R}^{HW \times D} \) stand for the linearly projected features for query, key, and value. For simplicity, we ignore the multi-head mechanism and feed-forward network (FFN) in the equation.

As shown in Eq. (4), when updating the object queries, a softmax function is applied to the image resolution (HW), which is typically in the range of thousands of pixels for the task of segmentation. Given the huge number of pixels, it can take many training iterations to learn the attention map, which starts from a uniform distribution at the beginning (as the queries are randomly initialized). Each object query has a difficult time to identify the most distinguishable features among the abundant pixels in the early stage of training. This behavior is very different from the application of transformers to natural language processing tasks, e.g., neural machine translation [4, 86], where the input and output sequences share a similar short length. Vision tasks, especially segmentation problems, present another challenge for efficiently learning the cross-attention.

Discussion.   Similar to cross-attention, self-attention needs to perform a softmax function operated along the image resolution. Therefore, learning the attention map for self-attention may also take many training iterations. An efficient alternative, such as axial attention [93] or local attention [70] is usually applied on high resolution feature maps, and thus alleviates the problem, while a solution to cross-attention remains an open question for research.

k-Means Clustering.   In Eq. (4), the cross-attention computes the affinity between object queries and pixels (i.e., \(\textbf{Q}^{c} \times (\textbf{K}^{p})^{\textrm{T}}\)), which is converted to the attention map through the spatial-wise softmax (operated along the image resolution). The attention map is then used to retrieve (and weight accordingly) affiliated pixel features to update the object queries. Surprisingly, we observe that the whole process is actually similar to the classic k-means clustering algorithm [72], which works as follows:

$$\begin{aligned} \textbf{A}&= {\mathop {\hbox {argmax}}\limits _{N}}(\textbf{C} \times \textbf{P}^{\textrm{T}}),\end{aligned}$$
(5)
$$\begin{aligned} \hat{\textbf{C}}&= \textbf{A} \times \textbf{P}, \end{aligned}$$
(6)

where \(\textbf{C}\in \mathbb {R}^{N \times D}\), \(\textbf{P}\in \mathbb {R}^{HW \times D}\), and \(\textbf{A}\in \mathbb {R}^{N \times HW}\) stand for cluster centers, pixel features, and clustering assignments, respectively.

Comparing Eq. (4),  Eq. (5), and Eq. (6), we notice that the k-means clustering algorithm is parameter-free and thus no linear projection is needed for query, key, and value. The updates on cluster centers are not in a residual manner. Most importantly, k-means adopts a cluster-wise argmax (i.e., argmax operated along the cluster dimension) instead of the spatial-wise softmax when converting the affinity to the attention map (i.e., weights to retrieve and update features).

This observation motivates us to reformulate the cross-attention in vision problems, especially image segmentation. From a clustering perspective, image segmentation is equivalent to grouping pixels into different clusters, where each cluster corresponds to a predicted mask. However, the cross-attention mechanism, also attempting to group pixels to different object queries, instead employs a different spatial-wise softmax operation from the cluster-wise argmax as in k-means. Given the success of k-means, we hypothesize that the cluster-wise argmax is a more suitable operation than the spatial-wise softmax regarding pixel clustering, since the cluster-wise argmax performs the hard assignment and efficiently reduces the operation targets from thousands of pixels (HW) to just a few cluster centers (N), which (we will empirically prove) speeds up the training convergence and leads to a better performance.

3.3 k-means Mask Transformer

Herein, we first introduce the crucial component of the proposed k-means Mask Transformer, i.e., k-means cross-attention. We then present its meta architecture and model instantiation.

k -means Cross-Attention.   The proposed k-means cross-attention reformulates the cross-attention in a manner similar to k-means clustering:

$$\begin{aligned} \hat{\textbf{C}}&= \textbf{C} + {\mathop {\hbox {argmax}}\limits _{N}}(\textbf{Q}^{c} \times (\textbf{K}^{p})^{\textrm{T}}) \times \textbf{V}^{p}. \end{aligned}$$
(7)

Comparing Eq. (4) and Eq. (7), the spatial-wise softmax is now replaced by the cluster-wise argmax. As shown in Fig. 1, with such a simple yet effective change, a typical transformer decoder could be converted to a kMaX decoder. Unlike the original cross-attention, the proposed k-means cross-attention adopts a different operation (i.e., cluster-wise argmax) to compute the attention map, and does not require the multi-head mechanism [89]. However, the cluster-wise argmax, as a hard assignment to aggregate pixel features for the cluster center update, is not a differentiable operation, posing a challenge during training. We have explored several methods (e.g., Gumbel-Softmax [48]), and discover that a simple deep supervision scheme turns out to be most effective. In particular, in our formulation, the affinity logits between pixel features and cluster centers directly correspond to the softmax logits of segmentation masks (i.e., \(\textbf{Q}^{c} \times (\textbf{K}^{p})^{\textrm{T}}\) in Eq. (7) corresponds to \(\textbf{F} \times \textbf{C}^{\textrm{T}}\) in Eq. (3)), since the cluster centers aim to group pixels of similar affinity together to form the predicted segmentation masks. This formulation allows us to add deep supervision to every kMaX decoder, in order to train the parameters in the k-means cross-attention module.

Fig. 1.
figure 1

To convert a typical transformer decoder into our kMaX decoder, we simply replace the original cross-attention with our k-means cross-attention (i.e., with the only simple change cluster-wise argmax high-lighted in red) (Color figure online)

Fig. 2.
figure 2

The meta architecture of k-means Mask Transformer consists of three components: pixel encoder, enhanced pixel decoder, and kMaX decoder. The pixel encoder is any network backbone. The enhanced pixel decoder includes transformer encoders to enhance the pixel features, and upsampling layers to generate higher resolution features. The series of kMaX decoders transform cluster centers into (1) mask embedding vectors, which multiply with the pixel features to generate the predicted masks, and (2) class predictions for each mask.

Meta Architecture.   Figure 2 shows the meta architecture of our proposed kMaX-DeepLab, which contains three main components: pixel encoder, enhanced pixel decoder, and kMaX decoder. The pixel encoder extracts the pixel features either by a CNN [41] or a transformer [70] backbone, while the enhanced pixel decoder is responsible for recovering the feature map resolution as well as enhancing the pixel features via transformer encoders [89] or axial attention [93]. Finally, the kMaX decoder transforms the object queries (i.e., cluster centers) into mask embedding vectors from the k-means clustering perspective.

Model Instantiation.   We build kMaX based on MaX-DeepLab [92] with the official code-base [98]. We divide the whole model into two paths: the pixel path and the cluster path, which are responsible for extracting pixel features and cluster centers, respectively. Figure 3 details our kMaX-DeepLab instantiation with two example backbones.

Pixel Path.   The pixel path consists of a pixel encoder and an enhanced pixel decoder. The pixel encoder is an ImageNet-pretrained [81] backbone, such as ResNet [41], MaX-S [92] (i.e., ResNet-50 with axial attention [93]), and ConvNeXt [71]. Our enhanced pixel decoder consists of several axial attention blocks [93] and bottleneck blocks [41].

Cluster Path.   The cluster path contains totally six kMaX decoders, which are evenly distributed among features maps of different spatial resolutions. Specifically, we deploy two kMaX decoders each for pixel features at output stride 32, 16, and 8, respectively.

Loss Functions.   Our training loss functions mostly follow the setting of MaX-DeepLab [92]. We adopt the same PQ-style loss, auxiliary semantic loss, mask-id cross-entropy loss, and pixel-wise instance discrimination loss [104].

Fig. 3.
figure 3

An illustration of kMaX-DeepLab with ResNet-50 and MaX-S as backbones. The hidden dimension of FFN is 256. The design of kMaX-DeepLab is general to different backbones by simply updating the pixel encoder (marked in dark-blue). The enhanced pixel decoder and kMaX decoder are colored in light-blue and yellow, respectively (Color figure online)

4 Experimental Results

In this section, we first provide our implementation details. We report our main results on COCO [66] and Cityscapes [28]. We also provide visualizations to better understand the clustering process of the proposed kMaX-DeepLab. The ablation studies are provided in the appendix.

4.1 Implementation Details

The meta architecture of the proposed kMaX-DeepLab contains three main components: the pixel encoder, enhanced pixel decoder, and kMaX decoder, as shown in Fig. 2. We provide the implementation details of each component below.

Pixel Encoder.   The pixel encoder extracts pixel features given an image. To verify the generality of kMaX-DeepLab across different pixel encoders, we experiment with ResNet-50 [41], MaX-S [92] (i.e., ResNet-50 with axial attention [93] in the 3rd and 4th stages), and ConvNeXt [71].

Enhanced Pixel Decoder.   The enhanced pixel decoder recovers the feature map resolution and enriches pixel features via self-attention. As shown in Fig. 3, we adopt one axial block with channels 2048 at output stride 32, and five axial blocks with channels 1024 at output stride 16. The axial block is a bottleneck block [41], but the \(3 \times 3\) convolution is replaced by the axial attention [93]. We use one bottleneck block at output stride 8 and 4, respectively. We note that the axial blocks play the same role (i.e., feature enhancement) as the transformer encoders in other works [10, 24, 104], where we ensure that the total number of axial blocks is six for a fair comparison to previous works [10, 24, 104].

Cluster Path.   As shown in Fig. 3, we deploy six kMaX decoders, where each two are placed for pixel features (enhanced by the pixel decoders) with output stride 32, 16, 8, respectively. Our design uses six transformer decoders, aligning with the previous works [10, 24, 104], though some recent works [23, 65] adopt more transformer decoders to achieve a stronger performance.

Training and Testing.   We mainly follow MaX-DeepLab [92] for training settings. The ImageNet-pretrained [81] backbone has a learning rate multiplier 0.1. For regularization and augmentations, we adopt drop path [45], random color jittering [29], and panoptic copy-paste augmentation, which is an extension from instance copy-paste augmentation [34, 38] by augmenting both ‘thing’ and ‘stuff’ classes. AdamW [53, 73] optimizer is used with weight decay 0.05. The k-means cross-attention adopts cluster-wise argmax, which aligns the formulation of attention map to segmentation result. It therefore allows us to directly apply deep supervision on the attention maps. These auxiliary losses attached to each kMaX decoder have the same loss weight of 1.0 as the final prediction, and Hungarian matching result based on the final prediction is used to assign supervisions for all auxiliary outputs. During inference, we adopt the same mask-wise merging scheme used in [24, 65, 104, 107] to obtain the final segmentation results.

COCO Dataset.   If not specified, we train all models with batch size 64 on 32 TPU cores with 150k iterations (around 81 epochs). The first 5k steps serve as the warm-up stage, where the learning rate linearly increases from 0 to \(5\times 10^{-4}\). The input images are resized and padded to \(1281\times 1281\). Following MaX-DeepLab [92], the loss weights for PQ-style loss, auxiliary semantic loss, mask-id cross-entropy loss, instance discrimination loss are 3.0, 1.0, 0.3, and 1.0, respectively. The number of cluster centers (i.e., object queries) is 128, and the final feature map resolution has output stride 4 as in MaX-DeepLab [92].

We have also experimented with doubling the number of object queries to 256 for kMaX-DeepLab with ConvNeXt-L, which however leads to a performance loss. Empirically, we adopt a drop query regularization, where we randomly drop half of the object queries (i.e., 128) during each training iteration, and all queries (i.e., 256) are used during inference. With the proposed drop query regularization, doubling the number of object queries to 256 consistently brings 0.1% PQ improvement under the large model regime.

Cityscapes Dataset.   We train all models with batch size 32 on 32 TPU cores with 60k iterations. The first 5k steps serve as the warm-up stage, where learning rate linearly increases from 0 to \(3\times 10^{-4}\). The inputs are padded to \(1025\times 2049\). The loss weights for PQ-style loss, auxiliary semantic loss, mask-id cross-entropy loss, and instance discrimination loss are 3.0, 1.0, 0.3, and 1.0, respectively. We use 256 cluster centers, and add an additional bottleneck block in the pixel decoder to produce features with output stride 2.

4.2 Main Results

Our main results on the COCO [66] and Cityscapes [28] val set are summarized in Table 1 and Table 2, respectively.

Table 1. COCO val set results. Our FLOPs and FPS are evaluated with the input size \(1200\times 800\) and a Tesla V100-SXM2 GPU. \(\dagger \): ImageNet-22K pretraining. \(\star \): Using 256 object queries with drop query regularization. \(\ddagger \): Using COCO unlabeled set

COCO val Set.   In Table 1, we compare our kMaX-DeepLab with other transformer-based panoptic segmentation methods on COCO val set. Notably, with a simple ResNet-50 backbone, kMaX-DeepLab already achieves 53.0% PQ, surpassing most prior arts with stronger backbones. Specifically, kMaX-DeepLab outperforms MaskFormer [24] and K-Net [107], all with the ResNet-50 backbone as well, by a large margin of 6.5% and 5.9%, while maintaining a similar level of computational costs. Our kMaX-DeepLab with ResNet-50 even surpasses the largest variants of MaX-DeepLab [92] by 1.9% PQ (while using 7.9\(\times \) fewer parameters and 22.0\(\times \) fewer FLOPs), and MaskFormer (while using 3.7\(\times \) fewer parameters and 4.7\(\times \) fewer FLOPs) by 0.3% PQ, respectively. With a stronger backbone MaX-S [92], kMaX-DeepLab boosts the performance to 56.2% PQ, outperforming MaX-DeepLab with the same backbone by 7.8% PQ. Our kMaX-DeepLab with MaX-S backbone also improves over the previous state-of-art K-Net with Swin-L [70] by 1.6% PQ. To further push the envelope, we adopt the modern CNN backbone ConvNeXt [71] and set new state-of-the-art results of 57.2% PQ with ConvNeXt-B and 58.0% PQ with ConvNeXt-L, outperforming K-Net with Swin-L by a significant margin of 3.4% PQ.

When compared to more recent works (CMT-DeepLab [104], Panoptic SegFormer [65], and Mask2Former [23]), kMaX-DeepLab still shows great performances without the advanced modules, such as deformable attention [112], cascaded transformer decoder [23], and uncertainty-based pointly supervision [56]. As different backbones are utilized for each method (e.g., PVTv2 [95], Swin [70], and ConvNeXt [71]), we start with a fair comparison using the ResNet-50 backbone. Our kMaX-DeepLab with ResNet-50 achieves a significant better performance compared to CMT-DeepLab, Panoptic SegFormer and Mask2Former by a large margin of 4.5%, 3.4%, and 1.1% PQ, respectively. Additionally, our model runs almost 3\(\times \) faster than them (since kMaX-DeepLab enjoys a simple design without deformable attention). When employing stronger backbones, kMaX-DeepLab with ConvNeXt-B outperforms CMT-DeepLab with Axial-R104, Panoptic SegFormer with PVTv2-B5, and Mask2Former with Swin-B (window size 12) by 3.1%, 1.8%, and 0.8% PQ, respectively, while all models have a similar level of cost (parameters and FLOPs). When scaling up to the largest backbone for each method, kMaX-DeepLab outperforms CMT-DeepLab, and Panoptic SegFormer significantly by 2.7% and 2.2% PQ. Although we already perform better than Mask2Former with Swin-L (window size 12), we notice that kMaX-DeepLab benefits much less than Mask2Former when scaling up from base model to large model (+0.7% for kMaX-DeepLab but +1.4% for Mask2Former), indicating kMaX-DeepLab’s strong representation ability and that it may overfit on COCO train set with the largest backbone. Therefore, we additionally perform a simple experiment to alleviate the over-fitting issue by generating pseudo labels [12] on COCO unlabeled set. Adding pseudo labels to the training data slightly improves kMaX-DeepLab, yielding a PQ score of 58.1% (the drop query regularization is not used here and the number of object query remains 128).

Cityscapes val Set.   In Table 2, we compare our kMaX-DeepLab with other state-of-art methods on Cityscapes val set. Our reported PQ, AP, and mIoU results use the same panoptic model to provide a comprehensive comparison. Notably, kMaX-DeepLab with ResNet-50 backbone already surpasses most baselines, while being more efficient. For example, kMaX-DeepLab with ResNet-50 achieves 1.3% PQ higher performance compared to Panoptic-DeepLab [22] (Xception-71 [27] backbone) with 20% computational cost (FLOPs) reduced. Moreover, it achieves a similar performance to Axial-DeepLab-XL [93], while using 3.1\(\times \) fewer parameters and 5.6\(\times \) fewer FLOPs. kMaX-DeepLab achieves even higher performances with stronger backbones. Specifically, with MaX-S backbone, it performs on par with previous state-of-the-art Panoptic-DeepLab with SWideRNet [16] backbone, while using 7.2\(\times \) fewer parameters and 17.2\(\times \) fewer FLOPs. Additionally, even only trained with panoptic annotations, our kMaX-DeepLab also shows superior performance in instance segmentation (AP) and semantic segmentation (mIoU). Finally, we provide a comparison with the recent work Mask2Former [23], where the advantage of our kMaX-DeepLab becomes even more significant. Using the ResNet-50 backbone for a fair comparison, kMaX-DeepLab achieves 2.2% PQ, 1.2% AP, and 2.2% mIoU higher performance than Mask2Former. For other backbone variants with a similar size, kMaX-DeepLab with ConvNeXt-B is 1.9% PQ higher than Mask2Former with Swin-B (window size 12). Notably, kMaX-DeepLab with ConvNeXt-B already obtains a PQ score that is 1.4% higher than Mask2Former with their best backbone. With ConvNeXt-L as backbone, kMaX-DeepLab sets a new state-of-the-art record of 68.4% PQ without any test-time augmentation or COCO [66]/Mapillary Vistas [75] pretraining.

Table 2. Cityscapes val set results. We only consider methods without extra data [66, 75] and test-time augmentation for a fair comparison. We evaluate FLOPs and FPS with the input size \(1025\times 2049\) and a Tesla V100-SXM2 GPU. Our instance (AP) and semantic (mIoU) results are based on the same panoptic model (i.e., no task-specific fine-tuning). \(\dagger \): ImageNet-22K pretraining
Fig. 4.
figure 4

Visualization of kMaX-DeepLab (ResNet-50) pixel-cluster assignments at each kMaX decoder stage, along with the final panoptic prediction. In the cluster assignment visualization, pixels with same color are assigned to the same cluster and their features will be aggregated for updating corresponding cluster centers

Visualizations.   In Fig. 4, we provide a visualization of pixel-cluster assignments at each kMaX decoder and final prediction, to better understand the working mechanism behind kMaX-DeepLab. Another benefit of kMaX-DeepLab is that with the cluster-wise argmax, visualizations can be directly drawn as segmentation masks, as the pixel-cluster assignments are exclusive to each other with cluster-wise argmax. Noticeably, the major clustering update happens in the first three stages, which already updates cluster centers well and generates reasonable clustering results, while the following stages mainly focus on refining details. This coincides with our observation that 3 kMaX decoders are sufficient to produce good results. Besides, we observe that 1st clustering assignment tends to produce over-segmentation effects, where many clusters are activated and then combined or pruned in the later stages. Moreover, though there exist many fragments in the first round of clustering, it already surprisingly distinguishes different semantics, especially some persons are already well clustered, which indicates that the initial clustering is not only based on texture or location, but also depends on the underlying semantics. Another visualization is shown in Fig. 5, where we observe that kMaX-DeepLab behaves in a part-to-whole manner to capture an instance. More experimental results (e.g., ablation studies, test set results) and visualizations are available in the appendix.

Fig. 5.
figure 5

Visualization of kMaX-DeepLab (ResNet-50) pixel-cluster assignments at each kMaX decoder stage, along with the final panoptic prediction. kMaX-DeepLab shows a behavior of recognizing objects starting from their parts to their the whole shape in the clustering process. For example, the elephant’s top head, body, and nose are separately clustered at the beginning, and they are gradually merged in the following stages

5 Conclusion

In this work, we have presented a novel end-to-end framework, called k-means Mask Transformer (kMaX-DeepLab), for segmentation tasks. kMaX-DeepLab rethinks the relationship between pixel features and object queries from the clustering perspective. Consequently, it simplifies the mask-transformer model by replacing the multi-head cross attention with the proposed single-head k-means clustering. We have tailored the transformer-based model for segmentation tasks by establishing the link between the traditional k-means clustering algorithm and cross-attention. We hope our work will inspire the community to develop more vision-specific transformer models.