Keywords

1 Introduction

In machine learning, the purpose of data distillation [1] is to compress the original dataset while maintaining the performance of the models trained on it. The generalizability of the dataset is also needed. By this we mean the ability to train models of architectures that were not involved in the distillation process. Since training with less data is usually faster, distillation can be useful in practice. For example, it can be used to speed up a neural architecture search (NAS) task. Acceleration is achieved through the faster training of candidates. In many recent works [1, 3, 5,6,7], distillation is formulated as an optimization problem with the objects of a new dataset as parameters for optimization. Therefore, to distill the dataset for an image classification task, pixels of images have to be optimized. First, all new objects are initialized with random noise, then these objects are used to train a student (i.e., a randomly selected network). Then the student misclassification loss is calculated on real data. Finally, a gradient descent step is used to update the synthetic objects. Gradients can be calculated by backpropagating the error through the entire student’s learning process. The step of this procedure can be very time-consuming and memory-intensive, so there is a need for an alternative. In [2], the authors use the Implicit Function Theorem to solve the memory consumption problem. In [3], the data distillation problem has been reformulated to use gradient matching loss and speed up the optimization of synthetic objects and reduce memory usage. There is an alternative to optimizing the pixels of synthetic data. In [4], the authors suggest to optimize parameters of the generator model (a generative teaching network or GTN) to produce synthetic data from noise and labels. The disadvantage is that the authors used backpropagation through the learning process for optimization. Inspired by recent ideas in the field of data distillation, we propose replacing it with gradient matching or with implicit differentiation to make the procedure less computationally expensive. We have found that this allows not only to reduce memory costs but also to create more efficient and generalizable datasets. In addition, we investigate the use of augmentation in the distillation procedure and in models’ learning on distilled data.

The paper is divided into 7 sections. We first analyse the first data distillation algorithm [1] and discuss its problems in Sect. 2. A brief description of the algorithms for implicit differentiation [2] and gradient matching [3] can be found in Sects. 3 and 4. Section 5 presents the generative teaching network architecture that we use in our work. Section 6 contains the results of experiments with the MNIST image classification benchmark. In Sect. 6.1 we compare the results of all the described distillation methods, limiting the distillation time to a constant. In Sects. 6.2 and 6.3 we show results of new distillation techniques when training a generator with gradient matching and implicit differentiation, respectively. In Sect. 6.4 we study the use of augmentation by distillation, and in Sect. 6.5 we check the generalization of the data obtained with the new methods. Finally, we present our findings in Sect. 7. The code can be found on our GitHub page.Footnote 1

2 Backpropagation Through the Student’s Learning Process

Let \(\lambda \) be teacher parameters. These can be either GTN network’s parameters, or synthetic objects’ parameters (e.g. pixels of synthetic images). To update \(\lambda \), we must first train the student network \(\theta \) on synthetic data, minimizing the task specific loss \(\mathcal {L_S}\) (e.g. cross-entropy), and then get the loss on real data \(\mathcal {L_T}\). To take care of generalizability, student’s initialization goes from preset distribution \(p(\theta _0)\). Afterall, the optimization problem for \(\lambda \) can be formulated as follows:

$$\begin{aligned}&\lambda ^* := \mathop {\textrm{argmin}}\limits _{\lambda }~\mathbb {E}_{\theta _0 \sim p(\theta _0)} \mathcal {L_T^*}\text {, where}\\ \nonumber \mathcal {L_T^*} :=&\mathcal {L_T}(\theta ^*(\lambda )),\quad \quad \theta ^*(\lambda ) := \mathop {\textrm{argmin}}\limits _{\theta }~\mathcal {L_S}(\lambda , \theta ). \end{aligned}$$
(1)

To resolve the first problem (1) we can calculate gradient of \(\mathcal {L_T}\) with respect to \(\lambda \) to do the gradient descent step. In this work, we use cross-entropy loss as \(\mathcal {L_T}\) and there is an explicit dependence only on \(\theta \) and parameters of real data, so \(\frac{\partial \mathcal {L_T}}{\partial \lambda } = 0\) and \(\frac{\partial \mathcal {L^*_T}}{\partial \lambda } = \frac{\partial \mathcal {L_T}}{\partial \theta } \frac{\partial \theta ^*}{\partial \lambda }\). Thus, the main part is the calculation of \(\frac{\partial \theta ^*}{\partial \lambda }\). Where the dependence of \(\theta ^*\) on \(\lambda \) comes from a student’s training procedure. The first distillation algorithm was suggested in [1] and it is based on the assumption that the student’s learning procedure is differentiable. This means that we can backpropogate gradient through it. We will denote it as unroll. This algorithm can be implemented using the Higher library [10]. It allows to backpropogate through many optimizers, in our paper we use SGD with momentum [8]. This distillation method is both time and space consuming. To perform a single step of updating \(\lambda \) it is necessary to perform N student optimization steps, while all intermediate results (copies of the student weights) must be stored in memory. There is also a problem with the generalization of resulting synthetic dataset, the performance of models whose architectures were not involved in the distillation process is much lower. This negative effect can be mitigated by sampling the initialization and student architecture.

Note that the procedure of student’s training on the resulting synthetic dataset can be carried out in different ways. New data, parameterized with \(\lambda \), can be used as a single large batch or it can be split into several smaller ones. This split can be useful to reduce memory consumption per training step. Instead of random sampling of distilled objects, the authors of the original work propose to attach each of them to a specific batch. These batches would have the same order in each epoch. In our paper, we use the same schemes. Let ic (input count) be the number of batches of the synthetic dataset, note that it must be divisor of N. In our experiments we try limit values \(ic=1\) and \(ic=10\).

3 Implicit Differentiation

This method suggested in [2] is based on the Implicit Function Theorem:

Theorem 1 (Cauchy, Implicit Function Theorem)

Let \(\frac{\partial \mathcal {L_S}}{\partial \theta }(\lambda , \theta ): \varLambda \times \varTheta \rightarrow \varTheta \), be a continuously differentiable function. Fix a point \((\lambda ^{'}, \theta ^{'})\) with \(\frac{\partial \mathcal {L_S}}{\partial \theta } (\lambda ^{'}, \theta ^{'}) = 0\). If the Jacobian matrix \(\frac{\partial ^2 \mathcal {L_S}}{\partial \theta ^2}\) is invertible, then there exists an open set \( U \subseteq \varLambda \) containing \(\lambda ^{'}\) such that there exists a unique continuously differentiable function \(\theta ^*: U \rightarrow \varTheta \), such that \( \theta ^*(\lambda ^{'}) = \theta ^{'} \quad \text {and}\quad \forall \lambda \in U, \quad \frac{\partial \mathcal {L_S}}{\partial \theta } (\lambda , \theta ^*(\lambda )) = 0. \) Moreover, the partial derivatives of \(\theta ^*\) in U are given by the matrix product:

$$\begin{aligned} \frac{\partial \theta ^*}{\partial \lambda } (\lambda ) = - \Bigg [\frac{\partial ^2 \mathcal {L_S}}{\partial \theta ^2} (\lambda , \theta ^{*}(\lambda ))\Bigg ]^{-1} \frac{\partial ^2 \mathcal {L_S}}{\partial \theta \partial \lambda } (\lambda , \theta ^{*}(\lambda )). \end{aligned}$$
(2)

So, if there was an efficient way to invert the matrix, we would simply have used (2), after the student \(\theta \) has reached a local minimum, assuming \(\frac{\partial \mathcal {L_S}}{\partial \theta } (\lambda , \theta ^*(\lambda )) \approx 0\). But the inversion operation is time costly, so the authors used the approximation by the Neumann series taking the first few elements and controlling convergence with a hyperparameter \(\alpha \) (see (3)).

The resulting algorithm (see Algorithm 1) has no problems with memory consumption since there is no need to store copies of the student \(\theta \). And, despite the several subsequent approximations, the experimental results show that the method has a competitive performance (see Table 4). Note that grad in Algorithm 1 denotes the dot product between the Jacobian of the given function (func) at the given point (wrt) and a vector (vec). Another interesting detail of this method is that there is no dependence on which optimizer is used to train the student, and on the order (curriculum) of batches of synthetic data. So, in our paper we only use a single large batch of synthetic data. The original work [2] lacks a detailed description of the experimental results, so it can be found in our paper (see Sect. 6.3). We used the open-source codeFootnote 2 as the basis for the implementing the method.

$$\begin{aligned} \Bigg [\frac{\partial ^2 \mathcal {L_S}}{\partial \theta ^2} (\lambda , \theta ^{*}(\lambda ))\Bigg ]^{-1} \approx \alpha \sum \limits _{j=0}^N&\Bigg [ I - \alpha \frac{\partial ^2 \mathcal {L_S}}{\partial \theta ^2}(\lambda , \theta ^{*}(\lambda ))\Bigg ]^j. \end{aligned}$$
(3)
figure a

4 Gradient Matching

The gradient matching method (GM) was proposed in [3], and it solves a different problem than the general one (1). The main difference is that we want not only to train the student \(\theta \) to achieve a good performance on real data but also to get such a solution as if it was trained on real data. To formulate this let \(D(\nabla _{\theta }\mathcal {L_S}, \nabla _{\theta }\mathcal {L_T})\) be the function of how close one tensor is to another.

The distance function D is just the sum (in our paper for GTN experiments we used the mean) of the cosine distance functions for each student layer \(\theta ^l\). Let A and B be gradient tensors with respect to layer parameters. Let i be the index of the output axis (e.g. for a convolutional layer this is the index of the output channel). \(A_i\) and \(B_i\) are flat gradient vectors corresponding to each output element indexed by i. The most interesting detail here is that the authors [3] suggest to update \(\lambda \) after each step of student optimization, so now we don’t need to wait until it reaches a local minimum, as it was before. The authors also propose not to store student copies and to minimize \(D\big (\nabla _{\theta } \mathcal {L_{S}}(\lambda , \theta _{t-1}), \nabla _{\theta } \mathcal {L_{T}}(\theta _{t-1})\big )\) for each step separately. So there is no backpropagation through \(opt_{\theta }\). Both of these proposals make the gradient matching method very computational effective.

$$\begin{aligned}&\lambda ^* = \mathop {\textrm{argmin}}_{\lambda } \mathbb {E}_{\theta _0 \sim P_{\theta _0}} \Big [ \sum \limits _{n=1}^{N-1} D\big (\nabla _{\theta } \mathcal {L_{S}}(\lambda , \theta _{n}), \nabla _{\theta } \mathcal {L_{T}}(\theta _{n})\big )\Big ],\quad \text {where: }\\&D(\nabla _{\theta }\mathcal {L_S}, \nabla _{\theta }\mathcal {L_T}) = \sum \limits ^L_{l=1} d(\nabla _{\theta ^l}\mathcal {L_S}, \nabla _{\theta ^l}\mathcal {L_T}), \quad d(A, B) = \sum ^{\text {dim(A)}}_{i=1} \Bigg ( 1 - \frac{A_i \cdot B_i}{\Vert A_i\Vert \Vert B_i\Vert } \Bigg )\nonumber \end{aligned}$$
(4)
figure b

The peculiarity of this loss function is that the gradient of one synthetic object depends on other objects from the same batch, because of a normalization operation in the d equation (4). It makes the optimization problem harder and can cause negative effects (see Table 2). So authors decided to distill objects separately for each class. Note that the gradient matching is independent of the student training optimization algorithm. There is only one assumption that the direction should be based on the gradient. Another aspect is that the curriculum (the order of the synthetic batches in the student’s learning procedure) can be learned with this distillation method. We used an open-source codeFootnote 3 as the implementation of this method.

5 Generative Teaching Network

The idea first appeared in [4], where the authors suggested to use the generator as the teacher \(\lambda \). The input of the generator is a concatenation of noise and one hot encoded label (for conditional generation). In the original paper, the authors use backpropagation through the student’s learning process to train the generator, which is inconvenient for practical use due to high memory consumption, so in our paper, we show that the same or even better results can be achieved more efficiently by using gradient matching or implicit differentiation. Experimental results in [4] show that using a generator can help to improve students’ performance. In our paper, we check if we can improve distillation performance using larger generators. Note that the size of the generator in our experiments is controlled by the k hyperparameter (see Fig. 1). The generator consists of two linear layers and two convolutional layers. The output size of the first layer is k. And \(\lfloor k/2 \rfloor \times \text {width} \times \text {height of picture}\) is the output size of the second layer. \(\lfloor k/4\rfloor \) is the number of output channels of the first convolution. Hereinafter, unless otherwise indicated, we use the following notation: DD (Data Distillation) is a distillation, when the parameters of the teacher \(\lambda \) are pixels of synthetic images, and GTN is a distillation using a generator. Note that the generator has two modes: GTN-rnd is a generator with random noise as input, (GTN-lrn) is a generator with a learned input.

Fig. 1.
figure 1

Generator’s architecture; k is a hyperparameter to control network’s size, \(d=64\) is a generator’s input.

6 Experiments

6.1 Distillation with Time Limit

The neural architecture search (NAS) is one of the most promising areas for distillation and it is important to note that the time spent on distillation should be added to the time spent on the NAS, this idea was also mentioned in the reviewFootnote 4 of [4]. So, in this section, we check the performance of all known distillation methods. We think that it is fair to distill the data by all methods for the same limited time. We have chosen a time limit of \(\approx \)15 min, and it is based on common sense and the time spent on the NAS in similar experiments [3]. Note that this limit may not be accurate, as the distillation takes an integer number of steps, where each step takes a non-deterministic time. To check the performance we use the following scheme. First we train teacher \(\lambda \) with three restarts. The number of steps is determined by the time limit indicated above. Then, to get the final results we train five randomly initialized students \(\theta \) for each of the three teachers. Each student’s training takes 1000 optimization steps. In our work we use the MNIST [9] benchmark and make the same preparations as in [4]. We extract part of the training data for validation (10 thousand images) and use it to get the best teacher hyperparameters. We use \(|\mathcal {B^T}|=256\) batch size of training data. For the most of our experiments we use ConvNet [12] as a student. As student’s optimizer we use SGD with momentum with the same parameters as suggested in [3]. We use the same teacher optimizers as in the original papers [1, 3, 4]. The volume of synthetic data can be controlled by the ipc (images per class) parameter. For each table in this paper, the largest numbers in the column are shown in bold.

Table 1. The mean and standard deviation of test accuracy for different distillation algorithms.

Table 1 shows the mean and standard deviation of test accuracy, reached by students trained on distilled data. Note that there is only one difference from previous works: we use time limit for each distillation procedure, so there is a degradation in performance. For this experiment, we use \(K=1000, N=10\) as default hyperparameters values. To check the memory consumption we use a special tool,Footnote 5 which can measure the GPU memory usage. Note that using of the unroll distillation procedure consumes memory the most. The third column shows the number of teacher parameters, and although GTN (\(k=64\)) is twice as large as DD, there is not much difference in memory usage.

6.2 Training Generator with Gradient Matching

In this section we explore the use of the gradient matching to train the teacher generator. We first check the hyperparameters for this distillation method. N controls the frequency of the student’s reinitialization, \(\zeta _{\theta }\) controls the speed at which the teacher’s parameters are updated. Figure 2 (a–d) shows the non-trivial relationship between performance and the hyperparameter choice. We assume that such a dependence can be caused by the time limit and the fact that increasing the values of these hyperparameters may cause longer convergence. Note that in previous works [1, 3, 4] where no time limit was used, increasing ipc always resulted in better performance.

Fig. 2.
figure 2

Dependence of student’s performance and hyperparameters of distillation procedure. Next parameters used as default: \(ipc=10, ic=1, N=10, \zeta _\theta =10, k=64\).

Table 2. Mean and standard deviation of test accuracy for different distillation algorithms.
Table 3. Mean and standard deviation of test accuracy for different distillation algorithms.

Figure 2 (e) shows that the fixation of the generator input is really important for gradient matching distillation because teacher training (optimization of \(\lambda \)) diverges when using random input. Another important aspect mentioned above is that the gradient must be calculated per class. Table 2 shows the results for per class case and not. It seems that per class distillation gives significantly better results. Figure 2 (f) shows the accuracy achieved with data distilled with generators of different sizes (marked with different k), and without a generator (DD). This plot depicts the dependency between the number of synthetic images per class (ipc) and student’s performance on a test set. It seems that the correct size selection for the generator allows to get a better performance. More detailed results can be found in Tables 2 and 3. For experiment in Table 2, we use \(ipc=10,~ic=1,~N=10,~K=110,~\zeta _\theta =10\) and \(k=64\) for GTN as default hyperparameters values. For experiment in Table 3, we use \(k=64, ipc=50, K=35, N=10,\text { and}~\zeta _\theta =10\). Tables 2 and 3 show the GPU memory usage. It seems that ipc has a greater impact on memory usage than k, which is another benefit of using GTN. Note that the memory usage can be reduced by changing the ic value to optimize more synthetic images using smaller batches. Note that such a change can slow down the convergence.

6.3 Distillation with Implicit Differentiation

Fig. 3.
figure 3

The relation of the distillation method’s hyperparameters and test performance. We use as default: \(ipc=10, N=10, \zeta _\theta =10,\text { and } k=64\).

The method was proposed in [2], and we will abbreviate it as IFT (Implicit Function Theorem). As mentioned above (see Sect. 3), there is no detailed description of the results in the original paper, so they can be found in this section. Figure 3 (a–c) shows the relationship between the hyperparameters of the distillation method and the student’s performance on the test. We assume that these results can be explained by the fact that increasing the values of these hyperparameters decreases the frequency of \(\lambda \) update, which negatively affects the performance. The only exception is \(\zeta _{\theta }\).

Figure 3 (d) shows results for distillation using a generator with the random input (GTN-rnd). Such a generator can produce as much data as we need, but it can not converge when trained with gradient matching. It seems that such distillation becomes possible using implicit differentiation.

Table 4 shows the best results for each method. For this experiment, we use \(K=1080, \zeta _\theta =50, ipc=10,\text { and }~N=10\) as default hyperparameters values. The performance seems to be the same or even better compared to backpropagation through the training procedure unroll (see Table 1). Note the difference in memory usage in both tables. Also note that the implicit differentiation distillation is inferior to the gradient matching distillation.

We think this may be connected with the difference in the frequency of \(\lambda \) update. To do one update using IFT, we first have to train the student, which is not needed in case of GM. It is also important to note that this method is very sensitive to \(\alpha \) and \(\zeta _{\theta }\), and in some DD cases it starts to diverge after several iterations. Meanwhile the use of GTN makes the procedure more stable and allows for a more generalizable dataset (see Table 6).

Fig. 4.
figure 4

Synthetic images for MNIST classification task obtained with different distillation methods: a) GM+DD, b) IFT+DD, c) GM+GTN-lrn, d) IFT+GTN-lrn, e) GM+GTN-rnd, f) IFT+GTN-rnd. We use the same hyperparameters as mentioned in Table 5. Hyperparameters for GM+GTN-rnd are described in caption of Table 4.

Table 4. Mean and standard deviation of test accuracy for different distillation algorithms.

Figure 4 shows part of the final synthetic dataset for GM (see a, c and e) and IFT (see b, d and f). The greatest difference is obtained when data distilled without a generator (see a, b). Synthetic data obtained using implicit differentiation looks less realistic and therefore can be used for federative learning [13]. Also note that the images distilled using a generator are more contrast.

6.4 Distillation with Augmentation

In previous works, augmentation has been used in different ways. In [4] it takes place during distillation (let’s call it train augmentation) by applying transformations to real images \(\mathcal {B^T}\). In [1, 3] it is used when teaching student on synthetic data (let’s call it test augmentation). In our study, we decided to compare augmentation techniques. Table 5 shows the test performance for various distillation and augmentation techniques. It seems that for the MNIST classification problem only test augmentation gives improvements (see Tables 234). To augment images we use random crop and rotation. For this experiment, we use \(K=1080, ipc=10,~\zeta _\theta =10,\text { and }~N=10\) as default hyperparameters values.

Table 5. The mean and standard deviation of test accuracy for different distillation algorithms and different augmentations.

6.5 Generalizability

The generalization problem of distilled data was first mentioned in [1] and then studied in [4] and [7].

Table 6. The mean and standard deviation of test accuracy for different distillation algorithms and student’s architectures.

The problem is that such data can’t guarantee convergence for students which didn’t participate in the distillation procedure. And this problem is of great importance, since the main practical use of synthetic data is the NAS. For this experiment, we use \(K=1080, ipc=10,~\zeta _\theta =10,\text { and }~N=10\) as default hyperparameters values. Table 6 shows the results of students with different architectures trained on data distilled with different methods. For distillation we used ConvNet student’s architecture, all results were obtained with test augmentation. It seems that the best generalizability can be obtained using GTN and GM. For a comparison with ConvNet see the second column of Table 5.

7 Conclusion

This work explores all the latest ideas in dataset distillation field suggested in [1,2,3,4]. We honestly compared the performance of all known methods, limiting their running time. We also proposed new methods based on the joint use of generators and memory efficient methods. Experiments with the MNIST benchmark show that selecting the correct size for the generator allows to achieve better performance for gradient matching distillation, and improves the generalizability of implicit differentiation distillation. This paper also presents the results of augmentation impact on distillation. We also provide a detailed description of the experimental results for implicit differentiation distillation, as we could not find them in the original work [2]. As future work, we would like to experiment with much more diverse datasets and architectures. We also want to improve the distilled data generalizing ability using stochastic depth networks [11]. We are also interested in experiments with bringing the distribution of synthetic objects closer to the original one.