1 Introduction

Deep Neural Networks (DNNs) have been successfully used in many real-world applications (Amodei et al., 2016; He et al., 2016). However, the training of DNNs relies on large-scale & high-quality datasets, which has become the core problem in practice. First, large-scale data annotation is highly expensive and only a small amount of labels can be accessed (Li et al., 2019; van Engelen & Hoos, 2020). Second, collected data usually follows long-tailed distribution (He & Garcia, 2009; Liu et al., 2019; Wei & Li, 2019; Wei et al., 2021b, 2022), where only some classes (the majority class) have sufficient training samples while other classes (the minority class) own a few samples as shown in Fig. 1a.

To utilize unlabeled data, Semi-Supervised Learning (SSL) emerges as an interesting solution (Miyato et al., 2018; Tarvainen & Valpola, 2017; Berthelot et al., 2019b; Sohn et al., 2020; Berthelot et al., 2019a; Zhou et al., 2021; Guo et al., 2020). It carries out model assumptions on the data distribution to build a learner to utilize unlabeled samples through selecting confident pseudo-labels. However, it is demonstrated that existing SSL methods tend to produce biased pseudo-labels towards the majority class (Kim et al., 2020), leading to undesirable performance.

Recently, Long-Tailed Semi-Supervised Learning (LTSSL) is proposed to improve the performance of SSL models on long-tailed data. The main ideas of existing LTSSL methods (Kim et al., 2020; Wei et al., 2021a; Lee et al., 2021) are two-fold. One is to improve the quality of pseudo-labels from the perspective of SSL. The other one is to employ class-balanced sampling or post-hoc classifier adjustment to alleviate class imbalance from the long tail perspective. These methods can improve the performance of conventional SSL models. However, the improvements are achieved with the cost of high computational overhead or losing information due to the undersampling of data.

How all data can be efficiently and effectively utilized is the core challenge of LTSSL and the focus of this paper. To this end, we propose a new method called TRAS (TRAnsfer and Share) which has two key ingredients. Figure 1b showcases the effectiveness of TRAS in the minority class.

Fig. 1
figure 1

a Long-tailed distribution of the training set under the main setting of CIFAR-10-LT. b Performance of minority-class accuracy(%) on CIFAR-10-LT dataset under class imbalance ratio 50, 100, and 150 with 20% of labels available. The proposed TRAS heavily improves the minority-class accuracy

First, we compensate for the minority-class training by generating a more balanced pseudo-label distribution. Under the guidance of pseudo-label distribution, DNNs can mine the interaction information between classes to obtain richer information for minority classes. The idea of learning from label distribution has been explored in previous literature, such as label distribution learning (Geng, 2016; Gao et al., 2017; Wang & Geng, 2019) and knowledge distillation (Xiang et al., 2020; He et al., 2021; Iscen et al., 2021), which however is still underexplored in LTSSL. To generate label distribution, knowledge distillation is a common approach via well-trained teacher models. Such a teacher model is not always available for SSL models because of limited long-tailed labeled data and high computation overhead. Alternatively, we employ a conventional SSL model with normally low accuracy on the minority class. This conventional SSL model is able to teach the learning of the student model after applying our proposed logit transformation. This transformation is particularly designed to enhance the minority-class supervisory signals without introducing extra computational cost. Subsequently, through training the student model by imitating the enhanced supervisory signals, the minority class will receive significant attention.

Second, we propose to merge the training of teacher and student models as a single procedure to reduce the computational cost. To this end, we use a double-branch neural network with a shared feature extractor and two classifiers for producing the predictions of the teacher and student. The neural network is then trained in an end-to-end way by a joint objective of these two classifiers. In addition to reduce training cost and simplify the approach, we empirically find that both classifiers can help improve the representation learning and learn clear classification boundaries between classes.

Our main contributions are summarized as follows:

  1. 1.

    A new LTSSL method TRAS is proposed, which significantly improves the minority-class training without introducing extra training cost.

  2. 2.

    TRAS transfers pseudo-label distribution from a vanilla SSL network (teacher) to another network (student) via a new logit transformation, instead of trying hard to construct a sophisticated LTSSL teacher model.

  3. 3.

    TRAS reveals the importance of the balancedness of pseudo-label distribution in transfer for LTSSL.

  4. 4.

    TRAS merges the training of teacher and student models by sharing the feature extractor, which simplifies the training procedure and benefits the representation learning.

  5. 5.

    TRAS achieves state-of-the-art performance in various experiments. Particularly, it improves minority-class performance by about 7% in accuracy.

2 Related work

2.1 Semi-supervised learning

Existing SSL methods aim to use unlabeled data to improve the generalization. For this purpose, consistency regularization and entropy minimization have become the most frequently used techniques and demonstrate considerable performance improvements. Specifically, Mean-Teacher (Tarvainen & Valpola, 2017) imposes consistency regularization between the prediction of the current model and the self-ensembled model obtained using exponential moving average. Virtual Adversarial Training (VAT) (Miyato et al., 2018) encourages the model to minimize the discrepancy of model predictions for unlabeled data before and after applying adversarial perturbation. MixMatch (Berthelot et al., 2019b) minimizes the entropy of model predictions by sharpening the pseudo-label distribution. ReMixMatch (Berthelot et al., 2019a) improves MixMatch by imposing another distribution alignment regularizer and augmentation anchoring. FixMatch (Sohn et al., 2020) merges consistency regularization and entropy minimization by regularizing the prediction for weakly augmented and strongly augmented unlabeled data. However, the above-mentioned methods assume both labeled and unlabeled data is both class-balanced, leading to poor performance on the minority class when working on long-tailed datasets.

2.2 Long-tailed semi-supervised learning

To deal with long-tailed datasets, several LTSSL methods have been proposed. In a nutshell, exiting methods aim to select not only confident but also more class-balanced pseudo-labels to improve the generalization for minority classes. For instance, DARP (Kim et al., 2020) proposes to estimate the underlying class distribution of unlabeled data, which is used to regularize the distribution of pseudo-labels. To this end, a convex optimization problem is solved. Additionally, CReST (Wei et al., 2021a) proposes to use class-aware confidence thresholds for selecting more pseudo-labels for the minority class. Recently, ABC (Lee et al., 2021) proposes to use an auxiliary balanced classifier built upon a conventional SSL model by class-balanced undersampling. However, these approaches either suffer from high computational cost or loss of supervisory information. In this work, we propose a new algorithm TRAS, which can fully utilize not only supervised data but also unsupervised data through efficient pseudo-label distribution transfer, and greatly improves the performance of the minority class.

3 Method: TRAS

We now introduce the problem setting in Sect. 3.1 and develop our proposed method TRAS, which consists of two key ingredients described in Sects. 3.2 and 3.3. Figure 2 shows the framework of the proposed TRAS.

3.1 Problem setting

Let \({\mathcal {X}}=\{({{\varvec{x}}}_i,y_i)\}_{i=1}^N\) be a labeled dataset, where \({{\varvec{x}}}_i \in {\mathbb {R}}^d\) is a training example and \(y_i \in {\mathbb {R}}\) is the corresponding label. We introduce an unlabeled dataset \({\mathcal {U}} = \{{{\varvec{u}}}_i\}_{i=1}^M\) where \({{\varvec{u}}}_i \in {\mathbb {R}}^d\) is the unlabeled data point. Following ABC (Lee et al., 2021), we assume that the class distributions of \({\mathcal {X}}\) and \({\mathcal {U}}\) are identical. We denote the number of labeled data points of class l as \(N_l\) (notice that \(\textstyle \sum _{l=1}^L N_l=N\)), assuming that all classes are sorted by cardinality in descending order \(N_1 \>N_2 \>\cdots \>N_L\). In LTSSL, we set the fraction of labeled data as \(\beta =\frac{N}{N+M}\) and the class imbalance ratio as \(\gamma =\frac{N_1}{N_L}\). Usually, we divide the class space into the majority class and the minority class according to their frequencies in the training data. Our goal is to learn a model which generalizes well on both the majority class and the minority class.

Our proposed method, TRAS, consists of a shared feature extractor and two classifiers, providing predictions for the teacher model \(P^T(y\mid {{\varvec{x}}})\) and student model \(P^S(y\mid {{\varvec{x}}})\). There are two key ingredients to TRAS: (1) Learn through imitation, in which the student model imitates the adjusted output of the teacher model, and (2) transfer via sharing weights. In the following, we present technical details of these two ingredients.

Fig. 2
figure 2

The TRAS method in diagrammatic form

3.2 Ingredient #1: learn through imitation

Given labeled data, a typical approach is to train a classifier f by optimizing the softmax cross-entropy:

$$\begin{aligned} \ell _{\text {CE}}(y, f({{\varvec{x}}}))=-\log \frac{e^{f_{y}({{\varvec{x}}})}}{\sum _{y^{\prime } \in [L]} e^{f_{y^{\prime }}({{\varvec{x}}})}}. \end{aligned}$$
(1)

In LTSSL, however, the distribution of labeled data is heavily class-imbalanced, such that the learned classifier would be biased towards the majority class. To improve the training of the minority class, we propose to use the distribution-aware cross-entropy loss:

$$\begin{aligned} \ell _{\text {DA-CE}}(y, f({{\varvec{x}}}))=-\log \frac{e^{f_{y}({{\varvec{x}}})+\tau \cdot \log \pi _{y}}}{\sum _{y^{\prime } \in [L]} e^{f_{y^{\prime }}({{\varvec{x}}})+\tau \cdot \log \pi _{y^{\prime }}}}, \end{aligned}$$
(2)

where \(\pi _y\) is the estimate of class prior \({\mathbb {P}}(y)\) and \(\tau >0\) is a scaling parameter. By minimizing \(\ell _{\text {DA-CE}}\), it encourages large margins between the true label and other negative labels. Using distribution-aware cross-entropy is not a new idea in the literature of long-tailed learning, such as Logit Adjustment (Menon et al., 2020) and Balanced Softmax (Ren et al., 2020). Interestingly, existing methods show that the scaling parameter \(\tau\) plays an important role in model training, but it is usually used as a constant, e.g., \(\tau =1\). In the following, we show a new instance-dependent logit scaling method.

In addition to labeled data, we can access to a large amount of unlabeled data to help improve the generalization. In LTSSL, the underlying distribution of unlabeled data is also long-tailed, and conventional SSL methods have shown impaired performance on the minority class. This paper proposes to train the model using pseudo-label distribution, rather than biased one-hot pseudo-labels. Intuitively, label distribution offers more supervisory signals and can benefit the minority-class training. We generate pseudo-label distribution by first training a vanilla SSL model as the teacher, and then training a student model by imitating the output distribution of teacher model. We opt for minimizing their Kullback-Leibler (KL) divergence:

$$\begin{aligned} \ell _{\text {KL}}\left( {\tilde{{{\varvec{y}}}}}^{T} , {\tilde{{{\varvec{y}}}}}^{S}\right) = \sum _{l=1}^{L} {\tilde{y}}^{T}_{l} \log \frac{{\tilde{y}}^{T}_{l}}{{\tilde{y}}^{S}_{l}}, \end{aligned}$$
(3)

where \({\tilde{{{\varvec{y}}}}}^T\) and \({\tilde{{{\varvec{y}}}}}^S\) are output probabilities of the teacher and student model respectively, which illustrate the implicit information of label distribution.

Note that the teacher model is trained via a conventional SSL algorithm and the produced pseudo-label distribution is still biased towards the majority class. To further enhance supervisory signals for the minority class, we present a new logit transformation to adjust the output of the teacher model. Specifically, for sample \({{\varvec{x}}}\), we transform its pseudo-label distribution as follows:

$$\begin{aligned} {\tilde{{{\varvec{y}}}}}^T = \text {softmax}\left( \phi \left( {{\varvec{z}}}^T\right) \right) = \text {softmax} \left( {{\varvec{z}}}^T -\tau \left( {\hat{y}}\right) \cdot \log \varvec{\pi } \right) , \end{aligned}$$
(4)

where \({{\varvec{z}}}^T\) is the output logits and \({\hat{y}}\) is the pseudo-label of \({{\varvec{x}}}\). In this way, the pseudo-label distribution of unlabeled data is more balanced. We demonstrate the generated label distribution in Fig. 3.

Notably, different from previous works that treat \(\tau\) as a constant to scale the output logits, we use \(\tau\) as a function of pseudo-labels. Concretely, given the pseudo-label \({\hat{y}}\), we define \(\tau ({\hat{y}}) =A \cdot \alpha _{{\hat{y}}} + B\), where \(\varvec{\alpha } = \text {softmax} (-\log \varvec{\pi })\) is a \({\hat{y}}\)-dependent function, A and B are constants. This is because adjusting pseudo-label distribution to over-compensate the minority class can be harmful to the majority class. By employing the \({\hat{y}}\)-dependent logit transformation function, we can alleviate this problem by flattening the label distribution of predicted minority-class samples more aggressively than other samples. In experiments, we simply set \(A=B=2\). Applying the proposed logit transformation generates a more balanced pseudo-label distribution to improve the training of the minority class as in Fig. 3.

Fig. 3
figure 3

Comparison of ground-truth label distribution and our generated pseudo-label distribution on CIFAR-100-LT dataset under class imbalance ratio 20 with 40% of labels available

Putting together the objectives for labeled and unlabeled data, we minimize the loss function for TRAS as follows:

$$\begin{aligned} {\mathcal {L}}_{{{\textsc {TRAS}}}} = \underbrace{\sum _{i=1}^{N} \ell _{\text {DA-CE}}\left( y_i, {{\varvec{z}}}^S_i\right) }_{\text {supervised\ loss}} + \underbrace{\sum _{j=1}^{M} {\mathbb {I}}\left( \max \left( {\tilde{{{\varvec{y}}}}}^S_j\right) \ge t\right) \ell _{\text {KL}}\left( {\tilde{{{\varvec{y}}}}}^T_j, {\tilde{{{\varvec{y}}}}}^S_j\right) }_{\text {unsupervised\ loss}}. \end{aligned}$$
(5)

Here, \({\mathbb {I}}(\cdot )\) represents the indicator function, t denotes the confidence threshold and we adopt the common setup \(t = 0.95\) for confident pseudo-labels from the student.

In this way, pseudo-label distribution can naturally describe the implicit information between labels. By applying the logit transformation, the distribution encodes more informative supervisory signals for the minority class, thus the student can alleviate data scarcity for minority classes.

3.3 Ingredient #2: transfer via sharing weights

Learning through imitating the teacher model can significantly compensate for the training of the minority class, however, it requires to train two separate DNNs sequentially, which is computational expensive in SSL.

To reduce the time consumption and simplify the approach, we propose to merge the training of teacher and student models into a single training procedure. In other words, the teacher and student share the feature extractor network. We further partition the parameter space into three disjoint subsets; (1) Let \(\psi ({{\varvec{x}}})\) be a feature extractor for \({{\varvec{x}}}\). (2) Let \(f^T(\psi ({{\varvec{x}}}))\) denote a teacher classifier and its prediction \({\tilde{{{\varvec{y}}}}}^T\). (3) Similarly, let \(f^S(\psi ({{\varvec{x}}}))\) denote a student classifier and its prediction \({\tilde{{{\varvec{y}}}}}^S\). Subsequently, let us define:

$$\begin{aligned} {{\varvec{z}}}^T = \text {stop\_gradient}\left( f^T\left( \psi ({{\varvec{x}}})\right) \right) , \end{aligned}$$
(6)

which is the output logits of the teacher model except that its gradient will not be calculated to update the teacher model’s classifier weights. Recall that function \(\phi (\cdot )\) acts as a logit transformer of \({{\varvec{z}}}^T\), we then consider:

$$\begin{aligned} {\mathcal {L}}_{{{\textsc {TRAS}}}} = \underbrace{\sum _{i=1}^{N} \ell _{\text {DA-CE}}\left( y_i, {{\varvec{z}}}^S_i\right) }_{\text {supervised\ loss}} + \underbrace{\sum _{j=1}^{M} {\mathbb {I}}\left( \max \left( {\tilde{{{\varvec{y}}}}}^S_j\right) \ge t\right) \ell _{\text {KL}}\left( \text {softmax} \left( \phi \left( {{\varvec{z}}}^T_j\right) \right) , {\tilde{{{\varvec{y}}}}}^S_j\right) }_{\text {unsupervised\ loss}}, \end{aligned}$$
(7)

as the joint objective. Note that the teacher and student share a single feature extractor, it only adds a linear classifier to the conventional SSL model, which incurs negligible training cost.

Let \({\mathcal {L}}_{\text {SSL}}\) denote the loss for a conventional SSL method, the total loss function that TRAS optimizes is:

$$\begin{aligned} {\mathcal {L}}_{\text {Total}} = {\mathcal {L}}_{{{\textsc {TRAS}}}} + {\mathcal {L}}_{\text {SSL}}. \end{aligned}$$
(8)

Particularly, if FixMatch is employed as the teacher model, \({\mathcal {L}}_{\text {SSL}}\) consists of a cross-entropy loss on labeled data and a consistency regularization on unlabeled data. Specifically, we have:

$$\begin{aligned} {\mathcal {L}}_{\text {SSL}} = \underbrace{\sum _{i=1}^{N} \ell _{\text {CE}}\left( y_i, {{\varvec{z}}}^T_i\right) }_{\text {supervised\ loss}} + \underbrace{\sum _{j=1}^{M} {\mathbb {I}}\left( \max \left( {{\varvec{z}}}^T_j\right) \ge t\right) \ell _{\text {CE}}\left( {\hat{y}}_j, {\tilde{{{\varvec{z}}}}}^T_j\right) }_{\text {unsupervised\ loss}}, \end{aligned}$$
(9)

where \({{\varvec{z}}}^T\) and \({\tilde{{{\varvec{z}}}}}^T\) are the output logits for weak and strong data augmentation, \({\hat{y}}=\arg \max _l z^T_l\) represents the pseudo-label for unlabeled data. In inference, we use the student classifier \(f^S\) to predict the label.

3.4 Connection to previous work

One may note that the basic idea of our TRAS can transfer knowledge distribution from a vanilla teacher model to a student model that has good generalization for the minority class. The technique is related to knowledge distillation which has been explored in some recent long-tailed learning works. For instance, LFME (Xiang et al., 2020) proposes to train the student model via distilling multiple teachers trained on less imbalanced datasets. DiVE (He et al., 2021) shows that flattening the output distribution of the teacher model using a constant temperature parameter can help the learning of minority classes. CBD (Iscen et al., 2021) distills features from the teacher to the student and shows that it can improve the learned representation of the minority class. Last but not least, xERM (Zhu et al., 2022) obtains an unbiased model by properly adjusting the weights between empirical loss and knowledge distillation loss.

In contrast to previous works that aim to solve supervised long-tailed learning, this paper studies semi-supervised long-tailed learning, where the amount of labeled data is much more limited. Moreover, previous works need to train teacher models via well-established long-tailed learning methods. However, our method TRAS only needs a vanilla SSL model as a teacher. Additionally, these methods have multiple-stage training procedures, but our method is simpler and can be trained in an end-to-end way.

4 Experiments

We conduct experiments on long-tailed version of CIFAR-10, CIFAR-100, and SVHN, in comparison with state-of-the-art LTSSL methods. We then perform hyper-parameter sensitivity studies and ablation studies to better understand our proposed TRAS.

4.1 Experimental setup

4.1.1 Datasets

We conduct experiments on common datasets long-tailed CIFAR-10 (CIFAR-10-LT), long-tailed CIFAR-100 (CIFAR-100-LT) and long-tailed SVHN (SVHN-LT) to evaluate our method. Without loss of generality, for imbalanced SSL settings, we randomly resample the datasets to meet the assumption that the distribution of labeled and unlabeled samples is consistent. We set the ratio of the class imbalance as \(\gamma\) (\(\gamma =\frac{N_1}{N_L}\)) and the number of labeled data points of class l as \(N_l\), where \(N_l= N_1 * \gamma ^{-{{l-1}\over {L-1}}}\) and \(M_l\) for the unlabeled. Specifically, we set \(N_1+M_1=5000\), \(L = 10\) for CIFAR-10-LT and SVHN-LT, \(N_1+M_1=500\), \(L = 100\) for CIFAR-100-LT respectively.

Following the previous work (Lee et al., 2021), we evaluate the classification performance with imbalance ratio \(\gamma\) = 100 and 150 for CIFAR-10-LT and SVHN-LT and \(\gamma\) = 20 and 30 for CIFAR-100-LT. The ratio of labeled data \(\beta\) is 10%, 20% and 30% for CIFAR-10-LT and SVHN-LT, 20%, 40% and 50% for CIFAR-100-LT. Since the test set remains balanced, overall accuracy, minority-class accuracy, and Geometric Mean scores (GM) (Branco et al., 2016) with class-wise sensitivity are three main metrics to validate the proposed method.

4.1.2 Setup

We implement our method with FixMatch over the backbone of Wide ResNet-28-2 (Zagoruyko & Komodakis, 2016). Our method is compared with the supervised baseline, long-tailed supervised learning methods, and long-tailed semi-supervised learning methods, denoted by (a) Vanilla; (b) VAT (Miyato et al., 2018) and FixMatch (Sohn et al., 2020; c) BALMS (Ren et al., 2020), classifier Re-Training (cRT) (Kang et al., 2020; d) DARP (Kim et al., 2020), CReST (Wei et al., 2021a), ABC (Lee et al., 2021). We set the hyper-parameters by following FixMatch and train the neural networks for 500 epochs with 500 mini-batches in each epoch, with the batch size of 64, using Adam optimizer (Kingma & Ba, 2015). The learning rate is 0.002 with a decay rate of 0.999. We start optimizing TRAS after training FixMatch for 10 epochs. For all experiments, we report the mean and standard deviation of test accuracy over multiple runs.

4.2 Experimental results

First, the performance of the algorithms compared under the main setting is in Table 1. Results of related methods are borrowed from ABC (Lee et al., 2021). It can been see that our method achieves the best performance, and the improvement on the minority class is impressive. It is known that normal SSL methods such as VAT and FixMatch perform unsatisfactorily on the minority class because pseudo-labels of unlabeled data are affected by the biased model thus hindering the learning of minority classes. Our method significantly improves the performance on the minority class by exploiting knowledge transfer to generate balanced label distribution, which conveys more implicit information than the one-hot pseudo-labels used in most previous LTSSL works. Moreover, our standard deviation is lower than other LTSSL methods, showing the superior stability of TRAS.

Table 1 Overall accuracy(%)/minority-class accuracy(%) under the main setting

To further validate the effectiveness of our method, we report the performance on various settings. The results on CIFAR-10-LT, SVHN-LT and CIFAR-100-LT are reported in Tables 2, 3 and 4. TRAS outperforms other methods in all cases with respect to both overall accuracy and minority-class accuracy. Particularly, TRAS achieves about 10%, 5%, 7% improvements in the minority class on three datasets. Moreover, TRAS is more robust to class imbalance. As the imbalance ratio increases, existing methods severely deteriorate their performance, while the accuracy of our method drops slightly.

Table 2 Overall accuracy(%)/minority-class accuracy(%) for CIFAR-10-LT. Two imbalance ratios \(\gamma\) and three labeled data ratios \(\beta\) are evaluated
Table 3 Overall accuracy(%)/minority-class accuracy(%) on SVHN-LT. Two imbalance ratios \(\gamma\) and three labeled data ratios \(\beta\) are evaluated
Table 4 Overall accuracy(%)/minority-class accuracy(%) on CIFAR-100-LT. Two imbalance ratios \(\gamma\) and three labeled data ratios \(\beta\) are evaluated

To evaluate whether our method TRAS performs balanced prediction for all classes, we measure its performance using Geometric Mean scores (GM) of class-wise accuracy. The results in Table 5 demonstrate that the proposed algorithm yields the best and most balanced performance in all classes. Additionally, TRAS achieves more significant performance improvement on the large dataset (CIFAR-100-LT).

Table 5 Results of GM(%) under the main setting

Additionally, we also evaluate TRAS on the more practical and challenging ImageNet127 dataset. ImageNet127 was crafted by CReST (Wei et al., 2021a) for LTSSL. It is a naturally imbalanced dataset with imbalance ratio 286 by grouping the 1000 classes of ImageNet into 127 classes based on the WordNet hierarchy. Due to limited resources, we are not able to conduct experiments on ImageNet127 with the full resolution.Footnote 1 Instead, we follow CoSSL (Fan et al., 2022) which down-samples the original images from ImageNet127 to smaller images of \(32 \times 32\) or \(64 \times 64\) pixels using the box method from Pillow library. We also randomly select 10% training samples as the labeled data. Since the test set is imbalanced as well, averaged class recall is applied as a balanced metric. We compare our method with the recent state-of-the-art approach CoSSL (Fan et al., 2022). From Table 6, we can see that TRAS achieves the best performance for both image size 32 and 64.

Table 6 Averaged class recall (%) on ImageNet-127

4.3 How does pseudo-label distribution impact the performance?

Recall that we use hyper-parameters A and B to control the distribution of pseudo-labels by \(\tau ({\hat{y}}) =A \cdot \varvec{\alpha }_{{\hat{y}}} + B\), we now analyze their influence on the performance. The results are reported in Fig. 4. We find that B impacts the performance much larger than A, which coincides with our intuition because \(\varvec{\alpha }_{{\hat{y}}} < 1\). It achieves comparable results by setting \(A \in \{1,2,3\}\). But unlike A, \(B = 2\) yields better performance than other values in our experiments. When setting \(A > 3\) and \(B > 3\), test accuracy severely deteriorates because of the heavy bias towards minority classes in the pseudo-label distribution.

Fig. 4
figure 4

The impact of values of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels available. a Overall accuracy(%) of TRAS; b Top-1 accuracy(%) of pseudo-labels of the teacher after transformation; c Top-5 accuracy(%) of pseudo-labels of the teacher after transformation

Interestingly, we find the balancedness of pseudo-label distribution matters much more than the accuracy of pseudo-labels in transfer. As B increases, the top-5 accuracy of pseudo-labels is impaired while the overall accuracy remains competitive. This indicates that class imbalance hurts the performance more than inaccurate pseudo-labels in our approach.

To better understand this phenomenon, we investigate the impact of logit scaling parameters A and B on the quality of pseudo-labels for head, torso, and tail classes separately. As illustrated in Fig. 5, \(A=0, B=0\) reveals superb performance with high precision. However, in Fig. 6, it shows the worst recall in the tail class. Since \(A=0, B=0\) means that pseudo-labels are from the conventional SSL model which is biased to the head class, transferring their distribution to a target model does not help the training of tail classes, as shown in Fig. 7.

Instead, by setting \(A=2, B=2\), it achieves the best performance in overall and tail-class accuracy as reported in Figs. 4a and 7. Notably, it produces high recall yet low precision for tail classes in Figs. 6 and 5. This observation confirms our suspicion that the balancedness of pseudo-label distribution requires more attention than the accuracy of pseudo-labels in knowledge transfer.

Fig. 5
figure 5

Comparison of pseudo-label precision by varying the values of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels available. The x-axis is the number of epochs, and the y-axis is the precision. Classes are divided into head ({0, 1, 2}), torso ({3, 4, 5, 6}) and tail ({7, 8, 9})

Fig. 6
figure 6

Comparison of pseudo-label recall by varying the value of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels. The x-axis is the number of epochs, and the y-axis is the recall

Fig. 7
figure 7

Comparison of test accuracy by varying the values of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels. The x-axis is the number of epochs, and the y-axis is the test accuracy

4.4 Better understanding of TRAS

We analyze TRAS from representation and classification perspectives on CIFAR-10-LT under the main setting. First, we compare the learned representations by ABC and our TRAS via t-distributed stochastic neighbor embedding (t-SNE) (Van der Maaten & Hinton, 2008) in Fig. 8. It can be seen that TRAS has more clear classification boundaries than ABC, which demonstrates that TRAS can distinguish the difference between classes with better representation learning.

Fig. 8
figure 8

Results of t-SNE for ABC (Lee et al., 2021) and our TRAS

Further, to analyze the classification results, we compare the confusion matrices of the prediction on the test set in Fig. 9. Each row represents the ground-truth label and each column represents the prediction by ABC or our TRAS. The value in the i-th row and j-th column is the percentage of samples from the i-th class and predicted as the j-th class. From the results, we can see that our TRAS performs better than ABC in the minority class. Moreover, it is observed that TRAS might misclassify some majority-class samples as the minority-class ones.

Fig. 9
figure 9

Confusion matrices of the prediction on the test set of CIFAR-10-LT

Ablation studies. We conduct ablation studies on important parts of our approach under the main setting.

First, this paper applies logit transformation with \(A=0, B=1\) to the teacher model’s prediction on unlabeled data for better performance of the teacher model. By removing logit transformation, the overall accuracy and minority-class accuracy under the main setting turn out to be 83.66% (−0.64%) and 77.54% (−4.66%) on CIFAR-10-LT, respectively.

Second, we modify the distribution-aware cross-entropy for the labeled data to the common cross-entropy loss, leading to 83.41% (−0.89%) and 80.28% (−1.92%) of the overall accuracy and minority-class accuracy. The marginal decline of the performance verifies the effectiveness of the learning through imitation approach.

Finally, we remove the sample mask on unlabeled data of the student model, which means all unlabeled data is used to imitate the teacher. The experiment shows that removing the sample mask decreases the performance slightly, i.e., 83.33% (-0.97%) and 79.60% (-2.59%) for overall and minority-class accuracy respectively. This demonstrates the advantage of selecting more accurate pseudo-labels for the student model.

4.5 Comparison with two-stage training

We further compare TRAS with two-stage training under the main setting in Table 7. In the two-stage training, we first train a FixMatch model as the teacher and then guide a student model by the teacher with our TRAS. We find that not only can our TRAS save training cost, but it is also better trained than the two-stage approach. This agrees with our expectation that, in two-stage training, the student lays more emphasis on the minority class, while fitting the pseudo-label distribution does not necessarily improve feature learning. Fortunately, two branches in TRAS share the feature learning backbone, which can improve the backbone and classifiers simultaneously.

To further show that double branches of TRAS can both improve feature learning, we stop propagating the gradients of the student branch from affecting the feature learning backbone. In this way, only the teacher model trains the feature extractor network, which is similar to the two-stage training. From the result, we see TRAS performs better, showing that the student can further enhance the feature learning by sharing the backbone with the teacher model.

Table 7 Performance comparison of overall accuracy(%)/minority-class accuracy(%) with two-stage training and TRAS-

5 Conclusion

We introduce TRAS, a new method for LTSSL. TRAS (1) learns from a more class-balanced label distribution to improve the minority-class generalization and (2) partitions the parameter space, enabling transfer via weight sharing of the transformed knowledge learned by the conventional SSL model. Extensive experiments on CIFAR-10-LT, SVHN-LT, and CIFAR-100-LT datasets show that TRAS outperforms state-of-the-art methods on the minority class by a large margin. In the sequel, it would be interesting to extend TRAS to more established SSL methods.