Keywords

1 Introduction

Over the past decade, deep neural networks (DNNs) have evolved to the de facto a standard approach for most if not all computer vision tasks, yielding unprecedentedly promising results. Due to the time- and resource-consuming DNN training process, many developers have generously released their pretrained models online, so that users may adopt these models in a plug-and-play manner without training from scratch. Nevertheless, pretrained DNNs often come with heavy architectures, making them extremely cumbersome to be deployed in real-world scenarios, especially resource-critical applications such as edge computing. Numerous endeavors have thus been made towards reducing the sizes of DNNs, among which one mainstream scheme is known as Knowledge Distillation (KD). The goal of KD is to “distill” knowledge from a large pre-trained model known as a teacher, to a compact model known as a student. The derived student is expected to master the expertise of the teacher yet come with a much smaller size, making it applicable to edge devices. Since the seminal work of [20], a series of KD approaches have been proposed to strengthen the performances of student models [47, 51, 66].

Fig. 1.
figure 1

Illustration of (top) 3 types of Knowledge Distillation and (bottom) our proposed Knowledge Factorization. (a) Single-Task Learning to Single-Task Learning (STL2STL) KD refers to distill a single-tasked student from a single-tasked teacher, (b) Multi-Task Learning to Multi-Task Learning (MTL2MTL) KD stands for distilling a multi-tasked student from a multi-tasked teacher and (c) Sub-Knowledge Distillation distill a subset of the teacher’s knowledge to its student model.

Albeit encouraging results achieved, KD has largely been treated as a black-box procedure, in which the intrinsic knowledge flow process remains opaque. Consequently, the derived student model may inherit the teacher’s task-wise competence but unfortunately lacks interpretability, since it is unclear how and what knowledge has been transferred to the student. In addition, as demonstrated in Fig. 1(a) and (b), conventional KD assumes that teacher and student models master homogeneous tasks or knowledge, which greatly limits its wide applications. Even if it is allowed to distill a subset of knowledge from the teacher, shown in Fig. 1(c), the problem setup of KD, by nature, overlooks the scalability of the student. For example, given a versatile classification teacher pretrained on ImageNet, if we are to learn two students, one handling cat-dog classification and one handling cat-fish, we will have to carry out the KD twice; if, however, we are to learn all k-class classification students from a pool of 1, 000 classes, we will have to conduct KD for \(\sum _{k=1}^{1000} \left( {\begin{array}{c}1,000\\ k\end{array}}\right) = 2^{1000}\) times, which is computational intractable.

In this paper, we introduce a novel task, termed Knowledge Factorization (KF), that alleviates the aforementioned flaws of KD at a problem-setup level. The core idea of KF regards the modularization and assemblability of knowledge: given a pretrained teacher, KF decomposes it into several factor networks, each of which masters one specific knowledge factorized from the teacher, while remaining disentangled with respect to others. Moreover, these factor networks are expected to be readily integratable, meaning that we may directly assemble multiple factor networks, without any fine-tuning, to produce a more competent multi-talented network. As shown in Fig. 1(d), those factor networks can be organized into a open-sourced model hub. At the same time, users could treat them as Lego-brick-like units of knowledge to build customized networks in a plug-and-play fashion, thereby lending itself to great scalability. Furthermore, the disentanglement property effectively enables the IP protection of network knowledge: since the factor networks are learned in a disentangled manner, they possess only task-specific knowledge, allowing the network owners to selectively conduct knowledge transfer without leaking knowledge of other tasks.

Admittedly, the aims of KF are unarguably ambitious, since the factor networks are, again, expected to be modularized and readily integratable, and meanwhile knowledge-wise disentangled and hence more interpretable. Notably, despite orthogonal in expertise, these factor networks will inherit the common knowledge shared by all tasks. As such, each factor network should be designed to account for both the task-agnostic commonality and its task-relevant specialization, which in turn reduces the overall parameter overhead for KF. As demonstrated in Fig. 1, given n types of knowledge, sub-KD requires an exponential number of \(2^n\) models, each with S parameters, while KF reduces the model number to a linear scale, with one full-sized common knowledge model and n mini models, each with s parameter, where \(s\ll S\).

To this end, we propose a dedicated scheme for conducting KF, that comprises two mechanisms, namely structural factorization and representation factorization.

  • Structural Factorization. Structural factorization decomposes the teacher network into a set of factor networks with different functionalities. Each factor network comprises a shared common-knowledge network (CKN) and a task-specific network (TSN). CKN extracts task-agnostic representations to capture the commonality among tasks, whereas the TSN accounts for task-specific information. Factor networks are trained to specialize in an individual task via fusing task-agnostic and task-specific knowledge.

  • Representation Factorization. Representation factorization disentangles the shared knowledge and task-level representations into statistically independent components. For this purpose, we introduce a novel information-theoretical objective, termed InfoMax Bottleneck (IMB). It maximizes the mutual information between input and the common features to encourage the lossless information transmission in CKN. Meanwhile, IMB minimizes data-task mutual information to ensure that, the task features are only predictive for a specific task. Specifically, we derive a variational lower bound for IMB to practically optimize this loss.

By integrating both mechanisms, we demonstrate in the experiments that KF indeed achieves architecture-level and representation-level disentanglement. Different from KD that transmits holistic knowledge in a black-box manner, KF offers unique interpretability for the factor networks through the knowledge transfer. Moreover, the learned common-knowledge representations facilitate the transfer learning to unseen downstream tasks, as will be verified empirically in our experiments.

Our contribution are therefore summarized as follows

  • We introduce a novel knowledge-transfer task, termed Knowledge Factorization (KF), which accounts for learning factor networks that are modularized and interpretable. Factor networks are expected to be readily integratable, without any retraining, to assemble multi-task networks.

  • We propose an effective solution towards KF. Our approach decomposes a pretrained teacher into factor networks that are task-wise disentangled.

  • We design an InfoMax Bottleneck objective to disentangle the representation between common knowledge and the task-specific representations, by exerting control over the mutual information between input and representations. We derive its variational bound for its numerical optimization.

  • Our method achieves strong performance and disentanglement capability across various benchmarks, with better modularity and transferability.

2 Related Work

Knowledge Distillation. Knowledge distillation (KD) [20] refers to the process to transfer the knowledge from one model or an ensemble of models to a student model. KD is originally designed for model compression [5, 31, 36, 50, 55, 63], but it has been found to be beneficial in other tasks like adversarial defense [46], domain adaptation [15, 43], continual learning [32, 67] and amalgamate the knowledge from multiple teachers [23, 38, 64]. Different from the common KD methods that disseminates knowledge as a whole, we factorize the knowledge of a multi-talented teacher to factor networks with disentangled representations.

Disentangled Representation Learning. It is often assumed that real-world observations should be controlled by factors. Therefore, a recent line of research argues the importance of finding disentangled variables in representation learning [4, 13, 35, 44, 48, 62] while providing invariance in learning [1, 14, 22]. The disentanglement are usually done through adversarial learning [10, 34, 40, 58] or variational auto-encoder [7, 19, 26]. In this work, we aim to disentangle the task-agnostic and task-related representation by optimizing the mutual information.

InfoMax Principle and Information Bottleneck. As one of the foundations of machine learning, information theory has promoted a series of learning algorithms. InfoMax [33] is a core principle of representation learning that encourages the mutual information should be maximized between multi-views or between representation and input. This principle gave birth to the recent trend on self-supervised learning [2, 21, 59] and contrastive learning [9, 16, 17, 25, 45, 56]. On the contrary, Information Bottleneck (IB) [57] aims to compress the representation while achieving realistic reconstruction results. In this study, we take a unified view of the two principles in multi-task learning. Infomax guarantees the learning of common knowledge across tasks, while IB promotes task-specific knowledge for an individual task.

Multi-task Learning. Multi-task learning (MTL) is designed to train models that handle multiple tasks by taking advantage of the common information among tasks. Some recent solutions explore on the decomposition between shared and task-specific processing [24, 39, 68]. Unlike conventional methods, we decompose a pre-training model into knowledge modules according to tasks.

3 Method

The essence of this work is to factorize a multi-task teacher into independent students by posing fine-grained control of the information among teacher and students. Figure 2 provides an overall sketch of our proposed KF. In what follows, we first give a definition of knowledge factorization, and then introduce the general procedure to decompose a teacher into factorized students.

Fig. 2.
figure 2

The overall framework of the proposed knowledge factorization. The factor networks are trained to mimic the prediction of the teacher. The CKN learns to maximize the mutual information between input and its features, whereas the TSNs are dedicated to minimizing the task-wise mutual information.

3.1 Knowledge Factorization in Neural Network

We define Knowledge Factorization (KF) to be the process of subdividing a teacher network into multiple factor networks, each of which possesses distinctive knowledge to handle one task. Formally, assume we have a multi-task dataset \(\mathcal {D}=\{(\textbf{x}_i,y_i^{1},\dots ,y_i^{K})\}\), where each input sample \(\textbf{x}\) may take one of K different labels \(\{y^j\}_{j=1}^K\) sampled from the joint probability \(P(X, Y_1, \dots , Y_K)\). With a loose definition, we also deem the multi-classing as a special case for multi-tasking, by considering each or a group of categories as a task. Given a multi-task teacher model \(\mathcal {T}\) that is able to predict K tasks simultaneously, KF aims to construct K factor networks \(\{\mathcal {S}_j\}_{j=1}^K\), each of which, again, tackles one task independently.

Specifically, we focus on decomposing the teacher knowledge into task-specific and common representations, meaning that each factor network not only masters task-specific knowledge, but also benefits from a shared common feature to make final predictions. To this end, we design two mechanisms to factorize knowledge: structural factorization to decompose the teacher network into a set of factor networks, as well as representation factorization to disentangle the common features from task-specific features by optimizing mutual information.

3.2 Structural Factorization

The goal of structural factorization is to endow different sub-networks with functional distinctions. Each factor networks is expected to inherit only a portion of the knowledge from the teacher, and specializes in an individual task. Specifically, a factor network \(\mathcal {S}_j\) for the j-th task comprises two modular networks: a Common Knowledge Network (CKN) \(\mathcal {S}_{C}(\cdot ;\varTheta _{\mathcal {S}_{C}})\) which is shared across all tasks, and a Task-specific Network (TSN) \(\mathcal {S}_{T_j}(\cdot ;\varTheta _{\mathcal {S}_{T_j}})\) which is task-exclusive. \(\varTheta _{\mathcal {S}_{C}}\) and \(\varTheta _{\mathcal {S}_{T_j}}\) are the model parameters for CKN and TSN respectively. For each input sample, \(\mathcal {S}_{C}\) is adopted to extract the task-agnostic feature \(\textbf{z}\):

$$\begin{aligned} \textbf{z} = \mathcal {S}_{C}(\textbf{x};\varTheta _{\mathcal {S}_{C}}). \end{aligned}$$
(1)

On the contrary, \(\mathcal {S}_{T_j}\) learns the task-related knowledge \(\textbf{t}^j\) from the input \(\textbf{x}\), which together with \(\textbf{z}\) is processed by a task head \(\mathcal {H}_{j}\) to make the final prediction:

$$\begin{aligned} \textbf{t}^j = \mathcal {S}_{T_j}(\textbf{x};\varTheta _{\mathcal {S}_{T_j}}); \hat{y}_S^j = \mathcal {H}_{j}(\textbf{z}, \textbf{t}^j;\varTheta _{\mathcal {H}_{j}}), \end{aligned}$$
(2)

which constrains each factor network \(\mathcal {S}_j\) to share the same common knowledge network but maintain the task-specific one to handle different tasks.

Intuitively, we expect that \(\mathcal {S}_j\) only masters the knowledge about task j by using the common knowledge \(\textbf{z}\) and \(\textbf{t}^j\). We accordingly define a structure factorization objective \(\mathcal {L}_{sf}^{(j)}\) to enforce each single-task factor network to imitate the teacher’s prediction while minimizing the supervised loss:

$$\begin{aligned} \mathcal {L}_{sf}^{(j)} = \mathcal {L}^{(j)}_{\text {sup}} + \lambda _{\text {kt}}\mathcal {L}^{(j)}_{\text {kt}}, \end{aligned}$$
(3)

where \(\mathcal {L}^{(j)}_{\text {sup}}\) and \(\mathcal {L}^{(j)}_{\text {kt}}\) denote the supervised loss and the knowledge transfer loss for the j-th task, respectively, and \(\lambda _{\text {kt}}\) is the weight coefficient. Notably, we may readily adopt various implementations for each of the loss terms here. For example, \(\mathcal {L}^{(j)}_{\text {sup}}\) may take the form of L2 norm for regression and cross-entropy for classification, while \(\mathcal {L}^{(j)}_{\text {kt}}\) may take the form of soft-target [20], hint-loss [51], or attention transfer [66]. More details can be found in the supplement.

Structure factorization therefore enables us to construct new combined-task models by assembling multiple networks without retraining. If, for example, a 3-category classifier is needed, we can readily integrate CKN and the corresponding 3 TSNs from the pre-defined network pool. This property, in turn, greatly improves the scalability of the model.

3.3 Representation Factorization

Apart from the functionality disentanglement, we hope that learned representations of the factor networks are statistically independent as well, so that each sub-network masters task-wise disentangled knowledge. This means task-specific features should only contain minimal information only related to a certain task, while the common representation contains as much information as possible.

To this end, we introduce the Infomax Bottleneck (IMB) objective to optimize the mutual information (MI) between features and input. For two random variables XY, MI \(\mathcal {I}(X,Y)\) quantifies the “number information” that variable X tells about Y, denoted by Kullback Leibler (KL) divergence between the joint probability \(p(\boldsymbol{x}, \boldsymbol{y})\) and the product of marginal distribution \(p(\boldsymbol{x})p(\boldsymbol{y})\):

$$\begin{aligned} \begin{aligned} \mathcal {I}(X,Y)&= D_{KL}\Big [p(\boldsymbol{x},\boldsymbol{y})||p(\boldsymbol{x})p(\boldsymbol{y})\Big ]. \end{aligned} \end{aligned}$$
(4)

In our problem, for each input sample \(\textbf{x}\sim P(X)\), we compute its common knowledge feature \(\textbf{z}\sim P(Z)\) and the task-predictive representation \(\textbf{t}^j \sim P(T_j)\). Ultimately, IMB attempts to maximize \(\mathcal {I}(X, Z)\) so that common knowledge keeps as much information of the input as possible, while minimize \(\mathcal {I}(X, T_j)\) so that task representations only preserve information related to the task. The representation disentanglement can then be formulated as an optimization problem:

$$\begin{aligned} \begin{aligned} \max \mathcal {I}(T_j, Y_j); \quad \text {s.t. } \mathcal {I}(X,T_j) \le \epsilon _{1}, -\mathcal {I}(X, Z) \le \epsilon _{2}, \end{aligned} \end{aligned}$$
(5)

where \(\epsilon _1\) and \(\epsilon _2\) are the information constraints we define. In order to solve Eq. 5, we introduce two Lagrange multiplier \(\alpha>0,\beta >0\) to construct the function:

$$\begin{aligned} \mathcal {L}^{(j)}_I = \mathcal {I}(T_j, Y_j) + \alpha \mathcal {I}(X, Z) - \beta \mathcal {I}(X, T_j). \end{aligned}$$
(6)

By maximizing the first term \(\mathcal {I}(T_j, Y_j)\), we ensure that the task representation \(\textbf{t}^j\) is capable to accomplish individual task j. \(\mathcal {I}(X, Z)\) term encourages the lossless transmission of information and high fidelity feature extraction for the CKN, while minimizing \(\mathcal {I}(X, T_j)\) enforces the only the task-informative representation is extracted by TSN, thus de-correlate the task knowledge \(\textbf{t}^j\) with the common knowledge \(\textbf{z}\). Unlike the convectional information bottleneck (IB) principle [57], our proposed IMB attempts to maximize \(\mathcal {I}(X, Z)\) [21, 37, 45], so that the CKN learns a general representation \(\textbf{z}\) with high fidelity.

3.4 Variational Bound for Mutual Information

Due to the difficulty of estimating mutual information for continuous variables, we derive a variational lower bound to approximate the exact IMB objectiveFootnote 1:

$$\begin{aligned} \begin{aligned} \hat{\mathcal {L}_I}&= \mathbb {E}_{p(\textbf{y}_j, \textbf{t}_j)}[\log q(\textbf{y}_j| \textbf{t}_j)] + \alpha \big (\mathbb {E}_{p(\textbf{z}, \textbf{x})}[\log q(\textbf{z}| \textbf{x})] + H(Z)\big ) - \beta \mathbb {E}_{p(\textbf{t}_j)}\Big [D_{KL}[p(\textbf{t}_j|\textbf{x})||q(\textbf{t}_j)]\Big ], \end{aligned} \end{aligned}$$
(7)

where \(D_{KL}\) denotes the KL divergence between two distributions and \(q(\cdot )\) denotes the variational distributions. We claim that \(\mathcal {L}_I \ge \hat{\mathcal {L}_I}\), with the equality acheived if and only if \(q(\textbf{y}_j| \textbf{t}_j) = p(\textbf{y}_j| \textbf{t}_j)\), \(q(\textbf{z}| \textbf{x}) = p(\textbf{z}| \textbf{x})\) and \(q(\textbf{t}_j) = p(\textbf{t}_j)\).

For better understanding, we explain the meaning of each term, specify the parametric forms of variational distribution and implementation details of Eq. 7.

Term 1. We maximize \(\mathcal {I}(T_j, Y_j)\) by maximizing its lower bound \(\mathbb {E}_{p(\textbf{y}_j, \textbf{t}_j)}[\log q(\textbf{y}_j| \textbf{t}_j)]\). We set \(q(\textbf{y}_j|\textbf{t}_j)\) to Gaussian for regression tasks and the multinomial distribution for classification tasks. Under this assumption, maximizing \(\mathbb {E}_{p(\textbf{y}_j, \textbf{t}_j)}[\log q_(\textbf{y}_j| \textbf{t}_j)]\) is nothing more than minimizing the L2 norm or cross-entropy loss for the prediction. \(q(\textbf{y}_j|\textbf{t}_j)\) is parameterized with another task head \(\mathcal {H}_{j'}\) that takes \(\textbf{t}^j\) as input and makes the task prediction. Notably, \(\mathcal {H}_{j'}\) is different from \(\mathcal {H}_{j}\) since \(\mathcal {H}_{j}\) takes both \(\textbf{z}\) and \(\textbf{t}^j\) as input.

Term 2. We maximize \(\mathcal {I}(X, Z)\) by maximizing its lower bound \(\mathbb {E}_{p(\textbf{z}, \textbf{x})}[\log q(\textbf{z}| \textbf{x})] + H(Z)\). We choose \(q(\textbf{z}| \textbf{x})\) to be an energy-based function that is parameterized by a critic function \(f(\textbf{x}, \textbf{z}): \mathcal {X}\times \mathcal {Z} \rightarrow \mathbb {R}\)

$$\begin{aligned} q(\textbf{z}|\textbf{x}) = \frac{p(\textbf{z})}{C} e^{f(\textbf{x},\textbf{z})}, \text {where } C= \mathbb {E}_{p(\textbf{z})}\big [e^{f(\textbf{x},\textbf{z})}\big ]. \end{aligned}$$
(8)

Substituting \(q(\textbf{z}|\textbf{x})\) into the second term gives us an unnormalized lower bound:

$$\begin{aligned} \mathcal {I}(X, Z) \ge \mathbb {E}_{p(\textbf{z}, \textbf{x})}[f(\textbf{x},\textbf{z})] - \log \mathbb {E}_{p( \textbf{x})}[C], \end{aligned}$$
(9)

The same bound is also mentioned in Mutual Information Neural Estimation (MINE) [3]. Different from original MINE, in our implementation, we estimate the \(\mathcal {I}(X, Z)\) through a feature-wise loss between teacher and students. With a slight abuse of notation, we refer \(\textbf{z}_\mathcal {T}=\mathcal {T}(\textbf{x})_l \in \mathbb {R}^{d_\mathcal {T}}\) and \(\textbf{z}_\mathcal {C} = \mathcal {S}_\mathcal {C}(\textbf{x})_l\in \mathbb {R}^{d_\mathcal {C}}\) as the intermediate feature vectors from teacher and CKN at the l-th layer. Given a pair of \((\textbf{z}_\mathcal {T}, \textbf{z}_\mathcal {C})\), f is defined as inner product of two vectors \(f(\textbf{x},\textbf{z}_\mathcal {C}) = \langle {\textbf{z}_\mathcal {C}, FFN(\textbf{z}_\mathcal {T})} \rangle \), where \(FFN(\cdot ):\mathbb {R}^{d_\mathcal {T}}\rightarrow \mathbb {R}^{d_\mathcal {C}}\) is a feed-forward network to align the dimensions between \(\textbf{z}_\mathcal {T}\) and \(\textbf{z}_\mathcal {C}\).

Term 3. \(\mathbb {E}_{p(\textbf{t}_j)}\big [D_{KL}[p(\textbf{t}_j|\textbf{x})||q(\textbf{t}_j)]\big ]\) is the expected KL divergence between the posterior \(p(\textbf{t}_j|\textbf{x})\) and the prior \(q(\textbf{t}_j)\), which is a upper bound for \(\mathcal {I}(X, T_j)\). We minimize \(\mathcal {I}(X, T_j)\) by minimizing \(\mathbb {E}_{p(\textbf{t}_j)}\big [D_{KL}[p(\textbf{t}_j|\textbf{x})||q(\textbf{t}_j)]\big ]\).

Following the common practice in variational inference [19, 27], we set the prior \(q(\textbf{t}_j)\) as zero-mean unit-variance Gaussian. Besides, we assume the \(p(\textbf{t}_j|\textbf{x})=\mathcal {N}(\boldsymbol{\mu }_{t_j}, \text {diag}(\boldsymbol{\sigma }_{t_j}) )\) is a Gaussian distribution. Accordingly, we compute the mean and variance for the task feature \(\textbf{t}_j\) in each forward pass:

$$\begin{aligned} \textbf{t}_j = \mathcal {S}_{T_j}(\textbf{x}; \varTheta _{\mathcal {S}_{T_j}});\boldsymbol{\mu }_{t_j}=\mathbb {E}[\textbf{t}_j], \boldsymbol{\sigma }^2_{t_j}=\text {Var}[\textbf{t}_j], \end{aligned}$$
(10)

Then, the KL divergence between \(p(\textbf{t}_j|\textbf{x})\) and \(q(\textbf{t}_j)\) can be computed as:

$$\begin{aligned} D_{KL}[p(\textbf{t}_j|\textbf{x})||q(\textbf{t}_j)] = \frac{1}{2}\sum _{l=1}^{L}(1 + \log \sigma _{t_j}^{(l)} -(\mu _{t_j}^{(l)})^2 - \sigma _{t_j}^{(l)}). \end{aligned}$$
(11)

The superscript denotes the l-th element of \(\boldsymbol{\mu }_{t_j}\) and \(\boldsymbol{\sigma }_{t_j}\).

Training. We minimize the following overall loss to achieve both structural and representation factorization between students:

$$\begin{aligned} \min _{\varTheta _{\mathcal {S}_C}, \varTheta _{\mathcal {S}_{T_j}}, \varTheta _{\mathcal {H}_j}} \sum _{j=1}^K \mathcal {L}_{sf}^{(j)} - \lambda _{\text {I}} \mathcal {L}^{(j)}_{\text {I}}, \end{aligned}$$
(12)

where \(\lambda _{\text {I}}\) is weighting coefficient of the IMB objective.

4 Experiments

In this section, we investigate how factorization works to promote the performance, modularity and transferability of the model. Defaultly, we set \(\alpha \) = 1.0 and \(\beta \) = 1e−3, \(\lambda _{\text {I}}\) = 1 and \(\lambda _{\text {kt}}\) = 0.1. Due to the space limit, more hyper-parameter settings, distillation loss, implementation details, data descriptions, and definitions of the metrics are listed in supplementary material.

4.1 Factor Networks Make Strong Task Prediction

We conduct comprehensive experiments on synthetic and real-world classification and multi-task benchmarks to investigate whether the factorized networks still maintain competitive predictive performance, especially on each subtask.

Synthetic Evaluation. We first evaluate our KF on two synthetic imagery benchmarks dSprites [41] and Shape3D [6]. Two datasets are both generated by 6 ground truth independent latent factors. We define each latent factor as a prediction target and treat both datasets as multi-label classification benchmarks. We compare our KF with 4 other baseline methods: single-task baseline, multi-task baselines, MTL2MTL KD and MTL2STL KD. Single-task baseline denotes training 6 single-task networks, while multi-task denotes that one model trained to predict all 6 tasks. MTL2MTL KD distill a multi-tasked student, whereas MTL2STL KD refers to distilling 6 single-tasked students. KF represents our results with factor networks. We train a teacher network as 6-layer CNN model. Besides, all students network encoders, including both the CKN and TSNs, are parametrized by the 3-layer CNN. We take a random train-test split of 7:3 on each dataset and report the ROC-AUC score on the test split.

Results. Figure 3 visualizes the bar plots for the ROC-AUC scores for our KF and its KD opponents on two datasets. Though all method achieves a high AUC score larger than 0.92 on both datasets, it is evident that our KF not only surpasses the multi-tasked baseline but also exceeds two distillation paradigms. In addition, it is noted that multi-tasked models generally achieves better performance than their single-task counterpart, revealing that the prediction performance benefits from learning from multiple labels on two datasets.

Fig. 3.
figure 3

Test ROC-AUC comparison on dSprites and Shape3D datasets.

Table 1. Test accuracy (%) comparison on CIFAR-10 between KD and KF. We report mean ± std over 3 runs.
Table 2. Top-1 Accuracy (%) comparison on ImageNet.

Real Image Classification. We further evaluate our KF on two real image classification CIFAR-10 [29] and ImageNet1K [52]. To apply factorization, we construct two Pseudo-Multi-task Datasets by considering the category hierarchy. The 10 classes in CIFAR-10 can be divided into 6 animal and 4 vehicle categories. Similarly, ImageNet1K classes are organized using WordNet [42] synset tree, with 11 super-classes. We accordingly construct the CIFAR-10 2-task and ImageNet1K 11-task datasets, with each task considering one super-class.

On the single-task and pseudo-multi-task evaluations, we take a pretrained classifier and distill or factorize its knowledge to single-task or pseudo-multi-task students. Each pseudo-multi-task factor/distilled network only manages to predict the categories within one super-class, with the concatenated output serving as the final prediction. We include ResNet-18 [18], WideResNet28-2 (WRN28-2) [65] and WideResNet28-10 (WRN28-10) [65] as our teacher networks on CIFAR-10; MobileNetv2 (MBNv2) [53], along with ResNet-18, WRN28-2 as student or CKN backbone. On ImageNet1K evaluation, the teacher networks are selected to be ResNet-18, ResNet-34 [18] and ResNet-50 [18], with MBNv2 and ResNet-18 as student or CKN backbone. We select a lightweight backbone MBNv2x0.5 to be TSNs. MBNv2x0.5 represents the width multiplier is 0.5.

Results. Table 1 and Table 2 provide the classification accuracy comparison between single-task or pseudo-multi-tasked KD and our proposed KF over 3 runs. Though both approaches improve the baselines under the single-task setting, we note that KD fails to improve the results on the pseudo-multi-tasked evaluation. We also do not report the 11-task KD results on ImageNet because the accuracy is generally lower than 20%. Notably, we observed that the imbalanced labeling causes the deterioration in training: when one network only masters one super-class and the rest of the classes are treated as negative samples, the distilled networks are prone to make low-confident predictions in the end. In comparison, KF has a CKN shared across all tasks, which considerably alleviates the imbalance problem in conventional KD. For example, factor networks obtained by 11-Task KF improve the performance of ResNet18-KD on ImageNet over 1.08% and 1.31% when learning from ResNet-50 and ResNet-34. On other evaluations, KF consistently makes progress overall the normal KD, which suggests that the factorization of task-specific and task-agnostic benefit the performance.

Multi-task Dense Prediction. Two multi-task dense prediction datasets are also used to verify the effectiveness of KF, including NYU Depth Dataset V2 (NYUDv2) [54] and PASCAL Context [11]. NYUDv2 dataset contains indoor scene images annotated for segmentation and monocular depth estimation. We include 4 tasks in PASCAL Context, including semantic/human part segmentation, normal prediction, and saliency detection. We use the mean intersection over union (mIoU), the angle mean error (mErr) and root mean square error (rmse) are used to measure the prediction quality.

We include both the single-task and multi-task together with their STL2STL/MTL2STL/MTL2MTL distilled models as our baselines. We adopt the HRNet48 [61] and ResNet-50 DeepLabv3 as teacher and HRNet18 and ResNet-18 DeepLabv3 as student or CNK. The TSN are set to MBNv2x0.5. We use a smaller \(\beta \) = 1e−5. The networks are initialized with the ImageNet pretrained weights.

Table 3. Performance comparison on the NYUDv2 dataset.
Table 4. Performance comparison on the PASCAL dataset.

Results. We show the evaluation results on NYUDv2 and PASCAL datasets in Table 3 and Table 4. On NYUDv2, the multi-task baselines are generally better-performed than its single-task competitors. On the contrary, in the PASCAL experiments of HRNet48, ResNet18 and ResNet50, the performance of multitask baseline has largely degraded. It reveals the negative transfer problem in MTL that the joint optimization of multiple objective might cause the contradiction between tasks, thus leading to undesirable performance reduction.

The same problem remains when comparing MTL2MTL-KD to STL2-STL-KD in Table 4, where the MTL teacher is inferior to STL ones. Our factor networks automatically resolve this problem, because different TSNs are structurally and representationally independent. As a result, KF achieved strong student performance compared to other baselines.

4.2 Factorization Brings Disentanglement

Given the distilled and factorized models in the previous section we measure a set of disentanglement metrics and representation similarity to confirm that the knowledge factorization captures the independent variables across tasks.

Disentanglement Evaluation Setup. We first validate the disentanglement between factor models on dSprites [41] and Shape3D [6]. We measure 4 disentanglement metrics to quantify how well the learned representations summarize the factor variables. Those metrics are disentanglement-completness-informativeness (DCI) [12], Mutual information gap (MIG) [8], FactorVAE metric [26], and Separated Attribute Predictability (SAP) score [30]. Higher means better.

We compare our KF with 3 other baseline methods: single-task baseline, multi-task baselines, and MTL2STL KD students, which has been introduced in previous section. Following the evaluation protocol in [35], we adopt the concatenation of all average-pooled task-specific representations as our final feature vector for evaluation and compute all scores on test set.

Results. Figure 4 illustrates the quantitative results of different disentanglement metrics using box plots. First, we see that multi-task learning naturally comes with disentangled representations, where MTL achieves a slightly higher score than the STL. Another observation is that knowledge transfer methods like KD and KF also help the model to find factors that are unappreciable for the teachers. The features extracted by our factor networks generally score the best, especially on the dSprites dataset, with an improvement over median of 0.47 and 0.09 on DCI and MIG scores. It is in line with our expectation that decomposing the knowledge into parts leads to disentangled representations.

Representation Similarity. We further conduct representation similarity analysis using centered kernel alignment (CKA) [28] between teacher models, distilled models and our factorized models across 4 datasets, including dSprites, Shape3D, CIFAR10 and NYUDv2. On each dataset, CKA is adopted to quantifying feature similarity among (1) MTL teacher (2) MTL2MTL-KD student (3) MTL2STL students and (4) Our CKN and TSNs. We compute linear kernel CKA between all pairs of models at the last feature layer on test set. The model architectures are described in the Appendix. The higher CKA index suggests higher correlation between two networks.

Fig. 4.
figure 4

Disentanglement Metrics comparison between (1) Single-Task Baseline, (2) Multi-Task Baseline, (3) KD, and (4) our proposed KF on dSprite (top) and Shape3D (bottom) datasets. Each experiment is repeated over 10 runs.

Fig. 5.
figure 5

CKA representation similarity between distilled and factorized models.

Results. Figure 5 visualizes the CKA confusion matrix between all model pairs on 4 tasks. We made the following observations. First, models mastering the same subtask has high feature similarity. Second, our factorized TSN captures more “pure” knowledge compared with MTL2STL students. On each heatmap, the bottom left region has high similarity (in darker red), suggesting that the conventional distilled models still maintains high similarity with its peers even though they are trained on dedicated tasks. In comparison, factorized TSNs achieve smaller similarities (in upper right region), again supporting our argument that factor networks capture the disentangled factors across tasks.

4.3 Common Knowledge Benefits Transferring

We then finetune the factorized CKN on two downstream tasks to see if the common knowledge facilities the transfer learning to unseen domains. We train ResNet-18 networks with different initializations on Caltech-UCSD Birds (CUB-200) [60] and MIT indoor scene (Scene) [49]. The trained models are then reestablished as teachers to educate student networks like MBNv2 and ShuffleNetv2.

Results. Table 5 shows the transfer learning performance and distillation accuracy using different pretrained weights. R18 w/ImageNet-CKN refers to the ResNet-18 CKN factorized from ImageNet pretrained ResNet-18. Compared with the original pretrained weights, ImageNet-CKN achieves substantial improvement on both datasets. By reusing the finetuned ResNet-18 as teacher network, we show in Fig. 5 that CKN serves as a better role model to educate the student networks. It provides compelling evidence that common knowledge factorized from the teacher network benefits the transfer learning to other tasks.

Table 5. Finetuning performance and distillation accuracy with different pretrained weights. R18 is the short for ResNet-18.

5 Conclusion

In this paper, we introduce a novel knowledge-transfer task termed Knowledge Factorization. Given a pretrained teacher, KF decomposes it into task-disentangled factor networks, each of which masters the task-specific and the common knowledge factorized from the teacher. Factor networks may operate independently, or be integrated to assemble multi-task networks, allowing for great scalability. We design an InfoMax Bottleneck objective to disentangle the common and task-specific representations by optimizing the mutual information between input and representations. Our method achieves strong and robust performance, and meanwhile demonstrates great disentanglement capability across various benchmarks, with better modularity and transferability.