Keywords

1 Introduction

Deep neural networks (DNNs) have been widely used for machine learning applications. Despite of their success, it has been shown that the training of DNNs requires large-scale labeled and unbiased data. However, in many real-world applications, training set biases are prevalent [9, 21, 27, 28], which typically have two types: i) class-imbalanced data distribution; and ii) noisy labels. For example, in autonomous driving, the vast majority of the training data is composed of standard vehicles but models also need to recognize rarely seen classes such as emergency vehicles or animals with very high accuracy. This will sometime lead to biased training models that do not perform well in practice. Moreover, large-scale high-quality data annotations are expensive and time-consuming to obtain. Although coarse labels are cheap and of high availability, the presence of noise will hurt the model performance. Therefore, it is desirable to develop machine learning algorithms that can accommodate not only class-imbalanced training set, but also the presence of label noise.

Both learning with noisy labels and class-imbalanced learning (a.k.a. long-tailed learning) have been studied for many years. When dealing with label noise, the most popular approach is sample selection where correctly-labeled examples are identified by capturing the training dynamics of DNNs [11, 29]. When dealing with class imbalance, many existing works propose to reweight examples or design unbiased loss functions by taking into account the class distribution of training set [3, 8, 26]. However, most existing methods focus on only one of these two training set biases.

Fig. 1.
figure 1

Illustration of normal classifier and Prototypical Classifier.

In this paper, we address both training set biases simultaneously. As shown in Fig. 1a, it is known that the classifier directly learned on class-imbalanced data is biased towards head classes [8, 32] which results in poor generalization on tail classes. Moreover, using sample loss/confidence produced by biased classifiers fails to detect label noise, because both clean and noisy samples of tail classes have large loss and low confidence. To solve this problem, we propose to use Prototypical Classifier which is demonstrated to produce balanced predictions even through the training set is class-imbalanced. Our basic idea is that there exists an embedding in which examples cluster around a single prototype representation for each class. In order to do this, we learn a non-linear mapping of the input into an embedding space using a neural network and take a class’s prototype to be the normalized mean vector of examples in the embedding space. Classification is then performed for an embedded test example by simply finding the nearest class prototype. Notably, Prototypical Classifier does not need additional learnable parameters given embedding of examples. Unfortunately, it is easy to observe that simply using prototypes for classification may lead to many wrong predictions for samples of head classes as shown in Fig. 1b. The reason is that the representations are supposed to be modified when the classification boundaries of tail classes expand. We therefore train the neural networks to pull together embedding of examples and the prototype of their class, while pushing apart examples from prototypes of other classes. By doing this, it can avoid many mis-classifications for samples of head classes, as shown in Fig. 1c. Subsequently, we find that the confidence scores produced by Prototypical Classifier is balanced and comparable across classes. By leveraging this property, we can simply detect noisy labels via thresholding where the threshold is dynamically adjusted, followed by a sample re-weighting strategy.

In summary, our key contributions of this work are:

  • We propose to learn from training set with mixed biases, which is practical but has been understudied;

  • Our approach, Prototype Classifier, is simple yet powerful. It produces more balanced predictions over all classes than normal classifiers even when the training set is class-imbalanced. This property further benefits the detection of label noise.

  • On both simulated datasets and a real-world dataset Webvision with label noise, Prototype Classifier achieves substantial performance improvement.

2 Related Work

Class-Imbalanced Learning. Recently, many approaches have been proposed to handle class-imbalanced training set. Most extant approaches can be categorized into three types by modifying (i) the inputs to a model by re-balancing the training data [16, 22, 32]; (ii) the outputs of a model, for example by post-hoc adjustment of the classifier [8, 17, 25]; and (iii) the internals of a model by modifying the loss function [2, 6, 20, 23]. Each of the above methods are intuitive, and have shown strong empirical performance. However, these methods assume the training examples are correctly-labeled, which is often difficult to obtain in real-world applications. Instead, we study a realistic problem to learn from class-imbalanced data with label noise.

Label Noise Detection. Plenty of methods have been proposed to detect noisy labels [4, 7, 10]. Many works adopt the small-loss trick, which treats samples with small training losses as correctly-labeled. In particular, MentorNet [7] reweights samples with small loss so that noisy samples contribute less to the loss. Co-teaching [4] trains two networks where each network selects small-loss samples in a mini-batch to train the other. DivideMix [10] fits a Gaussian mixture model on per-sample loss distribution to divide the training data into clean set and noisy set. In addition, AUM [19] introduces a margin statistic to identify noisy samples by measuring the average difference between the logit values for a sample’s assigned class and its highest non-assigned class. The above methods only consider class-balanced training sets, thus is not directly applicable for class-imbalanced problems. Ref. [12] observes that real-world dataset with label noise also has imbalanced number of samples per-class. Nevertheless, they only inspect a particular setup of class imbalance.

3 Prototypical Classifier with Dynamic Threshold

3.1 Motivation

Consider a binary classification problem with the data generating distribution \(\mathbb {P}_{XY}\) being a mixture of two Gaussians. In particular, the label Y is either positive (+1) or negative (−1) with equal probability (i.e., \(\frac{1}{2}\)). Condition on \(Y = +1, \mathbb {P} (X \mid Y = +1) \sim \mathcal {N} (\mu _1, \sigma _1)\) and similarly, \(\mathbb {P} (X \mid Y = -1) \sim \mathcal {N} (\mu _2, \sigma _2)\). Without loss of generality, let \(\mu _1 > \mu _2\). It is straightforward to verify that the optimal Bayes’s classifier is \(f(x) = sign(x - \frac{\mu _1+\mu _2}{2})\) [30], i.e., classify x as +1 if \(x > \frac{\mu _1+\mu _2}{2}\). This reminds us the nearest neighbor classifier, whose classification boundary is at the middle of two data points (i.e., balanced classification boundary). For general multi-class tasks, this motivates us to measure the distance of samples to class prototypes, which is empirically observed to produce balanced classification boundary even though the training set is class-imbalanced, as shown in Fig. 2.

Fig. 2.
figure 2

Experiment on CIFAR-100-LT. x-axis is the class labels with decreasing training samples and y-axis is the marginal likelihood p(y) on the test set.

In order to do this, we learn a non-linear mapping of the input into an embedding space using a neural network \(f_{\theta }\) parameterized by \(\theta \) using training set \(\mathcal {D} = \{({\boldsymbol{x}}_i, y_i)\}_{i=1}^N\). The class prototype is taken as the normalized mean vector of the embedded examples belonging to its class. For example, the prototype for class \(k \in \{1,\dots , K\}\) is computed as:

$$\begin{aligned} \boldsymbol{c}_{k} = {\text {Normalize}}\bigg ( \frac{1}{|\mathcal {D}_k|} \sum _{ i \in \mathcal {D}_k } f_\theta ({\boldsymbol{x}}_i) \bigg ), \mathcal {D}_k = \left\{ i \mid y_{i}=k \right\} . \end{aligned}$$
(1)

Prototypical Classifier produces a distribution over classes for sample \({\boldsymbol{x}}\) based on a softmax over distances to the prototypes in the embedding space. In particular, when use cosine similarity as distance measure, we have:

$$\begin{aligned} \mathbb {P}_{\theta }( Y=k \mid {\boldsymbol{x}})=\frac{\exp \left( f_{\theta }({\boldsymbol{x}})^{\top } \mathbf {c}_{k}\right) }{\sum _{k^{\prime }} \exp \left( f_{\theta }({\boldsymbol{x}})^{\top } \mathbf {c}_{k^{\prime }}\right) }. \end{aligned}$$
(2)

Learning proceeds by minimizing the negative log-probability \(J(\theta )=-\log \mathbb {P}_{\theta }(Y=k \mid \mathbf {x})\) of the true class label k via SGD. Notably, the model in Eq. (2) is equivalent to a linear model with a particular parameterization [18]. To see this, expand the term in the exponent:

$$\begin{aligned} \mathbf {c}_{k}^{\top } f_{\theta }({\boldsymbol{x}}) = \mathbf {w}_{k}^{\top } f_{\theta }({\boldsymbol{x}})+b_{k}, \text{ where } \mathbf {w}_{k}= \mathbf {c}_{k} \text{ and } b_{k}=0. \end{aligned}$$
(3)

Our results indicate that Prototypical Classifier is effective despite the equivalence to a linear model. We hypothesize this is because all of the required non-linearity can be learned within the embedding function [24]. Indeed, this is the approach that modern neural network classification systems currently use.

3.2 Dynamic Thresholding for Label Noise Detection

However, the existence of label noise may hurt the representation learning of the network. To tackle this issue, it is a common practice to correct noisy labels. Let \(\hat{{\boldsymbol{y}}} = [\hat{y}_1, \cdots , \hat{y}_K] = \mathbb {P}_{\theta }( Y \mid {\boldsymbol{x}})\) be the prediction of Prototypical Classifier, the labels are refined as stated by the following rule:

$$\begin{aligned} \tilde{y} = \left\{ \begin{array}{ll} y_{i} &{} \text{ if } \hat{y}_{y_i} > \tau _t \\ \hbox {arg max}_j \hat{y}_{j} &{} \text{ otherwise. } \end{array}\right. \end{aligned}$$
(4)

In words, we deem samples as clean if the confidence scores on their original labels is greater than a threshold \(\tau _t\). It is notably that using normal classifiers cannot achieve this goal due to its biased predictions, while predictions of Prototypical Classifier are balanced and comparable. We illustrate this finding in Fig. 3.

Fig. 3.
figure 3

Experiment on CIFAR-100-LT. x-axis is the class labels with decreasing training samples and y-axis is the confidence scores of classifiers on training set.

We then need to construct \(\tau _t\). Intuitively, with the increase of the optimization iteration t, the predictive confidence also increases in general, so that \(\tau _t\) is also required to increase. Mathematically, we set the dynamic threshold \(\tau _t\) as an increasing function of t, which is given by:

$$\begin{aligned} \tau _t = \gamma ^{t} \tau _0. \end{aligned}$$
(5)

Here, \(\tau _0\) is the initial threshold and \(\gamma \) is set to 1.005 in our experiments. We provide more analysis about \(\tau _t\) in supplementary materials. Lemma 1 summarizes the performance bound of the label noise detection method.

Lemma 1

With probability at least p, the F\(_1\)-score of detecting noisy labels in \(\mathcal {D}_j\) by thresholding the predictive scores of Prototypical Classifier is at least \(1-\frac{e^{-v} \max \left( N^{-}, N^{+}\right) +\alpha }{N^{-}}\) when the noise ratio is known, where \(p=\int _{-1}^{\mu ^{t r u e}-\mu ^{f a l s e}-\varDelta } f(t) d t\), f(t) is the probability density function of the difference of two independent beta-distributed random variables \(\beta _1 - \beta _2\), where \(\beta _{1} \sim {\text {Beta}}\left( N^{-}, 1\right) \), \(\beta _{2} \sim {\text {Beta}}\left( \alpha +1, N^{+}-\alpha \right) \).

Lemma 1 shows that the performance of noise detection depends on the intraclass concentration of clean samples in the embedding space (denoted by \(\frac{\varDelta ^2}{v}\)), which is optimized by the prototypical contrastive loss defined in Eq. (6). We refer the reader to Ref. [33] for the proof of Lemma 1. We further justify the effectiveness of our method in Fig. 4, which produces high F\(_1\)-score for both head and tail classes.

Fig. 4.
figure 4

Experiment on CIFAR-100-LT. We show the F\(_1\)-score of clean examples selection module for many, medium and few classes.

3.3 Example Reweighting

In standard training, we aim to minimize the expected loss for the training set, where each input example is weighted equally. Here we aim to learn a reweighting of the inputs to cope with hard mislabeled samples whose labels are not correctly refined, where we minimize a weighted loss:

$$\begin{aligned} \mathcal {L}_{\text{ pc } }= \frac{-1}{ \sum _{i=1}^N w_i } \sum _{i=1}^{N} w_i \log \frac{\exp \left( f_{\theta }({\boldsymbol{x}}) \cdot \boldsymbol{c}_{y_{i}} / \tau \right) }{\sum _{k=1}^{K} \exp \left( f_{\theta }({\boldsymbol{x}}) \cdot \boldsymbol{c}_{k} / \tau \right) }. \end{aligned}$$
(6)

With a slight abuse of the notation, we re-define \(w_i\) to be the weight for the i-th example and \(\tau \) is a temperature parameter. We expect the weights can reflect the likelihood of examples being correctly-labeled. In that regard, we devise a weighted version for computing prototypes as:

$$\begin{aligned} \boldsymbol{c}_{k} = {\text {Normalize}}\bigg ( \frac{1}{\sum _{i \in \mathcal {D}_k} w_i} \sum _{ i \in \mathcal {D}_k } w_i f_\theta ({\boldsymbol{x}}_i) \bigg ), \mathcal {D}_k = \left\{ i \mid y_{i}=k \right\} . \end{aligned}$$
(7)

Recall that, one appealing property of Prototypical Classifier is balanced predictions across all classes, as opposite to biased normal classifiers. We therefore simply set examples weights as the predicted score of Prototypical Classifier on the training label, i.e., for the i-th example, we set \(w_i = \mathbb {P}_{\theta }(Y=y_i \mid {\boldsymbol{x}}_i)\) where \(y_i\) is the training label of \({\boldsymbol{x}}_i\). For samples whose labels are rectified, we update their weights by \(w' = \frac{\tau _t - w}{2}\) to reflect the uncertainty. The modified example weights are always positive since the label is refined if and only if \(w = \mathbb {P}(Y = y_i \mid {\boldsymbol{x}}_i) \le \tau _t\). The optimization of \(\mathcal {L}_{\text{ pc }}\) is realized by contrastive learning, which has been demonstrated effective in learning representations [13]. Observing that the presence of label noise may have negative effect on representation learning, we train networks to optimize the unsupervised contrastive loss, which does not use the biased training labels. The basic idea of unsupervised contrastive learning is to pull together two embeddings of the same example, while pushing apart from other examples. Formally, let \({\boldsymbol{z}}_i = f_{\theta }({\boldsymbol{x}}_i)\) and \({\boldsymbol{z}}_i^{\prime }\) be the embedding of augmented version of \({\boldsymbol{x}}_i\), the unsupervised contrastive loss is computed as:

$$\begin{aligned} \mathcal {L}_{\text{ cc } }^{i}=-\log \frac{\exp \left( \boldsymbol{z}_{i} \cdot \boldsymbol{z}_{i}^{\prime } / \tau \right) }{\sum _{b=0}^{B} \exp \left( {\boldsymbol{z}}_i \cdot \boldsymbol{z}_{b}^{\prime } / \tau \right) }, \end{aligned}$$
(8)

where \(\tau \) is a scalar temperature parameter and B is mini-batch size.

Given the above definitions and denoting \(\mathcal {L}^{\mathrm {ce}}\) as conventional cross-entropy loss, the overall training objective is written as:

$$\begin{aligned} \mathcal {L}=\mathcal {L}^{\mathrm {ce}}+\lambda _{1} \mathcal {L}^{\mathrm {cc}}+\lambda _{2} \mathcal {L}^{\mathrm {pc}}, \end{aligned}$$
(9)

where hyperparameters \(\lambda _{1}\) and \(\lambda _{2}\) are trade-off parameters. We adopt DNNs as feature extractor and a linear layer as projector to generate latent feature representation \(\boldsymbol{z}_i\). Another linear layer following the feature extractor is used as classifier. When minimizing \(\mathcal {L}_{\mathrm {pc}}\), we apply mixup [31] to improve the generalization which has been shown to be effective for learning with noisy labels [29].

4 Experiments

We perform experiments on CIFAR-10 and CIFAR-100 datasets by controlling label noise ratio and imbalance factor of the training set. Additionally, we perform experiments on a commonly used dataset Webvision with real-world label noise.

4.1 Results on Simulated Datasets

Class-Imbalanced Dataset Generation. Formally, for a dataset with K classes and N training examples for each class, by assuming the imbalance factor is \(\rho \), the number of examples for the k-th class is set to \(N_k={N}/{\rho ^{\frac{k-1}{K-1}}}\).

Label Noise Injection. Let Y denote the variable for the clean label, \(\bar{Y}\) the noisy label, and X the instance/feature, the transition matrix \(T(X = x)\) is defined as \(T_{ij}(X) = \mathbb {P}(\bar{Y}=j \mid Y=i, X=x)\). In this work, we follow the setup in RoLT+ [28] by setting \(T(X = x)\) according to the estimated class priors \(\mathbb {P}(y)\), e.g., the empirical class frequencies in the training dataset. Formally, given the noise proportion \(\gamma \in [0,1]\), we define:

$$\begin{aligned} T_{ij}(X) = \mathbb {P}(\bar{Y}=j \mid Y=i, X=x) = \left\{ \begin{array}{ll} 1 - \gamma &{} i = j \\ \frac{N_j}{N - N_i} \gamma &{} \text{ otherwise. } \end{array}\right. \end{aligned}$$
(10)

Here, N is the size of training set and \(N_j\) is frequency of class j.

Table 1. Test accuracy (%) on CIFAR-10. \(^*\) denotes ensemble models.
Table 2. Test accuracy (%) on CIFAR-100. \(^*\) denotes ensemble models.

Result. We train the PreAct ResNet-18 network using SGD optimizer with momentum 0.9 for all methods. We set \(\lambda _1=1\) and \(\lambda _2=5\). We use \(\tau _0 = 0.1\) for CIFAR-10 and \(\tau =0.01\) for CIFAR-100. Tables 1 and 2 respectively summarize the results for CIFAR-10 and CIFAR-100 datasets. We compare our methods with several commonly used baselines for long-tailed learning (1–3) and learning with noisy labels (4–5). As shown in the results, previous methods dreadfully degrade their performance as the noise ratio and imbalance factor increase, while our methods retain robust performance. In particular, compared with CE, Prototypical Classifier improves the test accuracy by 9% on average. It can be observed that the improvement becomes more significant when the noise ratio is high, benefiting from proposed noise detection method.

As DivideMix [10] and RoLT+ [28] are two strong baselines in this task, (4) and (5) obtain much higher performance than (1–3), particularly when noise ratio is high. Although (4) and (5) use an ensemble of two networks, our method (6) outperforms them in most cases. On CIFAR-100, Prototypical Classifier achieves the best results among all the approaches and outperforms others by a large margin for both head and tail classes in Fig. 5.

Fig. 5.
figure 5

Experiment on CIFAR-100-LT. We show the accuracy for many (#inst >100), medium (#inst \(\in [20, 100]\)) and few (#inst < 20) classes.

4.2 Results on Real-World Dataset

We test the performance of our method on a real-world dataset. WebVision [14] contains 2.4 million images collected from Flickr and Google with real noisy and class-imbalanced data. Following previous literature, we train on a subset, mini WebVision, which contains the first 50 classes. In Table 3, we report results comparing against state-of-the-art approaches, including MentorNet [7], Co-teaching [4], ELR [15], HAR [1], and DivideMix [10]. We use InceptionResNet-v2 for all methods. We set \(\tau _0 = 0.05\), \(\lambda _1=1\) and \(\lambda _2=2\) in all experiments. From the results, we can see that, by using a single model, the proposed method achieves competitive performance with DivideMix and outperforms other baselines.

4.3 Ablation Studies

We examine the effectiveness of the each module of our method by removing it and comparing its performance with the full framework. The results are reported in Table 4. Generally, it is easy to see that removing any part of the method significantly drops the performance or even fails in some cases. The performance of re-weighting and dynamic threshold shows their great effectiveness for dealing with label noise. Though we do not use the normal classifier trained via \(\mathcal {L}_{ce}\), it is observed to help improve the representation learning. We have a similar observation for the unsupervised contrastive loss \(\mathcal {L}_{ce}\). The strong augmentation method AugMix [5] also provides substaintial improvement.

Table 3. Accuracy (%) on WebVision and ImageNet. \(^*\) denotes ensemble models.

Additionally, we also test our method on class-balanced training sets with label noise in Table 5. Prototypical Classifier outperforms other methods in most cases, even though both DivideMix and RoLT+ uses an ensemble of two networks, which shows the generality of Prototypical Classifier.

Table 4. Ablation studies. \(\rho =0.5\) and \(\gamma =100\). ( ) indicate performance loss (gain) compared with Prototypical Classifier.
Table 5. Accuracy (%) on class-balanced datasets. \(^*\) denotes ensemble models.

5 Conclusion

We propose Prototypical Classifier for learning with training set biases. Prototypical Classifier is shown to produce balanced predictions for all classes even when learned on class-imbalanced training set. This appealing property provides a way of detecting label noise by thresholding the predicted scores of examples. Experiments demonstrate the superiority of the proposed method. We believe Prototypical Classifier can motivate solutions to more problems with class-imbalanced training sets, for instance semi-supervised learning and self-supervised learning.