Keywords

1 Introduction

Over the past few decades, deep neural networks (DNNs) have enjoyed great success in computer vision fields [20, 25], such as real-time semantic segmentation [7], object detection [15]. However, powerful DNNs frequently have larger parameters and require large computational and storage resources, which are undesirable for industrial applications. To address this issue, a number of model compression techniques have been proposed, including model pruning [2, 16, 30], quantification [9], and knowledge distillation [5], with knowledge distillation proving to be a mature method for improving the performance of small models.

Traditional knowledge distillation (KD) [5] (see Fig. 1(a)) utilizes a soft label from a pretrained teacher to supervise students to obtain similar performance to the teacher, which is a two-stage training process and not flexible. Recently, online knowledge distillation [14, 25] proposed a single-stage scheme to encourage networks to train each other and retrain teacher and student networks to improve consistency on different points of view [1, 4, 23]. For example (see Fig. 1(b)), deep mutual learning (DML) [22] predictions after the classifier of teacher and student. Chung et al. [3] introduces the middle layer feature transition between teacher and student, as shown in Fig. 1(c). The existing online knowledge distillation method is a way of teaching and learning collaboratively, and we hope to further enhance this collaboration, bringing students and teachers together as a whole.

Fig. 1.
figure 1

Illustration of (a) KD, (b) DML, (c) DML with Feature comparison, and (d) Knowledge distillation framework with joint regularization loss.

When the differences between teacher and student models are too great, distillation can adversely affect students [17]. Strengthening connections between teacher and student networks can improve distillation performance, from traditional knowledge distillation to online knowledge distillation (see Sect. 4.2,“Proof”). Recently, Wei et al. [29] proposes a robust federated learning method called Jocor to maximize similarity between DNNs by reducing their. Based on this insight, we believe that the teacher and student can obtain consistent joint supervision in predictions, enhancing the integrity of the two classifiers, and thus improving the distillation performance, as shown in Fig. 1(d).

This paper proposes a new knowledge distillation framework called Joint Regularization Knowledge Distillation (JRKD). Specifically, we train teacher and student networks through joint losses to maximize consistency between the two networks. Inspired by “course learning” [21], we propose a confidence-based continuous scheduler method (CBCS), which divides examples into center examples and edge examples based on their-confidence density distributions calculated teacher and student. The central example reduces the prediction error in the joint training of the two networks, promotes their mutual learning, and reduces the accumulation of error streams in the network. The proportion of central examples gradually increases as the training process progresses, ensuring the integrity of the training set. Extensive experiments on three representative benchmarks have shown that our JRKD can effectively train a high-performance student network.

  1. 1.

    We propose a joint regularized knowledge distillation method(JRKD), which can effectively reduce the differences between networks.

  2. 2.

    We used federated regularized loss to normalize teacher and student networks to maximize consistency across networks.

  3. 3.

    We develop a confidence-based continuous scheduling method (CBCS), through which the selection of loss instances can mitigate the negative impact between networks and reduce the difficulty of consistency training.

2 Related Literature

In this section, we will discuss the work related to online knowledge distillation and Disagreement. In both areas, various approaches have been proposed over the past few years. We summarize it below.

2.1 Online Knowledge Distillation

Traditional knowledge distillation is achieved by a network of pre-trained teachers who take their knowledge (extracted logits [5] or intermediate feature forms [11]) and guide students to models during the training process. This method is simple and effective, but it requires a high-performance teacher model. In online knowledge distillation, the teacher and student models update their own parameters at the same time to achieve an end-to-end process. The concept of online distillation was first proposed by Zhang et al. [22] to enable collaborative learning and mutual teaching between students and teachers. In order to address the impact between network mutual learning, SwitKD [25]adaptively calibrates the gap during the training phase through a switching strategy between the two modes of expert mode (pause teacher, keep student learning) and learning mode (restart teacher), so that teachers and students have appropriate distillation gaps when learning from each other. Chung et al. [26] adds a feature-map-based judgment to the original logit-based prediction, and the feature-map-based loss controls the teacher and student to distill each other through the adjudicator.

Online distillation is a single-stage training scheme with efficient parallel computation. The existing online knowledge distillation method is a way of teaching and learning collaboratively, and we hope to further enhance this collaboration, bringing students and teachers together as a whole, rather than individually.

2.2 Disagreement

Weakly supervised learning [27] solves the problem of time-consuming and labor-intensive collection of large and accurate data sets, and the use of online queries and other methods will inevitably be affected by noise labels. In recent years, the “divergence” strategy has been introduced to address such issues. For example, decoupling [19]uses two different networks, and when there is no difference in the predictions of the two networks, the network parameters are not updated, and the network is updated when there is a disagreement. “Divergence” strategy expectations use these examples that produce different predictions to steer the network away from current errors. In 2019, Chen et al [12] combined the “divergence” strategy with Co- teaching [28] in collaborative teaching to provide good performance in terms of DNN’s robustness to noise tags. Recently, Wei et al. [29] proposed a robust learning paradigm called JoCoR from different perspectives, which aims to reduce the diversity of training examples of two networks during training, and update the parameters of two networks at the same time by selecting examples with small losses. Under the training of joint loss, the two networks will become more and more similar due to the effects of coregularization.

We hope to be able to use the idea of “divergence” strategy in the field of knowledge distillation, aiming to reduce the differences between teacher and student networks, thereby improving the integrity between networks and improving distillation performance.

Fig. 2.
figure 2

JRKD flowchart, CBCS selects a central example based on the network output’s example confidence, and teachers and students receive joint supervision training.

3 Approach

In this section, we will discuss how CBCS selects central examples (Sect. 3.1) and how joint regularization loss trains the network collaboratively (Sect. 3.2).

3.1 Confidence-Based Continuous Scheduler

According to recent research [29], while networks can improve consistency between them through joint regularization, they are vulnerable to error streams caused by biased selection. To address this issue, we design a confidence-based continuous scheduler (CBCS) that divides the example into center examples and edge examples. Using central example training can better reduce the prediction bias between networks. This is shown in Fig. 2.

Different center examples are chosen by teachers and students; we only show how teachers choose, and students do the same. We use dataset \(\mathcal D=\left\{ \left( \boldsymbol{x}_{i},{y}_i\right) \right\} _{i=1}^n\) as the network input for each batch with n examples. Let the teacher network be \({\mathcal N}_T\), and the prediction probability on the dataset \(\mathcal D\) be \({{\mathcal N}_T({\textbf{x}}_i)}_{i=1}^n \). For Class m classification tasks, \(\max ({\mathcal N}_S({\textbf{x}}_i))\) represents the maximum confidence that the teacher network prediction instance \(\boldsymbol{x}_{i}\) is for one of the classes in class m. The KMeans clustering algorithm is used to obtain the maximum confidence centroid of n examples:

$$\begin{aligned} M_{-} p_{\text{ target } }=\frac{\sum _{i=1}^{N} \max ({\mathcal N}_S({\textbf{x}}_i))}{N}. \end{aligned}$$
(1)

\(M_{-} p_{\text{ target } }\) is the centroid of confidence. Calculating the absolute distance from each \(\max ({\mathcal N}_S({\textbf{x}}_i))\) to \(M_{-} p_{\text{ target } }\) yields the set \({d}_{t}=\) \(\left[ {d}_{1}, {d}_{2},\ldots , {d}_{n}\right] \). The smaller the value in \({d}_{t}\), the closer to the confidence center.

CBCS controls the central example proportion for each period through a continuous scheduling functions lambda(t). \(T_{\text{ total } }\) is the total training cycle, \(\lambda _{0}\) Represents the proportion of the initial central example selection, t stands for epock currently trained:

$$\begin{aligned} lambda(t)=\min \left( 1, \lambda _{0}+\frac{1-\lambda _{0}}{T_{\text{ total } }} \cdot t\right) . \end{aligned}$$
(2)

By using \({d}_{t}\) as the basis for selecting the central example, the example size is controlled by the lambda(t). \(\textrm{Index}(\cdot ,\cdot )\) is a function method that returns an index of multiple minimum values. Get the current set of central example \(D_{t}\), the same process can also obtain \(D_{s}\):

$$\begin{aligned} D_{t}=D \times { Index }\left( {\text {lambda}}(t) \times N, d_{t}\right) , \quad D_{t} \in D, \end{aligned}$$
(3)
$$\begin{aligned} D_{s}=D \times { Index }\left( {\text {lambda}}(t) \times N, d_{s}\right) , \quad D_{s} \in D. \end{aligned}$$
(4)

3.2 Joint Regularization Knowledge Distillation

For the multi-class classification task for class m. We use two deep neural networks to express the proposed JRKD method. For clarity, we set \({p}_{s}=\left[ p_{s}^{1}, p_{s}^{2}\ldots , p_{s}^{m}\right] \) and \({p}_{t}=\) \(\left[ p_{t}^{1}, p_{t}^{2}, \ldots , p_{t}^{m}\right] \) as the final prediction probabilities of the example \(\boldsymbol{x}_{i}\) by students and teachers, respectively. It is obtained by softening the network output by the softmax function of distillation temperature \(T = 3\).

Joint Regularization Loss. We train the two networks together using joint regularization loss, which brings the predictions of each network closer to the peer-to-peer network. Under joint training, networks will become more and more similar to each other. To accomplish this, asymmetric Kullback-Leibler (KL) divergence is used:

$$\begin{aligned} \mathcal {L}_{con}=D_{\textrm{KL}}\left( \boldsymbol{p}_{s} \Vert \boldsymbol{p}_{t}\right) +D_{\textrm{KL}}\left( \boldsymbol{p}_{s} \Vert \boldsymbol{p}_{t}\right) . \end{aligned}$$
(5)

\(\mathcal {L}_{con}\) represents the joint regularization loss. CBCS selects different central examples to participate in joint training based on the confidence probability of the examples generated by teachers and students:

$$\begin{aligned} D_{\textrm{KL}}\left( \boldsymbol{p}_{s} \Vert \boldsymbol{p}_{t}\right) =\sum _{i=1}^{N} \sum _{m=1}^{M} p_{s}^{m}\left( \boldsymbol{x}_{i}\right) \log \frac{p_{s}^{m}\left( \boldsymbol{x}_{i}\right) }{p_{t}^{m}\left( \boldsymbol{x}_{i}\right) }, x \in D_{t}, \end{aligned}$$
(6)
$$\begin{aligned} D_{\textrm{KL}}\left( \boldsymbol{p}_{t} \Vert \boldsymbol{p}_{s}\right) =\sum _{i=1}^{N} \sum _{m=1}^{M} p_{t}^{m}\left( \boldsymbol{x}_{i}\right) \log \frac{p_{t}^{m}\left( \boldsymbol{x}_{i}\right) }{p_{s}^{m}\left( \boldsymbol{x}_{i}\right) }, x \in D_{s}. \end{aligned}$$
(7)

Total Losses. For JRKD, the joint regularization loss is used to improve the integrity between the networks, and the conventional supervision loss is used to maintain the correctness of the learning. JRKD minimizes the following losses to train the network:

$$\begin{aligned} \mathcal {L}_{T}=\mathcal {L}_{T C E}+\mathcal {L}_{con}, \end{aligned}$$
(8)
$$\begin{aligned} \mathcal {L}_{S}=\mathcal {L}_{S C E}+\mathcal {L}_{con}. \end{aligned}$$
(9)

\(\mathcal {L}_{S C E}\) and \(\mathcal {L}_{T C E}\) represent conventional supervision loss for students and teachers, respectively. Finally, we give the algorithm flow table of JRKD, as shown in Algorithm 1.

Algorithm 1
figure a

. JRKD

4 Experiments

In this section, we select three representative image classification tasks for experiments in Sect. 3.1 to evaluate the performance of JRKD. The ablation experiment at Sect. 3.2 confirmed the effectiveness of CBCS and loss of joint regularization. In addition, we analyze the effect of \(\lambda _{0}\) initial center example ratio on performance. In Sect. 3.3, visualize the probability distribution of teacher and student network outputs.

Experiment Setup. The configuration of our experiment is to descend SGD with a stochastic gradient and set the learning rate, weight decay, and momentum to 0.1, \(5 \times 10^{-4}\), and 0.9, respectively. The dataset uses a standard data augmentation scheme and normalizes [17] the input image using channel means and standard deviations.

4.1 Experiments on Benchmarks

Results on Tiny-ImageNet. It contains 200 categories, each containing 500 training images, 50 validation images, and 50 test images. After using JRKD, the two groups of networks obtained an accuracy of \(59.43\%\) and \(55.71\%\), respectively. It can effectively improve the accuracy of the student network. Compare these methods, Our method also achieves good results. The results are shown in Table 1.

Table 1. The accuracy of the comparison method comes from the papers of other authors. JRKD verified accuracy results on the Tiny-ImageNet dataset.

Results on CIFAR-100. The CIFAR-100 dataset has 100 classes. Each class has 500 sheets as a training set and 100 as a test set. Table 2 shows the experimental results, and JRKD outperforms many other methods on various network architectures. Impressively, JRKD achieves \(1.33\%\) (WRN-40-2/WRN-16-2) accuracy improvement to DML on CIFAR-100. Besides, JRKD also shows \(0.88\%\) and \(0.19\%\) (ResNet\(32\times 4\)/ResNet\(8\times 4\)) accuracy gain over ReviewKD and DKD, respectively.

Table 2. JRKD verified accuracy results on the CIFAR-100 dataset. W40-2, R32x4, R8x4 and SV1 stand for WRN-40-2, ResNet\(32\times 4\), ResNet\(8\times 4\),ShuffleNetV1. The accuracy of other methods is mainly derived from DKD [22].

Results on CIFAR-10. The CIFAR-10 dataset has a total of 60,000 examples, which are divided into 50,000 training examples and 10,000 test examples. The experimental results are shown in Table 3, using the same experimental configuration as other methods. Our method not only improved student performance, the teacher achieved an accuracy gain of \(0.23\%\) and \(0.8\%\) over SwithOKD and KDCL, respectively.

Table 3. Ours results are the average over 5 trials. Comparison of performances with powerful distillation techniques using the same 200 training epochs. Performance metrics refer to the original article.

4.2 Ablation Experiments

CIFAR-100 was chosen for the dataset of the ablation experiment. As shown in Table 4, we quantified the gap between teachers and networks using T-S gap, and compared KD and DML, JRKD can effectively reduce the differences between networks and improve distillation performance. The JRKD\(\dagger \) compared other distillation methods and showed that joint regularization loss can improve similarity between networks. The comparison of JRKD and JRKD\(\dagger \) shows that CBCS is beneficial for online training. In addition, the sensitivity analysis of the \(\lambda _{0}\) parameter manually set in the continuous scheduler lambda(t) was performed. As shown in Table 5, The value of \(\lambda _{0}\) in the continuous scheduler generally defaults to 0.3, so we only analyze the value around 0.3 and find that the appropriate \(\lambda _{0}\) is conducive to distillation.

Table 4. Verify the effectiveness of joint regularization losses and CBCS. The student network is MobileNetV2, the teacher is the VGG13, \(K D_{T \rightarrow S}\) represents the teacher network to accept student supervision, Top-1 is the classification accuracy of CIFAR-100, T-S gap uses KL to calculate the gap between output logical values between networks. JRKD\(\dagger \) refers to the absence of CBCS to select loss instances.
Table 5. The parameter sensitivity experiment of the continuous scheduler of CBCS. The experimental data set uses CIFAR-100, and the experimental accuracy result is averaged 5 times.

4.3 Visual Analytics

We compare the traditional online knowledge distillation method DML and the JRKD by feeding the same batch of examples into the trained network and visualizing the confidence distribution of the examples by the teacher-student network. As shown in Fig. 3, the confidence distribution of the teacher-student network is more similar in the example output of JRKD, demonstrating that JRKD can improve network similarity.

Fig. 3.
figure 3

Two different methods produce confidence profiles.

5 Conclusion

This paper proposes an effective method called JRKD to reduce the differences between networks. The key idea of JRKD is to train the teacher and student networks by jointly regularizing losses to maximize consistency between the two networks. In order to reduce the difficulty of federation, we developed a confidence-based continuous scheduling method (CBCS), which can divide samples into central samples and edge samples according to the sample confidence distribution of network output. In the early stage of joint training, when training with central examples, the prediction difference between networks is reduced, and edge samples are added to the training with the training cycle to ensure the integrity of the training samples. We demonstrated the effectiveness of JEKD with a large number of experiments, and analyzed the joint regularization loss and the training aid of CBCS through ablation experiments. In future work, we will continue to explore the correlation between teacher networks and student networks as a whole training in online knowledge distillation.