1 Introduction

Semi-supervised learning (SSL) (Zhu and Goldberg 2009; Sohn et al. 2020) aims to improve model performance by leveraging both abundant unlabeled data and limited labeled data. SSL algorithms provide a solution to explore the latent pattern underlying unlabeled data, which reduces requirements of a large amount of annotations (Sohn et al. 2020). Most of the previous SSL studies heavily rely on the pseudo-labeling strategy (Lee 2013; Sohn et al. 2020) that generates a hard label for unlabeled sample and trains the deep model on these pseudo-labels.

For pseudo-labeling methods (Lee 2013; Sohn et al. 2020; Zhang et al. 2021; Xu et al. 2021), it is essential to set a proper threshold for selecting reliable pseudo-labels for unlabeled data. For example, FixMatch (Sohn et al. 2020) selected high-confidence pseudo-labels via a fixed threshold (e.g., 0.95 for CIFAR Krizhevsky and Hinton (2009) and 0.65 for ImageNet (Deng et al. 2009)). However, as reported in Xu et al. (2021), fixing the threshold in the entire training process could mitigate the learning efficiency and raise the error rate of pseudo-labels, especially in the early learning stage.

To address this issue, subsequent works (Xu et al. 2021; Guo and Li 2022; Zhang et al. 2021; Saito et al. 2021) that dynamically generate the threshold to enable more robust SSL have been proposed. For instance, Xu et al. (2021) translated the fixed threshold to a loss threshold and selected the unlabeled data whose loss values (evaluated on pseudo-labels) are smaller than the loss threshold. Then, these selected data are incorporated into the training set, while the loss threshold gradually decreases over training iterations. Zhang et al. (2021) leveraged the idea of curriculum learning (Bengio et al. 2009) to take into account the learning status of each class and flexibly adjusted thresholds for different classes at each time step via a preset function.

Despite the decent performance of the pseudo-labeling methods mentioned above, they share two common drawbacks. Firstly, they (Xu et al. 2021; Guo and Li 2022; Zhang et al. 2021) always resort to manually pre-specified functions to adjust the threshold. This tends to be infeasible when we know little knowledge of underlying datasets or when the label conditions are too complicated. Secondly, these methods (Xu et al. 2021; Guo and Li 2022; Zhang et al. 2021) usually involve at least two hyper-parameters, which requires complex cross-validation phase and thus suffer from the scalability issue (Franceschi et al. 2018) when we apply them to real-world application.

To address the two drawbacks mentioned above, this paper presents a simple yet effective strategy to automatically learn sample-aware confidence thresholds for each unlabeled data. In contrast with previous works, our method does not resort prior knowledge to pre-define a function for adjusting thresholds while including only one hyper-parameter. Besides, to the best of our knowledge, we for the first time introduce instance-level thresholds, which is inspired by that the deep model has different learning capabilities for different categories even for different examples. Figure 1a shows a practical example. Intuitively, setting instance-level thresholds is more logical and beneficial to generate more accurate pseudo-labels for unlabeled instances, further facilitating deep model’s learning.

Fig. 1
figure 1

a An example illustrates that deep models have different learning capabilities for different examples in class tiger. b Review of the pseudo-labeling training framework and comparison FixMatch (Sohn et al. 2020) with three improved algorithms on the fixed confidence threshold. Compared to decayed and class-level thresholds in Dash (Xu et al. 2021) and FlexMatch (Zhang et al. 2021), our method designs a meta-net which generates a more refined confidence threshold for each unlabeled example (i.e., sample-level thresholds)

Specifically, we leverage the idea of meta-learning (Finn et al. 2017) to construct a lightweight meta-net (e.g., three-layer MLP) for explicitly modeling the instance-level thresholds (finally obtain a set of thresholds for all unlabeled data). Thanks to the universal approximation theorem (Hornik et al. 1989) of multilayer feedforward networks, our meta-net can be considered as a generalized version of the pre-defined functions mentioned above (Zhang et al. 2021; Xu et al. 2021). In this way, our framework contains a classifier network and a meta-net, where the training problem of two networks is in a nested optimization scheme. This optimization problem can be solved by a bi-level strategy, which is presented as 1) Inner loop. Generate instance-level thresholds for all unlabeled instances and utilizes the hard pseudo labels to train the classifier network, 2) Outer loop. Update all parameters of the meta-net by a small scale of meta-data which are constructed on the labeled data.

An appealing feature of this formulation is that the inner loop can be viewed as a mapping from the sample threshold space into the meta-net parameter space, and the outer loop performs the optimization on thresholds. Since the indicator function \(\mathbbm {1}(\cdot )\), which is non-differentiable, explicitly exists in the pseudo-labeling framework, we thus leverage a surrogate function to approximate it, making the bi-level optimization problem reachable. In Fig. 1b, we compare our method with vanilla FixMatch (Sohn et al. 2020) and two improved methods (Xu et al. 2021; Zhang et al. 2021), which highlights the merits of our method such as avoiding preset function and no prior knowledge is required.

Our contributions can be summarized as follows:

  • We propose a simple yet effective training framework (named Meta-Threshold, Meta-T) based on bi-level optimization for threshold-based SSL, which enjoys the following benefits: 1) Meta-T learns thresholds of unlabeled sample automatically through bi-level optimization, avoiding the the pathology of conventional threshold-based methods’ reliance on strong prior knowledge on data. 2) Meta-T only includes one extra hyper-parameter, i.e., the learning rate of the meta-net, which is not sensitive and thus does not require complex cross-validation.

  • We introduce the surrogate function to replace the indicator function. Further, we theoretically provide the convergence of our framework and demonstrate that it enjoys a convergence rate of \(\mathcal {O}(1/\epsilon ^2)\).

  • We integrate the proposed Meta-T into the framework of curriculum learning dubbed Green Meta-T, which significantly reduce the training cost of our learning algorithm with only slight loss of accuracy.

  • Our method can be applied to solve both the conventional and imbalanced SSL tasks, exhibiting great potential in real-world applications.

2 Related work

Deep Semi-Supervised Learning As a common learning paradigm, deep SSL exhibits remarkable performance in leveraging a great deal of unlabeled data to train the deep model. Current deep SSL methods can be roughly divided into three categories: consistency-based methods, pseudo-labeling methods, and hybrid methods. The key idea of consistency-based methods is that forcing the model’s output of original unlabeled data and perturbed unlabeled data to keep the same (Laine and Aila 2016; Tarvainen and Valpola 2017; Xie et al. 2020). Pseudo-labeling methods, which are also called self-learning in previous works, belong to an iterative mechanism that uses limited labeled data to train the model to predict unlabeled data. Then, the generated labels of unlabeled data are introduced to train the model Lee (2013). Hybrid approaches (Sohn et al. 2020; Zhang et al. 2021; Xu et al. 2021) always integrate the above two methods with strong augmentation strategies (e.g., RandAugment (Cubuk et al. 2020) and CTAugment (Berthelot et al. 2019)).

Imbalanced Semi-Supervised Learning To improve the universality of SSL algorithms, some works (Kim et al. 2020; Wei et al. 2021; Guo and Li 2022) turn attention to more challenging settings like SSL under class-imbalanced label distribution. DARP (Kim et al. 2020) designed a distribution-aligning manner to modify biased pseudo-labels to match the true class distribution. However, this method requires prior knowledge about data distribution, which is hard to fulfill in real applications. For this, DARP manages to estimate the class distribution by a confusion matrix between labeled and unlabeled data. CReST (Wei et al. 2021) is based on a typical self-training strategy that adaptively adds pseudo-labeled data to the training set according to the label frequency.

Meta-Learning also known as “learning to learn", has been widely applied to several weakly-supervised tasks, such as noisy labels learning (Shu et al. 2019; Sun et al. 2022), out-of-distribution learning (Guo et al. 2020), and semi-supervised learning (Wang et al. 2020; Xiao et al. 2021). In SSL fields, some works introduce the idea of meta-learning to learn a set of parameters. For example, Wang et al. (2020) proposed a framework to learn sample weights for all unlabeled data, which aims to give high weights to more reliable pseudo-labels. Xiao et al. (2021) proposed to learn soft labels for unlabeled data while designing a one-order update strategy for bi-level framework.

Relations Two works L2RW (Ren et al. 2018) and MW-Net (Shu et al. 2019) employed bi-level to efficiently learn a set of hyper-parameter. Our work bears three critical differences.

  1. (1)

    Problem setting: (Ren et al. 2018; Shu et al. 2019) focus on improving the generalization performance of deep models under noisy labels learning, while our work aims to enhance the quality of generated pseudo-labels for unlabeled data in semi-supervised learning.

  2. (2)

    Methodology: (Ren et al. 2018; Shu et al. 2019) learn a set of sample weights for training (label-corrupted) samples and then minimize the product of training loss and corresponding weight, while our framework generates thresholds which are used to select the high-reliability pseudo-labels instead of directly participating in model’s training. Besides, our method obeys the framework of the pseudo-labeling method and thus suffers from the non-differentiable issue of the indicator function, which can be solved by a surrogate function. Eventually, we joint the bi-level training framework with curriculum learning, significantly reducing the cost of bi-level strategy.

  3. (3)

    Theory: We introduce a surrogate function to replace the indicator function and provide the convergence guarantee of our learning algorithm when the upper bound of the surrogate function is given. Besides, we simply give an analysis of training costs of both Meta-T and Green Meta-T.

3 Preliminaries

Problem setting. In a C-class classification task, we have a set of training data which contains N labeled examples \(D^l = \{(\textbf{x}_1^l,\textbf{y}_1^l), \cdot \cdot \cdot ,(\textbf{x}_N^l,\textbf{y}_N^l)\}\) and M unlabeled examples \(D^u = \{\textbf{x}_1, \cdot \cdot \cdot ,\textbf{x}_M\}\), where \(\textbf{x} \in \mathcal {X} \subseteq \mathbb {R}^d\) denotes the input d-dimensional feature vector and \(\textbf{y} \in \mathcal {Y}\) is one-hot label. Given a deep model f with learnable parameters \(\textbf{w}\) and a classification loss function \(H(\cdot )\) (e.g., cross-entropy loss), the training objective in typical supervised learning is \(L_s = \mathbb {E}_{(\textbf{x},\textbf{y}) \sim D^l}H(f(\textbf{x}),\textbf{y})\). To achieve higher performance, the training objective of SSL algorithms can be summarised as \(L_s + \lambda _u L_u\), where \(L_u\) is constructed on \(D^u\) and the trade-off coefficient \(\lambda _u\) satisfies \(\lambda _u > 0\).

3.1 Confidence thresholds in semi-supervised learning

Due to its simplicity yet great performance, we select FixMatch (Sohn et al. 2020) as an example to illustrate the usage of confidence threshold in pseudo-labeling methods.

The core idea of FixMatch is the introduction of confidence threshold and strong augmentation strategies. To train the classifier on unlabeled data, FixMatch first computes the pseudo-label on the weakly-augmented version of image. For each unlabeled data \(\textbf{x}_m \in D^u\), the prediction of classification network is \(p_m = f(\mathcal {A}^w(\textbf{x}_m); \textbf{w})\), where \(\mathcal {A}^w\) denotes weak augmentation strategies, and the pseudo-label can be written as \(\hat{\textbf{y}}_m = \arg \max (p_m)\). Due to the property of function \(\arg \max\), \(\hat{\textbf{y}}_m\) is a one-hot probability distribution. Then, the training loss of \(\textbf{x}_m\) can be summarised as

$$\begin{aligned} \ell _{\textbf{x}_m} = \mathbbm {1}(\max (p_m) > \tau ) \cdot H(\hat{\textbf{y}}_m, f(\mathcal {A}^s(\textbf{x}_m); \textbf{w})), \end{aligned}$$
(1)

where \(\mathbbm {1}(\cdot )\) is an indicator function and denotes the selection of high-reliability of pseudo-label, \(\mathcal {A}^s\) denotes strong augmentation strategies, and \(\tau\) is a fixed constant. Eventually, the training objective of all unlabeled data is \(L_{u} = \frac{1}{M}\sum\nolimits_{{x_{m} \in D^{u} }} {\ell _{{x_{m} }} }\).

As mentioned before, many related works (Zhang et al. 2021; Xu et al. 2021) modified the fixed constant \(\tau\) to improve the universality of pseudo-labeling algorithms. However, they always resort to prior knowledge and further design a task-specific function to adjust this value, limiting their application in practice. Thus, in the next section, we devise a framework that does not require pre-defined functions yet enables sample-aware confidence thresholds.

4 Proposed method

Overview. We construct a meta-net (threshold generation network, or TGN) for dynamically produce sample-level threshold. First, we rewrite the learning objective for threshold-based SSL methods. Second, we introduce the architecture of TGN. Then, we solve this meta-optimization problem via bi-level strategy which alternatively trains the classifier and TGN. Eventually, we analyse the convergence of our algorithm and provide a green version of our method which enjoys lower training time.

4.1 Learning with sample-level thresholds

To alleviate the aforementioned issues of previous methods, we want to construct a meta-learning framework that could generate a sample-level confidence threshold for all unlabeled data in each training step. To be specific, given a meta-net \(\mathcal {V}\) with parameters \(\Theta\), the confidence threshold of unlabeled data \(\textbf{x}_m\) can be written as \(\tau _m \leftarrow \mathcal {V}_m(\textbf{w}, \Theta )\), while the architecture and input of \(\mathcal {V}\) is detailedly illustrated in Sect. 4.2. Then, the fixed constant \(\tau\) in Eq. (1) can be replaced with a sample-level threshold \(\tau _{m}\) and the loss of unlabeled data \(\textbf{x}_m\) is formulated as

$$\begin{aligned} \ell _{\textbf{x}_m}(\textbf{w}, \Theta ) = \mathbbm {1}(\max (p_m) > \mathcal {V}_m(\textbf{w}, \Theta )) \cdot H(\hat{\textbf{y}}_m, f(\mathcal {A}^s(\textbf{x}_m); \textbf{w})). \end{aligned}$$
(2)

However, due to the non-differentiable property of the indicator function \(\mathbbm {1}(\cdot )\), computing partial derivative with respect to \(\Theta\) in Eq. (2) is infeasible. In the practical training phase, we introduce a modified sigmoid function to replace it, which can be written as \(\mathcal {S}(x) = \frac{1}{1+\exp ^{-\beta x}}\) where the input is \(\max (f(\mathcal {A}^w(\textbf{x}_m); \textbf{w})) - \mathcal {V}_m(\textbf{w}, \Theta )\) and \(\beta\) is the slope parameter to control the shape of the function.

Fig. 2
figure 2

Compare the indicator function with the approximate function \(\mathcal {S}(\cdot )\) with varying \(\beta\)

Discuss about the approximate function \(\mathcal {S}(\cdot )\). In Fig. 2, we compare the difference between the indicator function \(\mathbbm {1}(\cdot )\) and the suggorate function \(\mathcal {S}(\cdot )\). We can observe that the input of function satisfies \(\textrm{max}(f(\mathcal {A}^w(\textrm{x}_m); \textbf{w})-\mathcal {V}_m(\textbf{w}, \Theta )) \in [-1, 1]\). Meanwhile, the first-order and second-order gradient of sigmoid function obviously exist, making backpropagation of the training loss in Eq. (2) possible.

Eventually, the optimal classifier parameters \(\textbf{w}^*\) can be calculated by minimizing the loss

$$\begin{aligned} \textbf{w}^*({\Theta }) = \mathop {\arg \min }\limits _{\textbf{w}} L_u = \frac{1}{M} \sum \nolimits _{\textbf{x}_m \in D^u} {\ell _{\textbf{x}_m}}({\textbf{w}, \Theta }). \end{aligned}$$
(3)

4.2 Threshold generation network TGN

In this subsection, we design a threshold generation network (TGN), serving as a meta model. By summarizing previous works (Zhang et al. 2021; Guo and Li 2022), we found that considering average class confidence provides more valuable information for generating threshold and improves the applicability of methods on extreme data distribution. Thus, we construct the meta-net which learns from instance confidence and average class confidence simultaneously and outputs sample-aware threshold for unlabeled data.

Formally, given a weakly-augmented version of unlabeled data \(\textbf{x}_m\), the classifier network \(f_{\textbf{w}}\) gives the prediction result (a soft label) \(g(p_m^t)\) in t-th iteration, where \(g(\cdot )\) denotes Softmax function. Further, the pseudo-label is \({\hat{{{\textbf{y}}}}}_m^t = \arg \max (g(p_m^t))\). Meanwhile, the average class confidence can be represented as \(\overline{\textrm{p}}_c^t = \frac{1}{M}\sum \nolimits _{m=1}^{M} g(p_m^t | c=\hat{{{\textbf{y}}}}_m^t)\). Note that \(\overline{\textrm{p}}_c^t\) can be regarded as an average soft label of class c in time t. Therefore, for unlabeled data \(\textbf{x}_m\), the generated threshold in t-th iteration is

$$\begin{aligned} \tau _m^t = \mathcal {V}(g(f(\textrm{x}_m; \textbf{w})), \, \overline{\textrm{p}}_c^t; \Theta ). \end{aligned}$$
(4)

As shown in Fig. 3b, we illustrate the architecture of proposed TGN, which belongs to a lightweight net (e.g., three full-connected layers). For \(\textbf{x}_m\), we connect its prediction result \(g(p_m^t)\) (a C-dimension soft label) with the average class confidence \(\overline{\textrm{p}}_c^t\) (a C-dimension vector). Therefore, the input layer in TGN is 2C dimension.

Fig. 3
figure 3

a Flowchart of our learning algorithm. The solid and dashed lines represent forward and backward propagation, respectively. In each iteration, overall training process contains six phases. Step 1: feed weak-augmented images to the classifier network and attain pseudo-labels with prediction confidence. Step 2: input a pair of average class confidence and predicted confidence into the meta-net TGN. Step 3: leverage generated sample-level threshold \(\tau\) to select high-reliability data and compute the loss \(L_u\). Step 4: update the classifier parameters while holding the computation graph for its gradient. Step 5: feed the meta-data into the meta-net, compute the loss \(L_\textrm{meta}\) and update \(\Theta\). Step 6: recompute the gradient of \(L_u\) w.r.t. \(\textbf{w}\) and update \(\textbf{w}\). b Architecture of TGN. Given an unlabeled sample \(\textbf{x}_m\), TGN’s input consists of two parts

4.3 Meta-optimization problem

There are two networks in our training framework, including a classification network \(f_{\textbf{w}}\) and a meta-net \(\mathcal {V}_{\Theta }\). The parameters \(\textbf{w}\) and \(\Theta\) can be optimized by the meta-learning idea (Andrychowicz et al. 2016; Shu et al. 2019). Specifically, we require a small amount of meta-data set which can be sampled from labeled data in SSL task. Since some works (Shu et al. 2019; Sun et al. 2022) proved that the generalization performance of the meta-model largely benefits from a large scale of meta-data, we straightforwardly represent this meta-data set as \(D^{\textrm{meta}} = D^l = \{(\textbf{x}^{l}_i, \textbf{y}^{l}_i)\}_{i=1}^N\) (i.e., we use the total labeled data for constructing the meta-data set). The optimal parameters \(\Theta ^*\) can be obtained by minimizing the following loss

$$\begin{aligned} \Theta ^* = \mathop {\arg \min }\limits _{\Theta } L_{\textrm{meta}}(\textbf{w}^*({\Theta })) = \frac{1}{N} \sum \nolimits _{i=1}^N H_i(\textbf{w}^*({\Theta })). \end{aligned}$$
(5)

For clarity, we represent \(H_i(\textbf{w})\) as \(H(\textbf{y}_i^l, f(\textbf{x}_i^l; \textbf{w}))\).

Obtaining the optimal parameters \(\textbf{w}^*\) in Eq. (3) and \(\Theta ^*\) in Eq. (5) is a nested optimization problem. For this, we resort to bi-level training strategy as MAML (Finn et al. 2017) and update parameters of meta-net with online strategy. To be specific, the training loss of classifier network and meta-net (Eq. (3) and Eq. (5)) can be optimized via the SGD optimizer. In each training iteration, given a mini-batch size number n, we have two batches of meta data and unlabeled data and represent them as \(\{(\textbf{x}_1^l, \textbf{y}_1^l),...,(\textbf{x}_n^l, \textbf{y}_n^l)\}\) and \(\{\textbf{x}_1,...,\textbf{x}_{(\mu \times n)}\}\), respectively. Note that we can increase \(\mu\) to expand the size of unlabeled data in one iteration. In t-th iteration, we formulate the parameter of classifier network as \(\textbf{w}^{(t)}\) and the parameters of the meta-net as \(\Theta ^{(t)}\). The updates of two networks are as the following three phases.

Algorithm 1
figure a

Learning algorithm of Meta-T.

  • Formulating learning manner of classifier network. Given the learning step with a size of \(\alpha\), the descent direction of the objective loss in Eq. (3) on a mini-batch unlabeled data is

    $$\begin{aligned} {\hat{{{\textbf{w}}}}}^{(t)}(\Theta ) = \textbf{w}^{(t)} - \alpha \frac{1}{n\mu } \sum \nolimits _{i=1}^{n\mu } \nabla _{\textbf{w}} \ell _{\textbf{x}_i}(\textbf{w}^{(t)}, \Theta ^{(t)}), \end{aligned}$$
    (6)

    where \(\ell _{\textbf{x}_i}\) is calculated by Eq. (2).

  • Updating parameters\(\Theta\) As we obtain parameter \({\hat{{{\textbf{w}}}}}^{(t)}(\Theta )\) with fixed \(\Theta\) in Eq. (6), the update of our meta-net TGN can be achieved by a mini-batch of meta-data \(\{(\textbf{x}_1^l, \textbf{y}_1^l),...,(\textbf{x}_n^l, \textbf{y}_n^l)\}\). Specifically, \(\Theta ^{(t)}\) moves along the direction of direction of gradients w.r.t. the objective in Eq. (5)

    $$\begin{aligned} \Theta ^{(t+1)} = \Theta ^{(t)} - \psi \frac{1}{n} \sum \nolimits _{i=1}^{n} \nabla _{\Theta } H_i({\hat{{{\textbf{w}}}}}^{(t)}(\Theta )), \end{aligned}$$
    (7)

    where \(\psi\) denotes the learning step of the SGD optimizer. Note that \(\Theta\) in this equation is a variable, which enables gradient computation of \(\frac{\partial {\hat{{{\textbf{w}}}}}^{(t)}(\Theta )}{\Theta }\).

  • Updating parameters \(\textbf{w}\) of classifier network. Eventually, we utilize the updated TGN \(\Theta ^{(t+1)}\) to regenerate confidence threshold for unlabeled data and update the parameters \(\textbf{w}\) of classifier network

    $$\begin{aligned} {\textbf{w}}^{(t+1)} = \textbf{w}^{(t)} - \alpha \frac{1}{n\mu } \sum \nolimits _{i=1}^{n\mu } \nabla _{\textbf{w}} \ell _{\textbf{x}_i}(\textbf{w}^{(t)}, \Theta ^{(t+1)}). \end{aligned}$$
    (8)

We illustrate the flowchart of our learning algorithm in Fig. 3a, where Step 4,5,6 represent Eqs. (6), (7) and (8), respectively. Meanwhile, we summarize the overall updating steps in Algorithm 1. Compared to current SSL methods, Meta-T does not rely on any prior knowledge to predefine the function for adjusting the threshold. We believe that this merit would expand applicability of our method in certain environments where we cannot model the data distribution.

4.4 Convergence analysis

We analyze the convergence of Meta-T and give a rigorously theoretical guarantee.

Lemma 1

(Smoothness). Suppose the loss function H is L-Lipschitz and smooth, and the approximate function \(\mathcal {S}\) is \(\zeta\)-Lipschitz, and \(\mathcal {V}(\cdot )\) is differential with \(\delta\)-bounded gradient and twice differential with \(\mathcal {B}\)-bounded Hessian, and the loss function H have \(\rho\)-bounded gradients w.r.t. training/meta data and has upper bound with \(\phi\). Replacing indicator function with \(\mathcal {S}\), the gradient of \(\Theta\) w.r.t. the meta loss is Lipschitz continuous.

The Proof is shown in Appendix A.1 and Lemma 1 implies that the meta loss w.r.t. the meta-network is smooth-bounded.

Theorem 1

(Convergence) Based on Lemma 1, let the learning rate \(\alpha _t\) satisfies \(\alpha _t = \min \{1, \frac{k}{T}\}\), for some \(k > 0\), such that \(\frac{k}{T} < 1\), and \(\psi _t\), \(1 \le t \le T\) is a monotone descent sequence, \(\psi _t = \min \{\frac{1}{L}, \frac{\mathrm{{c}}}{\sigma \sqrt{T}}\}\) for some \({\textrm{c}}>0\), such that \(\frac{\sigma \sqrt{T}}{\textrm{c}} \ge {L}\) and \(\sum \nolimits _{t=1}^{\infty } \psi _t \le \infty , \sum \nolimits _{t=1}^{\infty } \psi _t^2 \le \infty\). Then we have \({\frac{1}{T} \sum \nolimits _{t=1}^T \mathbb {E} \Big [ \left\| \nabla L_{\textrm{meta}} \Big ( \hat{{{\textbf{w}}}}^{(t)}(\Theta ^{(t)}) \Big ) \right\| _2^2 \Big ] \le \mathcal {O}(\frac{1}{\sqrt{T}})}\).

The Proof is shown in Appendix A.2. To be specific, Theorem 1 means that the our algorithm can achieve \({\mathbb {E} \Big [ \left\| \nabla L_\textrm{meta} \Big ( \hat{{{\textbf{w}}}}^{(t)}(\Theta ^{(t)}) \Big ) \right\| _2^2 \Big ] \le \epsilon }\) in \(\mathcal {O}(1/\epsilon ^2)\) steps, and would eventually convergence to a stationary point with the training iteration step increases.

4.5 Green meta-T: training with lower complexity

Training complexity analysis. Compared with the single-step training procedure, the training process of Meta-T can be divided into three parts, (1) forward and backward passes of the classifier network for computing \(\hat{{{\textbf{w}}}}(\Theta )\); (2) forward and backward passes of TGN for updating \(\Theta\); (3) forward and backward passes of classifier network for updating \(\textbf{w}\). Hence, compared with FixMatch, which only involves one forward and backward pass, Meta-T requires approximately three times of training time.

As summarized by Xu et al. (2021), the main cost of training time is caused by the backpropagation in updating the parameters \(\Theta\) of the meta-net since the meta-gradient in Eq. (7) needs to compute the similarity between each meta-data and unlabeled data. Therefore, reducing the computation of \(\hat{{{\textbf{w}}}}(\Theta )\) would significantly decrease training time. To this end, we change the training procedure that integrates our proposed Meta-T algorithm with curriculum learning and name it Green Meta-T. Specifically, we conduct the bi-level strategy (i.e. Meta-T) once for learning the classifier network and TGN, and then continuously do \(\mathrm k\)-step classifier learning. Then, we give the training complexity of Green Meta-T as follows.

Proposition 1

Suppose a fixed training iteration T, the training time of FixMatch and Meta-T can be represented as \(\mathcal {T}\) and \(3\mathcal {T}\), respectively. Given a hyper-parameter \(\mathrm k\), the training time of Green Meta-T is \(\frac{\textrm{k}+2}{\textrm{k}} \mathcal {T}\).

Proposition 1 means that the training complexity of Green Meta-T could gradually reduce to \(\mathcal {T}\) with the value of \({\textrm{k}}\) increases.

5 Experiments

Table 1 Details about five tested benchmarks

5.1 Experimental settings

Datasets. We select five image classification datasets and three text classification datasets to evaluate the effectiveness of Meta-T, including five image benchmarks CIFAR-10 (Krizhevsky and Hinton 2009), CIFAR-100 (Krizhevsky and Hinton 2009), SVHN (Coates et al. 2011), SLT-10 (Netzer et al. 2011), and ImageNet (Deng et al. 2009), three text benchmarks IMDb (Maas et al. 2011), Amazon-5 (Zhang et al. 2015) and Yelp-5 (Zhang et al. 2015). Detailed statistics of these datasets are shown in Table 1.

Implementation Details. Our code is implemented by Pytorch 1.9.0 with GTX 3090. We leverage a pytorch library called Higher (Grefenstette et al. 2019) to implement our algorithm, which provides support for higher-order optimization. For all experiments, we repeat five times with different random seeds. Others for two networks are shown below

  • For the classifier, more information about data preprocessing and training procedure can be found in Table 2.

  • For TGN, we set the size of meta-data as 32 and utilize Adam optimizer with 1e-3 learning rate for all training epoches. We construct the three-layers fully-connected MLP for TGN, whose structure is \(\{2 \mathcal {C}, h, 1\}\). Notably, h is set as 100 for all image datasets and 1000 for all text datasets and \(\mathcal {C}\) is the number of categories.

Table 2 Detailed settings about training procedure of the backbone (the classifier network)
Table 3 Error rates (%) for previous SOTA methods on CIFAR-10 and CIFAR-100 with varying size of labeled set

5.2 Results on typical SSL task

Baselines. We categorize compared methods into two types. 1) Threshold-based methods, including Pseudo-Labeling (PL) Lee (2013), FixMatch (Sohn et al. 2020), FlexMatch (Zhang et al. 2021) and Dash (Xu et al. 2021). 2) others, including \(\Pi\)-Model (Sajjadi et al. 2016), MixMatch (Berthelot et al. 2019), UDA (Xie et al. 2020), CoMatch (Li et al. 2021) and SimMatch (Zheng et al. 2022).

Results on four image datasets. We conduct experiments on CIFAR-10, CIFAR-100, SVHN, SLT-10 and ImageNet. The results are shown in Tables 3 and 4. On CIFAR-10 & 100, Meta-T outperforms previous methods in the majority of settings. Under an extremely small size of the labeled set, the superiority of our method is significant. For example, we achieve 1.64% Top-1 accuracy improvements on CIFAR-100 with only 4 samples per class. Compared with threshold-based methods (Lee 2013; Sohn et al. 2020; Zhang et al. 2021; Xu et al. 2021), the improvement of our method is significant. On all settings, Meta-T constantly outperforms their performance. Eventually, our method also achieved the SOTA performance on ImageNet. By leveraging only 1% labeled data, Meta-T attains 67.7% top-1 accuracy on the test set. Compared to the previous state-of-the-art method SimMatch, the obtained improvement of 0.5% is significant in ImageNet. The superiority of Meta-T on ImageNet can already demonstrate its effectiveness on real-world SSL tasks.

Results on three text datasets. For a fair comparison, we keep the same training procedure with SoftMatch. Under two text benchmarks, including IMBb and Yelp-5, our method consistently achieves the best top-1 accuracy. Especially in Yelp-5 dataset, Meta-T outperforms the second-best method FlexMatch with 0.57% accuracy, which is a huge improvement in such a large-scale dataset.

Table 4 (Left) Error rates (%) for previous methods on SVHN and STL-10 with varying size of labeled set
Table 5 Top-1 test accuracy (%) on imbalanced CIFAR-10 under three imbalanced ratio and two different size of labeled set. The backbone is Wide ResNet-28-2

5.3 Results on imbalanced SSL task

We categorized compared methods into two parts. 1) Threshold-based methods, FixMatch (Sohn et al. 2020, Dash Xu et al. 2021) and FlexMatch (Zhang et al. 2021). 2) Others, cRT (Kang et al. 2019), LDAM, MixMatch (Berthelot et al. 2019), ReMixMatch (Berthelot et al. 2019), DARP (Kim et al. 2020), CReST (Wei et al. 2021) and Adsh (Guo and Li 2022). For constructing imbalanced datasets, we refer to Guo and Li (2022). Specificlly, we write the size of two training sets as \(N = \sum \nolimits _{c=1}^C N_c\) and \(M = \sum \nolimits _{c=1}^C M_c\). To construct imbalanced datasets, two parameters (imbalance ratio) \(\gamma _l, \gamma _u\) is introduced, i.e., \(\gamma _l = \frac{N_l}{N_C}, \gamma _2 = \frac{M_1}{M_C}\). Once \(\gamma _l,\gamma _u\) and \(N_1, M_1\) are given, we set \(N_c = N_1 \cdot \gamma _l^{-\frac{c-1}{C-1}}, M_c = M_1 \cdot \gamma _u^{-\frac{c-1}{C-1}}\) for \(1 < c \le C\). We conduct experiments on two settings, i.e., \(N_1=500, M_1=4000\) and \(N_1=1500, M_1=3000\) with varying imbalanced ratios \(\gamma _1, \gamma _2 \in [50, 100, 150]\).

In Table 5, we conduct the comparison experiments on the settings \(\gamma = \gamma _1 = \gamma _2\) and report the results. From the results, we can see that (1) our proposed Meta-T achieves the state-of-the-art performance in most cases, showing its robustness in such a data-imbalanced case; (2) with the imbalanced ratio increasing, the performance of our algorithm becomes more significant. Compared to the second best performance (i.e., Adsh), we achieve 1.43% top-1 accuracy improvements under \(\gamma =100\) and 2.42% improvements under \(\gamma =150\). The performance of Meta-T is slightly lower than that of Adsh on the case \(N_1=500, M_1=4000, \gamma =50\).

5.4 Effectiveness analysis

Fig. 4
figure 4

Visualization of the curve of correct and error pseudo-labels in the selected set with varying training epochs.Note that Setting 1: keep the identical training time (FixMatch / FlexMatch: 1000 epochs, Ours: 300 epochs), Setting 2: keep the same training epochs as FixMatch

Fig. 5
figure 5

From the perspective of the confusion matrix, we compare Meta-T with FixMatch and FlexMatch under CIFAR-10 with \(\gamma = \gamma _l = \gamma _u = 100, N_1=1500, M_1=3000\)

Pseudo-labels. We verify the quality of produced pseudo-labels on both typical and imbalanced SSL settings.

  • Typical SSL. In Fig. 4a, b left, Meta-T shows greater performance in generating correct pseudo-labels, which benefits from the higher quality of thresholds produced by TGN. In the early learning stage, the number of correct labels in our method is remarkably higher than that in FixMatch, reflecting the superiority of sample-level thresholds. In Fig. 4a, b right, we exhibit the results of the number of wrong labels. Due to the poor performance of TGN in the early learning stage, some thresholds with low quality are produced, causing a greater number of wrong pseudo-labels compared with deterministic methods such as FlexMatch. Fortunately, the number of wrong labels decrease with the learning process and is lastly lower than that of FixMatch.

  • Imbalanced SSL. We conduct experiments from the perspective of the confusion matrix on unlabeled data and show results in Fig. 5. Thanks to the average class confidence, which is input into the TGN, we believe that TGN can learn the classifier confidence scores regarding varying categories under imbalanced settings and thus adaptively generate class-balanced confidence thresholds. Experimentally, FixMatch focuses on the studies of majority categories and thus produces unreliable pseudo-labels for minority classes. However, Meta-T achieves significant results on tailed classes and attains more than 80% accuracy on all classes.

Sample-level thresholds. We show the learned thresholds from three aspects to demonstrate the effectiveness of Meta-T.

  • Accuracy. Figure 6a shows the learned confidence thresholds on CIFAR-10 and CIFAR-100. We can observe that (1) the main learned sample-level thresholds are in the interval of [0.9, 1.0], supporting the prior knowledge that the confidence threshold should be set as 0.95 for CIFAR. The results verify that competitive sample-level thresholds can be learned by TGN; (2) some thresholds less than 0.95 are learned by our algorithm, where the samples can be regarded as hard (or boundary) samples. For this, it is reasonable that TGN gives them relatively low thresholds, which benefits the model’s learning for these samples.

  • Robustness. Figure 6b visualize the produced thresholds and test accuracy (%) under long-tail semi-supervised learning. We can see that our proposed Meta-T learns lower thresholds for tailed classes while keeping high thresholds for many-shot classes. Since a small number of tailed classes, the classifier has moderate or low confidence for these samples. For this, Meta-T produces relatively small thresholds (around 0.5) and thus enables the classifier to learn from more long-tailed unlabeled samples.

  • Stability. Figure 6c shows the comparison results from dynamic threshold generation. In the beginning, Meta-T tends to initialize thresholds of all unlabeled data as 0.5 and then immediately grow up to 0.95, which is identical to the setting in FixMatch. This result demonstrates thresholds learned by Meta-T are close to the optimal thresholds.

Fig. 6
figure 6

Results about learned confidence thresholds from three aspects. (a) Visualization of generated sample-level thresholds \(\tau\) for all unlabeled data on balanced CIFAR-10 (250 labels) and CIFAR-100 (2500 labels). (b) Visualization of generated thresholds under imbalanced SSL (CIFAR-10 with \(N_1=1500, M_1=3000, \gamma _1=\gamma _2=100\)). (c) Visualization of class-average confidence threshold v.s. learning processes. We compare Meta-T with others under balanced SLT-10 w/ 40 labels

5.5 Sensitivity analysis

We conduct experiments to analyse the sensitivity of Meta-T in three aspects.

The architecture of TGN. To exhibit the impact of the architecture of TGN, we try different MLP architecture settings with different depths and widths and show the results in Table 6left. It can be seen that varying (five) MLP settings have unsubstantial effects on the final result. Therefore, we prefer to adopt the simple yet effective one, i.e., \(\{2\mathcal {C}, 100, 1\}\), for all datasets. Meanwhile, we consider that TGN can attain great performance even under a small-scale meta-data due to its tiny number of parameters.

The learning rate \(\psi\) w.r.t. the meta-net. Compared with existing methods, our framework introduces an extra hyper-parameter (i.e., the learning rate of meta-net \(\psi\)), which does not require complex cross-validation process. Experimentally, we conduct ablation studies and show results with different settings of optimization for TGN in Table 6right. We can conclude that our algorithm is insensitive to the hyperparameter \(\psi\). Thus, we select a normal setting, i.e., Adam optimizer with 1e-3 learning rate.

The slope parameter \(\beta\). We conduct experiments with varying settings, \(\beta \in \{1, 10, 50, 100, 1000\}\). As shown in Figure 7a, b, the generalization performance improves as \(\beta\) increases at the beginning. When \(\beta\) exceeds 100, the improvement of the performance can be trivial. We thus set \(\beta =100\) for all experiments.

Table 6 Ablation studies of different settings of the meta-net TGN
Fig. 7
figure 7

Sensitivity analysis of the slope parameter \(\beta\) in the surrogate function (a,b) and the step number \(\mathrm k\) of Green Meta-T (c, d)

5.6 Efficiency analysis

The step number \(\mathrm k\) of Green Meta-T. We make ablation studies on two SSL settings with \(\textrm{k} \in \{1,2,...,10\}\). In Fig. 7c, d, we can observe that (1) with \(\mathrm k\) increases, the error rate of Green Meta-T gradually increases compared to Meta-T. It is reasonable that the learning of TGN would significantly decrease when conducting more rounds of classifier learning in the outer loop of curriculum learning. (2) A relatively large \(\mathrm k\) might not degrade the performance of Green Meta-T under a mild SSL setting.

To demonstrate efficiency of Green Meta-T, we plot learning curves whose abscissa is the number of accumulative floating point operations (FLOPs). FLOPs are from both the forward and backward propagation. To show the efficiency of Green Meta-T, we plot train loss, train accuracy, test loss, test accuracy with identical numbers of FLOPs for two learning algorithms in Figure 8. Since the number of epoch for two algorithms is identical, the learning process of Green Meta-T ends after approximately 240k FLOPs. We highlight that Green Meta-T achieves faster convergence than Meta-T when accumulative FLOPs are identical and reduces the computation cost from the second-order derivative at the meta-learning phase.

Fig. 8
figure 8

Results w.r.t. accumulative FLOPs on CIFAR-100 with 1000 labels

6 Conclusion

In this paper, we consider sample-level thresholds for pseudo-labeling methods in semi-supervised learning while a simple yet effective framework Meta-T is proposed. Compared with previous methods, Meta-T only contains one hyperparameter and does not rely on preset adjustment functions. By constructing a lightweight meta-net, the sample-aware thresholds can be automatically generated by this network. The update of the classifier network and meta-network can be achieved via bi-level strategy. We also design a surrogate function to replace the indicator function in typical pseudo-labeling methods. Further, we theoretically analyze the convergence of Meta-T and provide a solution to reduce training complexity, called Green Meta-T. Extensive experiments on typical and imbalanced SSL demonstrate its effectiveness.