1 Introduction

Cross-domain few-shot learning (CDFSL) addresses the problem that deep learning methods, such as convolutional neural networks (CNN) for image classification, generally require a large amount of labelled training data to achieve high predictive accuracy when trained from scratch. CDFSL algorithms are designed for scenarios where only a few labelled training instances are available in the form of a so-called “support set”. The aim is to nevertheless achieve high accuracy when predicting labels for instances of the target domain that have never been seen before, i.e., the so-called “query set”. This can generally only be achieved by applying transfer learning: taking knowledge gleaned from one or several source domains with large-scale training data and using this knowledge to inform learning in a few-shot target domain.

In CDFSL, the source domain(s) and the target domain are assumed to have potentially very distinct properties. This cross-domain setting is arguably more realistic than the “in-domain” scenario, used in some few-shot learning literature (Vinyals et al., 2016), where the source and the target domains comprise mutually exclusive sets of classes obtained from the same dataset.Footnote 1 It also yields harder learning problems due to greater domain shift.

In a CDFSL setting with multiple source domains, an algorithm needs to both select relevant source domains and effectively transfer their knowledge into a target domain using a few-shot support set. Performance is measured by “meta-testing”—transferring model(s) using target domain support sets and evaluating their predictive accuracy on corresponding query sets. Recent work considering performance in image classification, which is the setting we also focus on in this paper, shows that single-domain learning (SDL) and vanilla multi-domain learning (MDL), which applies one feature extractor and multiple classification heads, fail to achieve competitive performance compared to methods specifically designed for CDFSL (Triantafillou et al., 2020; Li et al., 2021).

A majority of recently published CDFSL methods involve building a universal model from a collection of extractors, with each extractor pretrained in a distinct source domain. This comprises the so-called “meta-training” phase, which is performed before meta-testing begins. The universal-model paradigm is generally efficient when performing meta-testing because a single universal feature extractor is used and fine-tuned on the support set, usually in conjunction with a simple robust classifier that turns extracted feature vectors into predictions. However, training the universal model is computationally expensive, and some methods constrain all extractors to the same architecture as the intended universal model (Triantafillou et al., 2021), rendering them inapplicable to heterogeneous extractor collections that are likely to occur in real-world practice. Moreover, they may require adjustment based on pre-existing domain knowledge to function well. For example, given a source domain/extractor collection for image classification consisting of ImageNet (Deng et al., 2009; Russakovsky et al., 2015), along with other, less comprehensive source domains, authors often assign greater importance to the ImageNet extractor during training (Triantafillou et al., 2021; Li et al., 2021). This achieves good performance on benchmarks, which normally include target domains such as CIFAR-10 that are quite similar to ImageNet in nature, but may not be as useful in real-world applications involving less similar data. Lastly, the process of deriving a universal model is non-incremental, which means it needs to be re-run whenever an extractor is updated or added, and normally requires access to the entire meta-training dataset (Triantafillou et al., 2021; Li et al., 2021).

As an alternative approach that avoids these shortcomings, we propose a novel “lazy” CDFSL method, termed feature extractor stacking (FES), that fine-tunes each extractor independently and trains a classifier using a form of stacked generalisation (Wolpert, 1992) during meta-testing. The “meta-training” phase in FES consists solely of training individual feature extractors, one for each source domain, using standard single-domain supervised learning. In practical applications, it may be possible to skip meta-training entirely if a set of suitable feature extractors has been obtained from other sources. FES is fully compatible with heterogeneous extractor collections, imposing no constraints on their architecture or fine-tuning configuration. It assumes equal importance of the extractors a priori, determining their task-specific relevance based purely on the support set, and does not require derivation of a universal model.

Along with the basic FES algorithm, which applies a simple linear stacking classifier and is described in Sect. 3.1, we present two variants: convolutional FES (ConFES) in Sect. 3.3 and regularised FES (ReFES) in Sect. 3.4. ConFES replaces the flat global kernel of FES with a hierarchy of depthwise convolutional kernels, reducing the number of parameters in the stacking classifier. ReFES applies fused lasso regularisation (Tibshirani et al., 2005) to the stacking classifier of FES to reduce the weights of irrelevant snapshots and induce smooth weight transition between adjacent snapshots.

We evaluate FES and its variants on the Meta-Dataset benchmark (Triantafillou et al., 2020), which contains eight source domains and five target domains, and include five additional target domains: CropDisease, EuroSAT, ISIC, ChestX, and Food101 (Guo et al., 2020; Bossard et al., 2014). We show that FES outperforms three recent universal-model methods: URL (Li et al., 2021), FLUTE (Triantafillou et al., 2021), and a URL extractor with TSA fine-tuning (Li et al., 2022), and advances the state of the art on Meta-Dataset. We also discuss practical advantages of FES in real-world scenarios, as FES can work with heterogeneous extractors out of the box and does not need to train a universal model.

2 Related work

Our empirical comparison of CDFSL methods is based on the Meta-Dataset framework, so we review this benchmark first before discussing methods that we compare to our approach. We also briefly review other noteworthy methods in the literature.

2.1 The Meta-dataset benchmark

The Meta-Dataset (Triantafillou et al., 2020) benchmark has multiple configurations; we describe the CDFSL configuration that we use—most recent publications in the field use this configuration as well. It contains eight source domains: ILSVRC-2012 (ImageNet), Omniglot, Aircraft, CUB-200-2011 (Birds), Describable Textures, Quick Draw, Fungi, and VGG Flower. Recent work utilising Meta-Dataset (Requeima et al., 2019) has extended its original set of two target domains, Traffic Signs and MSCOCO, by adding three more: MNIST, CIFAR10, and CIFAR100. For an even more comprehensive evaluation, we add four target domains from the CDFSL benchmark in Guo et al. (2020)—CropDisease, EuroSAT, ISIC, and ChestX—but additionally also employ Food101 (Bossard et al., 2014). Only the 250 sanitised test images in each Food101 class are used in our experiments.

The Meta-Dataset framework splits each source domain into three partitions: training, validation, and test. The partitions are mutually exclusive in terms of their classes, with the training partition containing approximately 70% of source domain classes and the validation and test partitions containing approximately 15% each. The training and validation partitions are made available to the CDFSL method for “meta-training”, where the training partition is generally used to train extractors and the validation split to aid hyperparameter tuning. The test partition is reserved for evaluating the CDFSL method by sampling few-shot episodes (i.e., meta-testing): the term “episode” refers to the process of sampling a support set and a query set, training a classifier on the support set, and evaluating it on the query set.

In contrast, the entire target domain data can be used for sampling episodes to evaluate few-shot learning in these domains. It is important to note that, by definition, only tasks sampled from target domains truly measure CDFSL performance. Using terminology that is common in this context, good performance in these domains indicates “strong generalisation”; good performance on tasks sampled from source domain test partitions indicates “weak generalisation”.

The most commonly used method to evaluate CDFSL algorithms on Meta-Dataset is to generate 600 any-way any-shot episodes from each dataset, and measure each algorithm’s mean classification accuracy on these 600 episodes, as well as the 95% confidence interval. Any-way any-shot sampling means the number of classes for each episode and the number of support instances per class are arbitrary, leading to imbalanced support sets more representative of real-world scenarios than fixed-way fixed-shot episodes. The query set is balanced in the Meta-Dataset setting. We adhere to this evaluation method in our experiments.

2.2 Methods included in the experimental comparison

Two recently published CDFSL methods that advanced the state-of-the-art on Meta-Dataset are Few-shot Learning with a Universal TEmplate (FLUTE) (Triantafillou et al., 2021) and Universal Representation Learning (URL) (Li et al., 2021). Even more recently, based on a URL universal model, a fine-tuning method using Task-Specific Adaptors (TSA) (Li et al., 2022) improved accuracy on some target domains even further. We compare our new FES approach to these methods in our experiments.

2.2.1 Few-shot learning with a universal template

Based on the FiLM approach (Perez et al., 2018), FLUTE trains a universal model in the source domains, employing the ResNet18 architecture (He et al., 2016) widely used in CDFSL, but maintaining a separate set of batch normalisation (Ioffe & Szegedy, 2015) parameters for each domain. The ResNet “template” contains one set of convolutional weights shared across all source domains, and only the batch normalisation parameters are specific to each source domain. FLUTE jointly trains the template in all source domains. At each training iteration, a random source domain is selected—with ImageNet having a 50% probability of being selected and the other seven source domains evenly sharing the other 50% probability—and a batch of input data is sampled from the selected source domain. In forward propagation, the input batch flows through the shared convolutional layers and the selected domain’s set of batch normalisation layers, and loss is computed by applying a cosine classifier (Chen et al., 2019, 2021). A nuance of FLUTE training is that backpropagation is performed using a “meta-batch” of eight individual batches: the intention is to stabilise training by aggregating loss values across multiple domains. Hyperparameter tuning is performed using episodes sampled from source domain validation partitions.

When the template is trained, snapshots are frequently saved. The final template is chosen as the snapshot that performs best on the source domains’ validation partitions. To establish performance, few-shot episodes are sampled from these partitions. For each episode, feature vectors are extracted using the shared convolutional layers and the domain’s set of batch normalisation layers. Accuracy is computed using a nearest-centroid classifier (Mensink et al., 2013; Snell et al., 2017).

One more component of FLUTE, produced in a separate meta-training phase, is a blender network, which is a dataset classifier based on a permutation-invariant set encoder (Zaheer et al., 2017) followed by a linear layer. Given a batch of instances, the blender predicts, as a probability distribution, the source domain from which the batch is sampled. It is trained on batches sampled from the source domains’ training partitions, and the final blender model is chosen using batches from the validation partitions.

Given a few-shot episode at meta-test time, the blender uses the support set to produce a probability distribution. These probabilities in turn are used to form a linear combination of the source-domain-specific batch normalisation weights. Along with the shared convolutional weights from the template, this forms the initial set of parameters for the ResNet18 feature extractor, which is applied in conjunction with a nearest-centroid classifier. The model’s batch normalisation parameters are then fine-tuned on the support set while its convolutional weights remain fixed.

2.2.2 Universal representation learning

The URL algorithm also generates a universal model. It first pretrains domain-specific ResNet18 extractors independently. Then, a separate ResNet18 feature extractor is trained to form a universal model by distillation. This model is trained to match each extractor’s output feature vectors and logits using instances sampled from the extractor’s corresponding domain. To this end, the universal model contains pairs of auxiliary domain-specific components that each comprise 1) a projection layer that transforms the universal extractor’s feature vectors to match those of each domain-specific extractor, and 2) a classifier layer trained to match the logits produced by each extractor.

In the experiments by Li et al. (2021), ImageNet is made more prominent in distillation: ImageNet instances make up 50% of each mini-batch and the other seven source domains evenly make up the rest. Snapshots of the universal feature extractor are saved at predefined intervals during knowledge distillation. Episodes sampled from source domain validation partitions are used to select the best snapshot as a form of early stopping.

After meta-training, the auxiliary components of the universal model are discarded, leaving only the feature extractor. During meta-testing, this extractor is frozen, and a projection layer is initialised with an identity weight matrix and trained using the support set. The projected feature vectors are used to build a nearest-centroid classifier. Cosine similarity values between a feature vector to be classified and the centroids are used as logits. Fine-tuning minimises cross-entropy loss on the support set. Note that during fine-tuning, as the projection layer is optimised, projected support feature vectors change, and their centroids change as well. The fine-tuning effect can be interpreted as forming better clusters with projected support feature vectors.

2.2.3 Task-specific adaptors

TSA (Li et al., 2022) is a fine-tuning method suitable for CDFSL. Given a pretrained extractor, trainable task-specific adaptors are attached to it, and the support set is used to optimise the adaptors with the extractor’s original weights frozen. Like URL, TSA also attaches a trainable linear projection layer and a robust classifier to the end of the feature extractor during fine-tuning, but it adds further adaptor components. Among multiple configurations examined, the most effective approach for few-shot image classification found by Li et al. (2022) is to attach channel projection matrices as residual connections to a model’s convolutional layers. Li et al. (2022) used TSA in conjunction with a URL-distilled universal extractor, but TSA can be applied to other CNN architectures as well.

2.3 Other work on CDFSL

We review additional noteworthy CDFSL methods here. These methods precede FLUTE, URL, and TSA chronologically and achieve lower accuracy than results presented by Triantafillou et al. (2021) and Li et al. (2021, 2022). Hence, in the experiments presented in this paper, we only compare to FLUTE, URL, and a URL extractor with TSA fine-tuning.

2.3.1 Selecting relevant features from a Universal representation

SUR (Dvornik et al., 2020) is a CDFSL method that utilises independently pretrained feature extractors directly for meta-testing. Each extractor is used to extract a set of feature vectors from the support set, with a trainable weight assigned to it. Feature vectors are multiplied by their respective weights and concatenated to provide input to a nearest-centroid classifier. The weights are trained by optimising loss of the classifier on the support set. SUR is similar to URL in the meta-testing phase, as both make predictions with a nearest-centroid classifier and optimise parameters on the support set; the primary difference is that URL maintains a universal model while SUR uses the original extractors directly.

2.3.2 Universal representation transformer

URT (Liu et al., 2021) also assigns a weight to each source domain extractor during meta-testing. However, it utilises a weight assignment model learned using meta-training instead of direct optimisation on the support set to obtain the weights. To this end, URT trains an attention mechanism (Vaswani et al., 2017) that learns to assign appropriate weights to source domain feature extractors given a few-shot episode. The weight assignment model is trained and has its hyperparameters selected using episodes sampled from the source domains’ training and validation partitions.

2.3.3 Conditional neural adaptive processes

The CNAPs method, as proposed in Requeima et al. (2019), uses an extractor pretrained in a large source domain, e.g., ImageNet (Deng et al., 2009; Russakovsky et al., 2015), and meta-trains adaptation networks, using episodes sampled from the source domains, to produce task-specific FiLM (Perez et al., 2018) transformations and a linear classifier for each few-shot episode.

A variant, Simple CNAPs (Bateni et al., 2020), was later proposed utilising a non-parametric Mahalanobis distance (Galeano et al., 2015) measure in place of the classifier adaptation network of CNAPs, reducing the parameter count and improving CDFSL performance. A transductive version of Simple CNAPs was subsequently also proposed (Bateni et al., 2022), making use of clustering of query instances in feature space to achieve better performance than Simple CNAPs, assuming that the query set is available as a batch instead of a sequential stream of incoming instances. As most other CDFSL methods do not rely on such an assumption, they cannot be compared to transductive CNAPs on an even footing.

2.3.4 Multi-mode modulator

Tri-M (Liu et al., 2021), akin to CNAPs, uses an extractor pretrained in a large-scale source domain, and meta-trains a modulation network using source domain episodes to generate appropriate FiLM transformations for each few-shot episode. Tri-M maintains two sets of transformations—a domain-specific one and a domain-cooperative one—and its resulting FiLM transformation is a combination of the two. Tri-M determines a source domain for its domain-specific transformation in a way similar to how FLUTE (Triantafillou et al., 2021) utilises its blender network and uses an attention mechanism (Vaswani et al., 2017) to compose its domain-cooperative transformation from relevant source domains.

3 Cross-domain few-shot learning using stacking

Considering the CDFSL methods discussed in the previous section, the SUR method stands out because its meta-training process is straightforward: all it involves is pretraining individual source domain feature extractors. Once these have been obtained, SUR performs “lazy” learning in the sense that significant work is only performed once the support set for a few-shot episode becomes available. This makes it very flexible because new extractors can be added at any time. However, SUR does not yield state-of-the-art performance. The new methods presented in this paper are inspired by SUR and the old and established method of applying stacked generalisation to learning a classifier that combines predictions of multiple base classifiers. Henceforth, we will refer to this classifier as the “stacking classifier”. There are four primary differences between SUR and our stacking-based methods: 1) the source domain extractors are fined-tuned on the support set to extract more information from this data by attaching appropriate classifier layers to them, 2) two-fold cross-validation is used to generate training data for the stacking classifier to tackle overfitting, 3) the feature vectors of this training data consist of logits obtained from classifier layers attached to the extractors, and 4) multiple snapshots of each extractor are stored during fine-tuning and used to obtain sets of logits, adding further richness to the data available for training the stacking classifier.

In the following, we first explain the basic method of feature extractor stacking (FES) in detail and prove convexity of its optimisation, before describing two variants: convolutional FES (ConFES) and regularised FES (ReFES).

3.1 Feature extractor stacking

Given pretrained feature extractors, FES has three key components: fine-tuning extractors to obtain snapshots, two-fold cross-validation to produce training data for the stacking classifier, and training of the stacking classifier. Figure 1 depicts the FES framework.

Fig. 1
figure 1

Framework of FES. Given an extractor collection with K extractors, each extractor \(\Phi\) is set up as a network \(\Psi\) for fine-tuning. The support set S is split into \(S_1\) and \(S_2\) using stratified cross-validation. Each network \(\Psi\) is fine-tuned on one split, producing J snapshots in the process, and these snapshots are used to extract logits from the other split. Logits extracted from both splits are combined into cross-validated logits of the full support set, which are used to train a stacking classifier W to fit S’s labels. The full S is then used to fine-tune \(\Psi\), producing snapshots to extract logits for the query set Q. W takes Q’s logits as input and predicts Q’s labels

3.1.1 Fine-tuning the extractors

We use \(f_{\Phi _1}, f_{\Phi _2},..., f_{\Phi _{K}}\) (or just \(\Phi _1, \Phi _2,..., \Phi _{K}\) for brevity) to denote the collection of pretrained feature extractors, where \(\Phi\) represents the corresponding extractor’s parameters and K is the number of source domains. The support set of a few-shot episode is denoted S and the query set Q. S contains N instances belonging to C classes. We fine-tune each extractor independently on S. As \(f_{\Phi }\) is a feature extractor, a classifier g with parameters \(\Theta _1\) is attached to \(f_{\Phi }\) to produce logits. Auxiliary components with parameters \(\Theta _2\) may also be introduced to the model to aid fine-tuning, such as with TSA (Li et al., 2022). The resulting model is defined as \(h_{\Psi } = g_{\Theta _1} \circ f_{(\Phi , \Theta _2)}\), where we use \(\Psi\) to denote the combination of all parameters. It is possible for \(\Theta _2\) to be \(\varnothing\), as auxiliary fine-tuning components are optional. J snapshots are saved sequentially at different fine-tuning iterations of \(h_{\Psi }\). Each snapshot contains parameters \(\Psi _k^j[S]\), where \(k \in [1,K]\) and \(j \in [1,J]\), with S denoting the fine-tuning set used.

3.1.2 Cross-validation to obtain training data for stacked generalisation

In stacked generalisation (Wolpert, 1992), cross-validation is employed to obtain training data for the stacking classifier to combat overfitting, and it is applied in FES as well. More specifically, we apply stratified two-fold cross-validation to the support set S, producing two splits \(S_1\) and \(S_2\), which will take turns serving as the training split \(S^{train}\) and the test split \(S^{test}\). It is possible to employ more folds in FES, but using additional folds did not yield performance gains in our experiments.

Training on one of the training splits amounts to fine-tuning a network \(h_{\Psi }\) on this data. In principle, this could be done for a fixed number of iterations, and once complete, logits on the corresponding test split could be obtained as training data for the stacking classifier. However, this naive approach may not work well because it is not known how many iterations should be performed for fine-tuning to maximise accuracy of the full learning system. The approach we propose and evaluate in this paper is instead based on the idea that we can take multiple snapshots of the models during fine-tuning and use all the snapshots’ logits on the test folds for training the stacking classifier. In other words, the learning algorithm for the stacking classifier will be responsible for deciding which extractor snapshots are the most useful ones for making accurate predictions on the test folds.

More specifically, given a pair \((S^{train}, S^{test})\) and an extractor \(h_{\Psi }\), we fine-tune \(h_{\Psi }\) on \(S^{train}\) with the same configuration used to obtain \(h_{\Psi ^j[S]}\), e.g., optimiser, learning rate, etc., and save snapshots \(h_{\Psi ^j[S^{train}]}\) at the same iterations as \(h_{\Psi ^j[S]}\). Logits \(L^j[S^{test}]\) are extracted from \(S^{test}\) with each \(h_{\Psi ^j[S^{train}]}\), i.e., \(L^j[S^{test}] = h_{\Psi ^j[S^{train}]}(S^{test})\). Using this approach, the two splits \(S_1\) and \(S_2\) can be used to alternately fine-tune extractors and produce logits \(L^j[S_1]\) and \(L^j[S_2]\), which are combined into \(L^j[CV]\), i.e., logits for every support set instance extracted using cross-validation. Considering the logits from all K extractors jointly, \(L_K^J[CV]\) is a tensor of shape \(N \times K \times J \times C\), i.e., N support instances converted into logits for C classes extracted by \(K \times J\) snapshot models, ready to serve as training data for the stacking classifier.

3.1.3 Stacking classifier training

Fig. 2
figure 2

FES uses a global kernel to compute stacking classifier logits from the snapshots’ base logits. The global kernel is essentially flat since it makes no use of the snapshots’ temporal relations. For demonstration purposes, this figure and the following ones assume three extractors (\(K = 3\)), five fine-tuning snapshots per extractor (\(J = 5\)), and a two-class problem (\(C = 2\))

The FES stacking classifier is a weight matrix W of shape \(K \times J\), with \(W_k^j\) representing \(\Psi _k^j\)’s weight. Given an instance l of shape \(K \times J \times C\), the stacking classifier’s output logits \(l^W\) are obtained using a simple weighted average:

$$\begin{aligned} l^W[c] = \sum _{k = 1}^K\sum _{j = 1}^JW_k^j \cdot l_k^j[c], \end{aligned}$$
(1)

where c is one of the C classes. We compute the cross-entropy loss using the N support set logits \(L^W\) output by the stacking classifier and the one-hot-encoded labels Y, i.e., \(-\sum \limits _{n = 1}^NY_n \log (\text {softmax}(L_n^W))\), which we minimise by training W. For interpretability, we constrain all values in W to be non-negative by clipping negative weights with ReLU. The FES stacking classifier is shown in Fig. 2.

After training, W is used with Eq. 1 to compute meta logits for the query set Q using the logits \(L_K^J[Q]\) computed by the saved snapshots \(\Psi _K^J[S]\). Then, a softmax function is used to obtain class probability estimates.

3.2 Proof of convexity

Given a stacking instance l consisting of base logits obtained from the extractor snapshots, which the stacking classifier transforms into meta-level logits \(l^W\), and the label \(c_y\), the negative log-likelihood loss \(\ell\) associated with the stacking classifier’s parameters W is

$$\begin{aligned} \ell (W) = \log \left(\sum _{i = 1}^Ce^{l^W[c_i]}\right) - l^W[c_y]. \end{aligned}$$
(2)

To prove that optimising FES is a convex problem, we show that for any two values of W, named A and B, a linear combination of the loss on A and the loss on B is never smaller than the loss obtained for the corresponding linear combination of the parameter values A and B, i.e.,

$$\begin{aligned} \ell (\lambda A + (1 - \lambda ) B) \le \lambda \ell (A) + (1 - \lambda )\ell (B), \lambda \in [0, 1]. \end{aligned}$$
(3)

Applying Eq. 2 to Eq. 3, we get

$$\begin{aligned}&\log ( \left( {\sum\limits_{{i = 1}}^{C} {e^{{l^{{(\lambda A + (1 - \lambda )B)}} [c_{i} ]}} } } \right) - l^{(\lambda A + (1 - \lambda ) B)} \left[ {c_{y} } \right] \le \\&\lambda \left(\log \left(\sum _{i = 1}^Ce^{l^A[c_i]}\right) - l^A[c_y]\right) + (1 - \lambda )\left(\log \left(\sum _{i = 1}^Ce^{l^B[c_i]}\right) - l^B[c_y]\right), \end{aligned}$$

which can be simplified into

$$\begin{aligned} \log \left(\sum _{i = 1}^Ce^{l^{(\lambda A + (1 - \lambda ) B)}[c_i]}\right) \le \lambda \log \left(\sum _{i = 1}^Ce^{l^A[c_i]}\right) + (1 - \lambda )\log \left(\sum _{i = 1}^Ce^{l^B[c_i]}\right), \end{aligned}$$
(4)

because using Eq. 1, we have

$$\begin{aligned}&l^{(\lambda A + (1 - \lambda ) B)}[c_y] \\ =&\sum _{k = 1}^K\sum _{j = 1}^J(\lambda A_k^j + (1 - \lambda ) B_k^j) \cdot l_k^j[c_y] \\ =&\sum _{k = 1}^K\sum _{j = 1}^J\lambda A_k^j \cdot l_k^j[c_y] + \sum _{k = 1}^K\sum _{j = 1}^J(1 - \lambda ) B_k^j \cdot l_k^j[c_y] \\ =&\lambda \sum _{k = 1}^K\sum _{j = 1}^JA_k^j \cdot l_k^j[c_y] + (1 - \lambda )\sum _{k = 1}^K\sum _{j = 1}^JB_k^j \cdot l_k^j[c_y] \\ =&\lambda l^A[c_y] + (1 - \lambda ) l^B[c_y]. \end{aligned}$$

Similarly, Eq. 4 can be transformed using Eq. 1 into

$$\begin{aligned} \log \left(\sum _{i = 1}^Ce^{\lambda l^A[c_i] + (1 - \lambda ) l^B[c_i]}\right) \le \lambda \log \left(\sum _{i = 1}^Ce^{l^A[c_i]}\right) + (1 - \lambda )\log \left(\sum _{i = 1}^Ce^{l^B[c_i]}\right). \end{aligned}$$
(5)

It is known that the LogSumExp function \(LSE(x) = \log (\sum \limits _{i = 1}^ne^{x_i})\) is convex. Therefore, we have

$$\begin{aligned} \forall n \in \mathbb {Z}^+, \alpha , \beta \in \mathbb {R}^n: LSE(\lambda \alpha + (1 - \lambda ) \beta ) \le \lambda LSE(\alpha ) + (1 - \lambda ) LSE(\beta ). \end{aligned}$$
(6)

Hence, Eq. 5 is true because we can make the following assignments:

$$\begin{aligned} n&= C, \\ \alpha _i&= l^A[c_i],\\ \beta _i&= l^B[c_i]. \end{aligned}$$

This completes the proof of Eq. 3, and thus optimising FES on a single instance l is a convex problem. As the sum of convex functions is a convex function, optimising FES on a full batch L is also a convex problem. Therefore, FES is a convex optimisation problem.

3.3 Convolutional feature extractor stacking

Fig. 3
figure 3

ConFES replaces the flat kernel of FES with a two-level kernel hierarchy. The base-level kernel is a one-dimensional depthwise, i.e., feature-extractor-wise, convolutional kernel, with predefined kernel and stride sizes. The high-level kernel is global like the one in FES but applied to the output of the base-level kernel, which requires substantially fewer parameters

The basic FES approach does not exploit the temporal relation between logits obtained from adjacent snapshots produced during fine-tuning. Convolutional FES (ConFES) replaces the global kernel of FES with a kernel hierarchy, as shown in Fig. 3, to treat the collection of logits as a time series. The hierarchy comprises one or more lower-level one-dimensional depthwise convolutional kernels and a top-level global kernel. The depthwise kernels condense the logit output sequence from each extractor’s snapshots into a 1D feature map, while keeping the extractors separate, and the global kernel summarises the feature maps produced by the lower-level kernels.

ConFES is motivated by the assumption that when each extractor is fine-tuned on the support set, it undergoes gradual changes between iterations, and the logits output by sequentially saved snapshots can be considered a time series. Therefore, 1D convolutions can be used to discern informative patterns in the time series data and compute feature maps, which are smaller in size than the raw logit time series, and therefore require fewer parameters in the global kernel than standard FES.

Given K extractors and J snapshots for each extractor, FES requires \(K \times J\) parameters. Assuming a two-level ConFES hierarchy, with a base-level convolutional kernel of size \(J_b\) and stride T, the feature map for each extractor will be of length \(J_m = \frac{J - J_b}{T} + 1\), leading to a global kernel size of \(K \times (\frac{J - J_b}{T} + 1)\). Including the \(K \times J_b\) parameters in the convolutional kernel, ConFES contains \(K \times (\frac{J - J_b}{T} + 1 + J_b)\) parameters. In practice, it can generally be assumed that \(J \gg 1\): a two-level ConFES architecture should be configured so that \(J \gg J_b \ge T \gg 1\) in order to cover all snapshots with significantly fewer parameters than FES.

ConFES utilises the sequential relation of each extractor’s snapshots through its lower-level 1D depthwise convolutional layers and exhibits substantially fewer parameters than FES, making it less prone to overfitting. Note that Fig. 3 is simplified for demonstration purposes and does not reflect well that ConFES maintains fewer parameters; for a practical example of ConFES kernels, please refer to Fig. 12.

3.4 Regularised feature extractor stacking

Fig. 4
figure 4

ReFES uses the same global kernel as FES and applies fused lasso regularisation to the kernel’s training process. Fused lasso drives each individual weight towards zero with a regularisation strength of \(\lambda _1\) and applies depthwise smoothing to the weight matrix by penalising the weight difference between adjacent snapshots with a regularisation strength of \(\lambda _2\)

To combat overfitting, an alternative to reducing the number of parameters is to perform regularisation. Regularised FES (ReFES) introduces fused lasso regularisation (Tibshirani et al., 2005) to the stacking classifier used in FES, as shown in Fig. 4. Non-zero weights are penalised with a strength of \(\lambda _1\), and each feature-extractor-wise weight sequence is smoothed with a strength of \(\lambda _2\). The loss is a combination of cross-entropy loss and depthwise fused lasso loss, as formulated in Eq. 7, given K extractors, J snapshots per extractor, and a 2D global kernel W of shape \(K \times J\).

$$\begin{aligned} \ell = \ell _{\text {cross-entropy}} + \lambda _1 \sum _{k = 1}^K \sum _{j = 1}^J \Vert W_k^j\Vert + \lambda _2 \sum _{k = 1}^K \sum _{j = 1}^{J - 1} \Vert W_k^j - W_k^{j + 1}\Vert . \end{aligned}$$
(7)

In addition to encouraging sparse weights like standard lasso, fused lasso also encourages smaller differences between adjacent weights (Tibshirani et al., 2005). Each extractor’s snapshots are ordered by their fine-tuning iterations, and adjacent snapshots are likely to be similar. By applying fused lasso regularisation, differences between adjacent weights are penalised, and weight sequences are smoothed.

The stratified two-fold splits \(S_1\) and \(S_2\) can be used to select appropriate \(\lambda _1\) and \(\lambda _2\) values for a few-shot episode. In the spirit of grid search with cross-validation, a ReFES stacking classifier is trained on the logits of one split, e.g., \(L_K^J[S_1]\), and tested on the logits of the other split, e.g., \(L_K^J[S_2]\). Different values for \(\lambda _1\) and \(\lambda _2\) can be explored and the best configuration selected based on the combined accuracy on the two folds. This configuration is then used to train a newly initialised ReFES stacking classifier on the full set of cross-validation logits \(L_K^J[CV]\), and this stacking classifier is used to label the query set instances Q based on their logits \(L_K^J[Q]\).

3.5 Handling single-instance classes

Meta-Dataset’s sampling scheme (Triantafillou et al., 2020) sometimes produces support sets containing single-instance classes. During cross-validation, single-instance classes need to be removed: if a class’ only instance is in the test split \(S^{test}\), then the training split \(S^{train}\) will have no instance of that class. FES and its variants can train their stacking classifiers on a subset of the support classes \(C_{sub} \le C\), because their kernels only encode the weights of the snapshots, and are inherently independent of the number of classes C. In Figs. 2, 3, and 4, C can simply be replaced by \(C_{sub}\) during training.

Given a strict one-shot problem, where all classes exhibit exactly one instance, FES cross-validation is infeasible, as all classes need to be removed during cross-validation, leading to \(L_K^J[CV] = \varnothing\). Therefore, support logits obtained from ordinary fine-tuning need to be used in place of cross-validation logits, i.e., \(L_K^J[S]\) is used to train the FES classifier W instead of using \(L_K^J[CV]\).

4 Experimental setup

To evaluate FES and its variants on the Meta-Dataset benchmark described in Sect. 2.1, we use an extractor collection containing eight extractors, each independently pretrained on a Meta-Dataset source domain. In our primary set of experiments, all extractors are ResNet18 models (He et al., 2016) and identical to the source domain extractors used in the publication introducing URL (Li et al., 2021). Note that the extractors are trained on the training split of the source domain data only. The source domain validation split is used to select a trained checkpoint.

FES is compatible with any fine-tuning algorithm that is applicable to the individual extractors. In our experiments, we save a snapshot of each extractor before fine-tuning and save a snapshot after each iteration. We evaluate FES with three fine-tuning methods used by state-of-the-art CDFSL methods in the literature:

  • TSA (Li et al., 2022)—matrix residual adaptors attached to convolutional layers, and a fully-connected layer to project feature vectors.

  • URL (Li et al., 2021)—only a fully-connected layer to project feature vectors.

  • FLUTE (Triantafillou et al., 2021)—scaling and shift factors of batch normalisation layers.

When performing each fine-tuning method for FES, we use the hyperparameters as stated in the source publications, including optimiser type, learning rate, number of iterations, etc., and we compare FES to each source method. The URL (Li et al., 2021) and TSA (Li et al., 2022) papers fine-tune their feature extractors for 40 iterations, leading to 41 FES snapshots per extractor. The FLUTE (Triantafillou et al., 2021) paper fine-tunes its feature extractor for six iterations, leading to seven FES snapshots per extractor.

We adhere to the TSA, URL, and FLUTE papers when replicating and evaluating their methods as benchmarks. Pretrained universal extractors are obtained from the official repositories, and hyperparameter settings are consistent with the papers’ specifications. Note that both the URL and TSA papers used the same URL-distilled universal extractor, and their difference is in fine-tuning, i.e., only fine-tuning a feature projection (URL) or additionally fine-tuning convolutional channel projections (TSA).

We use an LBFGS optimiser to train the stacking classifier, applying its default hyperparameters in the PyTorch library (Paszke et al., 2019), except that we utilise its line search function. A ridge regularisation of strength \(1\textrm{e}^{-2}\) is applied to FES and ConFES to make the LBFGS optimiser more numerically stable. Adjusting the regularisation strength up or down by an order of magnitude does not substantially affect classification accuracy.

Table 1 Meta-Dataset episode statistics

Meta-Dataset’s sampling randomness may cause one or two percent accuracy fluctuation of evaluated methods between different runs, as also stated in URL and TSA’s code repositories (Li et al., 2022). This fluctuation may exceed the 95% confidence interval of most results, so to eliminate it, we sample 600 episodes from each domain once in Meta-Dataset. The sampled episodes are cached and then used to evaluate all CDFSL methods. In a dataset, the numbers of classes and instances are randomly sampled for each episode, which means that different episodes can contain different numbers of classes and instances. In an episode, the number of instances is randomly sampled for each class, which means that different classes can contain different numbers of instances, and episodes can be class-imbalanced. However, the query set is stratified and always contains 10 instances per class.

Triantafillou et al. (2021) pointed out that Meta-Dataset instances need to be shuffled during sampling in case of datasets with particular ordering, e.g., traffic_sign contains consecutive frames from the same video, but their shuffling solution was implemented as a moving window of size 1,000 for streams of instances of each class, which we found to be potentially insufficient, leading to approximately 1% better accuracy in mscoco and 3% better accuracy in ChestX than true random sampling. We found that a window size of 10,000 yielded virtually the same level of accuracy as true random sampling, but nevertheless use true random sampling in our experiments, i.e., instances in each class are fully randomised and have equal chance of being selected, and episodes are completely independent of each other. Statistics of our sampling run are shown in Table 1. Using exactly the same sampled episodes for each learning scheme compared also allows us to perform a paired t-test on a per-dataset basis as a more sensitive statistical difference test than simply comparing two algorithms’ mean accuracy and confidence intervals. In addition, we rank the algorithms and show their critical difference diagrams (Demsar, 2006) in weak and strong generalisation.

Considering the complexity of the optimisation problem when learning the stacking classifier, it is worth noting that the FES and ReFES stacking classifiers each maintain \(8 \times 41 = 328\) parameters if the extractors are fine-tuned for 40 iterations, and \(8 \times 7 = 56\) parameters if the extractors are fine-tuned for 6 iterations.

ConFES is applied with a two-level hierarchy, i.e., a low-level depthwise 1D convolutional kernel and a high-level global kernel. For 40-iteration fine-tuning, the convolutional kernel has size \(L = 9\) with stride \(T = 4\), leading to a feature sequence/global kernel of length 9. Consequently, ConFES has \(8 \times 9 + 8 \times 9 = 144\) parameters in total. For 6-iteration fine-tuning, the convolutional kernel has size 3 with stride 2, leading to a global kernel of length 3, and therefore ConFES contains \(8 \times 3 + 8 \times 3 = 48\) parameters in total. All parameters are initialised with a constant \((1\textrm{e}^{-3})^{\frac{1}{h}}\), where h is the number of hierarchical levels in the stacking classifier. Therefore, FES and ReFES are initialised with \(1\textrm{e}^{-3}\), and a two-level ConFES hierarchy is initialised with \((1\textrm{e}^{-3})^{\frac{1}{2}}\). This initialisation is deterministic and ensures that the product of weights from all levels is close to \(1\textrm{e}^{-3}\), which is small enough for optimisation to go in either direction, but also big enough to avoid exceedingly small derivatives in gradient-based optimisers.

To facilitate grid search for the \(\lambda _1\) and \(\lambda _2\) values of ReFES, a pool of eight potential values is provided for each hyperparameter: 1, \(1\textrm{e}^{-1}\), \(1\textrm{e}^{-2}\), \(1\textrm{e}^{-3}\), \(1\textrm{e}^{-4}\), \(1\textrm{e}^{-5}\), \(1\textrm{e}^{-6}\), and 0.

5 Results

We present CDFSL results of FES, ConFES, ReFES, and the competing methods URL, FLUTE, and a URL extractor with TSA fine-tuning, on the Meta-Dataset benchmark and show that FES and its variants advance the state of the art on this benchmark. We then visually analyse an example of trained FES, ConFES, and ReFES kernels. Lastly, we examine the ability of FES, ConFES, and ReFES to omit snapshots with their non-negative kernels.

Table 2 Meta-Dataset results with TSA fine-tuning
Table 3 Statistically significant number of wins of column algorithm over row algorithm using paired t-test results with TSA fine-tuning
Fig. 5
figure 5

TSA weak generalisation critical difference diagram

Fig. 6
figure 6

TSA strong generalisation critical difference diagram

Table 4 Meta-Dataset results with URL fine-tuning
Table 5 Statistically significant number of wins of column algorithm over row algorithm using paired t-test results with URL fine-tuning
Fig. 7
figure 7

URL weak generalisation critical difference diagram (\(p > \alpha\))

Fig. 8
figure 8

URL strong generalisation critical difference diagram

Table 6 Meta-Dataset results with FLUTE fine-tuning
Table 7 Statistically significant number of wins of column algorithm over row algorithm using paired t-test results with FLUTE fine-tuning
Fig. 9
figure 9

FLUTE weak generalisation critical difference diagram

Fig. 10
figure 10

FLUTE strong generalisation critical difference diagram

5.1 Meta-dataset results

Results are organised by fine-tuning algorithms used, to provide a comparison between different CDFSL algorithms with the same fine-tuning scheme. The universal model of URL (Li et al., 2021), applied with TSA fine-tuning (Li et al., 2022), is the most recent and strongest CDFSL approach in the literature. Hence, we compare to this universal-model approach first, applying TSA fine-tuning in our FES methods as well in this comparison. Following that, we present experiments with the simpler fine-tuning approach used in the original URL (Li et al., 2021) paper. Finally, we evaluate FLUTE (Triantafillou et al., 2021) fine-tuning, which fine-tunes batch norm parameters only, and compare to the FLUTE universal template model.

Results with TSA fine-tuning are shown in Table 2, and paired t-test results based on the 600 individual accuracy values per dataset are shown in Table 3. Results with URL fine-tuning are shown in Tables 4 and 5, and those with FLUTE fine-tuning are shown in Tables 6 and 7.

In these tables, mean accuracy over 600 episodes and 95% confidence intervals are shown for each algorithm and dataset, and weak and strong generalisation accuracy and ranks averaged over all individual episodes are listed below the datasets. The best result of each row is shown in bold. If a paired t-test between a FES algorithm and the corresponding universal model/template (in the leftmost column) returns a p value less than 0.05, the null hypothesis (that there is no statistically significant difference) is rejected, and the FES result is marked with either \(\circ\) if it has higher accuracy, or \(\bullet\) if its competitor has higher accuracy.

The tables showing paired t-test results are split by weak generalisation (the eight source domains) and strong generalisation (the ten target domains). Each value indicates the number of datasets where the algorithm in the value’s column significantly outperforms the algorithm in its row according to the paired t-test.

Figures 5, 6, 7, 8, 9, and 10 are critical difference diagrams produced by the Nemenyi test applied with \(\alpha =0.05\), where algorithms are ranked using all relevant accuracy values (8 datasets \(\times\) 600 episodes for weak generalisation, and 10 datasets \(\times\) 600 episodes for strong generalisation). A Friedman test is first performed on all algorithms with the same \(\alpha\) to reject the null hypothesis. A Nemenyi test is then performed to group algorithms with no statistically significant difference into cliques via horizontal lines. Note that the Friedman p value is greater than \(\alpha\) for URL weak generalisation, i.e., Fig. 7, and the null hypothesis over all classifiers cannot be rejected in this case.

When using the same fine-tuning scheme, FES and its variants outperform their competitor CDFSL algorithms—building a universal model using knowledge distillation for URL and its TSA fine-tuning variant, and training a universal template with FiLM layers for FLUTE—in strong generalisation, where learning problems qualify as being cross-domain. The FES algorithms achieve better average accuracy and obtain more wins than losses in paired t-tests. They also rank higher than their competitors in the critical difference diagram.

Considering results with all three fine-tuning methods, the FES algorithms consistently outperform their competitors by a substantial margin on traffic_sign, CropDisease, and Food101, while being outperformed on cifar10 and cifar100. This phenomenon may indicate that FES and its variants perform better in domains that are more specialised, while their competitors gain an edge on datasets more similar to ImageNet, such as the CIFAR datasets. This speculation is supported by the fact that the competitor methods artificially attach greater importance to ImageNet when their universal models are obtained (Triantafillou et al., 2021; Li et al., 2021, 2022).

All three FES variants exhibit good CDFSL performance. Which variant is to be preferred depends on each specific use case: FES is the simplest and most versatile; ConFES maintains a smaller number of parameters and therefore a more manipulable search space; and ReFES uses regularisation to achieve smoother and more interpretable snapshot selections.

Fig. 11
figure 11

FES kernel for traffic_sign

Fig. 12
figure 12

ConFES kernels for traffic_sign

Fig. 13
figure 13

ReFES kernel for traffic_sign

5.2 Weight visualisation

Weights of the FES, ConFES, and ReFES kernels after fine-tuning with TSA on traffic_sign are visualised in Figs. 11, 12, and 13. The weights are averaged over 600 episodes.

ConFES maintains two kernels: a low-level depthwise 1D convolutional kernel (12a) and a high-level global kernel (12b). The two kernels can be expanded back into a global kernel (12c) for interpretation because the output of the convolutional kernel 12a serves as direct input to the global kernel 12b, without any intermediate non-linear activation. Figure 12 demonstrates how ConFES emulates a 328-parameter FES kernel with only 144 parameters. The stepped pattern in the expanded ConFES kernel, where every fourth snapshot is assigned relatively greater weight than its neighbours, is an artefact of 1D convolution—with a kernel size of 9 and a stride size of 4, this pattern results from kernel overlaps.

FES determines that the fine-tuned ilsvrc_2012 (ImageNet) and quickdraw extractors are the most prominent contributors to its predictions, indicated by the dark regions on the right end of these two extractors’ rows in Fig. 11. ConFES and ReFES arrive at similar conclusions regarding contributors, but exhibit characteristics that reflect their classifiers’ behaviours: ConFES shows stepped patterns due to 1D convolution as in Fig. 12; ReFES shows smoother weight changes due to fused lasso regularisation as in Fig. 13.

Additional heatmaps visualising kernel weights on the other target domains are in Appendix A, shown by Figs. 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 and 40.

5.3 Snapshot omission

Table 8 Percentage of snapshots omitted by the stacking classifier

As FES kernels are constrained to be non-negative by clipping their weights with ReLU, some snapshots may have their corresponding weights set to 0 after clipping, which means logits from these snapshots do not contribute to the aggregated meta logits, and these snapshots can be omitted, i.e., they do not need to be saved and are not used for inference.

Table 8 shows the average percentage of snapshots omitted by a FES, ConFES, or ReFES stacking classifier, using TSA, URL, or FLUTE fine-tuning. Note that ConFES omission rates are computed using the expanded kernel, because zero values need to exist in the expanded kernel, instead of merely in one of ConFES’ hierarchical kernels, for the corresponding snapshots to be omitted. A higher omission percentage is considered better because omitting snapshots saves storage space and inference computation. Among the three methods, FES achieves the highest percentage of omission, generally between 60% and 80%, followed by ConFES, which achieves 30% to 70% omission in general, while ReFES achieves the least amount of omission, mostly below 40%. FES achieves higher omission rates than ConFES and ReFES, but trades off mean strong generalisation accuracy as shown in Tables 2, 4, and 6.

6 Ablation study

Table 9 Ablation results with TSA fine-tuning
Table 10 Ablation results with URL fine-tuning
Table 11 Ablation results with FLUTE fine-tuning

We perform an ablation study by removing cross-validation from the framework and/or using only the first or last snapshots in fine-tuning. When cross-validation is not used, training logits for the “stacking” classifier are extracted from the support set using snapshots fine-tuned on the entire support set, akin to how one-shot episodes are handled in Sect. 3.5. When using only the first or last snapshots, the stacking classifier is a degenerate weight kernel with a singleton dimension for fine-tuning iterations, simply containing one weight value for each extractor. Results are shown in Tables 9, 10, and 11, organised by the fine-tuning algorithm used.

The results show that methods using cross-validation outperform their counterparts without cross-validation. Moreover, using all snapshots achieves better performance than using only the first or last snapshots in terms of mean strong generalisation performance for URL and TSA fine-tuning. For strong generalisation with FLUTE fine-tuning, using only the last snapshots leads to better performance. This could be due to the smaller number of fine-tuning iterations performed by FLUTE, as the last snapshots constitute a more substantial part of all snapshots. It is worth noting that cross-validation is helpful even when only using the first snapshots before any fine-tuning because the training logits are computed using a nearest centroid classifier, and cross-validation keeps the support instances for logit extraction separate from those used to compute the centroids, hence avoiding instance re-use and reducing overfitting.

7 Heterogeneous extractors

Table 12 Results of replacing the ResNet18 ImageNet extractor with a Small EfficientNetV2 pretrained on the 21K version of ImageNet, while the other seven extractors remain the same
Table 13 Comparison between applying FES to an ImageNet-pretrained EfficientNetV2 extractor alone and applying FES to an extractor collection containing it and the seven other ResNet18 source domain extractors

FES and its variants operate in logit space, which means they are independent of the architecture and feature size of each extractor. Therefore, they can naturally work with heterogeneous extractor collections. We demonstrate this by replacing the ResNet18 ImageNet extractor in the source domain collection with a more advanced Small EfficientNetV2 model (Tan & Le, 2021) pretrained on the 21K-class version of ImageNet, while keeping the seven other source domain ResNet18 extractors unchanged. The Small EfficientNetV2 model produces feature vectors of length 1280, as opposed to feature vectors of length 512 generated by ResNet18.

URL-style fine-tuning is used, i.e., a square matrix is used for feature projection and the matrix is initialised as an identity matrix. The results are shown in Table 12, and are compared to results of all eight extractors being ResNet18 models. Usage of the EfficientNetV2 model consistently improves FES performance in both weak and strong generalisation. Note that the evaluation’s main purpose is to show FES compatibility with heterogeneous model zoos, and its results are not directly comparable to the main results because the 21K-class ImageNet dataset used to pretrain the EfficientNetV2 model contains the Meta-Dataset ImageNet test split, which makes the ImageNet evaluation over-optimistic; moreover, test classes in the other domains may also be present in the 21K pretraining classes.

Since the EfficientNetV2 model is much more advanced than ResNet18, we investigate whether it dominates the extractor collection and effectively makes the other ResNet18 extractors irrelevant by performing FES using the single EfficientNetV2 extractor, with results in Table 13. Interestingly, all three FES variants obtain very similar accuracy when applied to only one EfficientNetV2 extractor, while their differences are shown more clearly when applied to a collection of eight extractors. Although using only EfficientNetV2 leads to better performance in a number of ImageNet-adjacent domains, e.g., ilsvrc_2012, dtd, vgg_flower, mscoco, and cifar10, it under-performs in most other domains, especially those significantly different from ImageNet, e.g., omniglot, aircraft, quickdraw, fungi, traffic_sign, mnist, CropDisease, ISIC, and ChestX.

Our EfficientNetV2 evaluation indicates: 1) FES and its variants are compatible with heterogeneous extractor collections, and 2) they are robust to discrepancies in extractor architectures and able to select relevant models from a diverse model zoo.

8 Limitations and discussion

Table 14 ResNet152 feature extractors with URL fine-tuning
Table 15 ResNet152 feature extractors with TSA fine-tuning
Table 16 Computational resource consumption of FES variants using TSA fine-tuning, compared to the official TSA algorithm applied to a URL ResNet18 or ResNet152 extractor
Table 17 Comparing the official URL model to a URL model distilled without favouring ImageNet

FES requires no universal extractor, which means the meta-training phase only requires pretraining a collection of extractors, similar to SUR. The cost for this is reduced to zero if pretrained extractors are readily available. However, FES is more expensive in the meta-testing phase in terms of both computation and storage, as it needs to fine-tune each extractor and save their snapshots instead of utilising a single universal extractor. The good performance of FES could be attributed to its increased capacity, as it maintains individual extractors instead of a single universal extractor. In the context of Meta-Dataset, FES maintains eight extractors, which means \(8 \times\) parameters compared to a universal model of the same architecture. Hence, in an additional experiment, we investigate larger universal models with capacities comparable to FES.

Originally, Li et al. (2021) distill eight ResNet18 extractor into a universal ResNet18 extractor. We distilled a universal ResNet152 (He et al., 2016) extractor using the same process. ResNet18 has 11 M parameters while ResNet152 has 60 M. We elected to use the same eight ResNet18 extractors for distillation, because pretraining eight ResNet152 extractors from scratch is prohibitively expensive for us, and this avoids introducing a confounding factor to meta-model evaluation because different base-model architectures may encompass source domain semantics differently. We also pretrained a universal ResNet152 model using “vanilla” multi-domain learning (MDL), i.e., one feature extractor is pretrained with all eight source domains’ data using eight classification heads, one for each domain. Compared to official ResNet18 URL training, we halved the mini-batch size (and doubled the number of iterations) to fit ResNet152 URL or MDL training in the 48GB memory of an NVIDIA A6000 GPU—the most advanced at our disposal. Tables 14 and 15 show their results with URL or TSA fine-tuning respectively, and compare them to using the official ResNet18 URL model, as well as FES variants with ResNet18 extractor collections. As TSA fine-tuning has high memory consumption, we forwent adaptors in the first and second convolutional blocks (shown to have a small impact on accuracy by Li et al. (2022)) to fit the ResNet152 TSA experiments on our NVIDIA A6000 GPU. In both tables, the ResNet152 URL model generally outperforms the ResNet18 URL and ResNet152 MDL models, and it achieves best average weak generalisation accuracy. Its mean strong generalisation accuracy is comparable to that of the FES variants, but individual results show that the methods excel at different tasks: the ResNet152 URL model performs better on mscoco, cifar10, cifar100, EuroSAT, and Food101, while the FES methods perform better on traffic_sign, mnist, CropDisease, ISIC, and ChestX—it appears that the ResNet152 URL model is better at ImageNet-adjacent tasks, while the FES methods are better at tasks that differ more substantially from ImageNet.

Table 16 compares the cost of FES inference using an NVIDIA A6000 GPU to that of URL ResNet18 and ResNet152 extractors. TSA fine-tuning is used by all methods in this table. It is worth pointing out that due to the few-shot nature of each episode, meta-testing is generally not time consuming. Table 16 represents the approximate upper bound of FES computation cost, because 1) the time presented in the table was measured using the largest traffic_sign episode in our cached sample, which contains 497 support instances, whereas smaller episodes consume less time, 2) URL and FLUTE fine-tuning are much less time-consuming than TSA, and 3) Sect. 5.3 shows that a portion of the snapshots does not in fact need to be computed and stored.

FES requires approximately \(2 \times K\) as much backpropagation as a universal extractor fine-tuned once, where 2 represents one fine-tuning run on the cross-validated support set (performed in two splits) and another on the full support set, and K represents the number of extractors. This is reflected in Table 16 as fine-tuning time for the FES methods is approximately 16 times that of fine-tuning the URL ResNet18 model. Time required to train a FES or ConFES stacking classifier is relatively trivial, while ReFES requires more time to determine its regularisation strength using grid search with cross-validation. FES stores multiple snapshots of each extractor during fine-tuning, but not all model parameters need to be saved. Only weights that are updated during fine-tuning need to be saved in snapshots, as the other unchanged weights can be loaded from the original extractor. Common CDFSL fine-tuning algorithms only update a relatively small set of weights: FLUTE fine-tunes batch normalisation weights, URL fine-tunes a feature projection, and TSA fine-tunes channel projections and a feature projection. Therefore, FES snapshots are normally lightweight. Table 16 shows that FES with TSA fine-tuning needs to store approximately 580 M parameters—2.32GB—which can fit in most modern GPUs during inference. As FES can fine-tune its extractors sequentially, its memory requirement is comparable to fine-tuning a single extractor with the same method. On the other hand, FES can easily be parallelised to fine-tune multiple extractors at once, should multiple GPUs be available.

Considering computational effort required for meta training, it is worth noting that even though a universal extractor only needs to be trained once, this training process may take days (for ResNet18) to weeks (for ResNet152) on an NVIDIA A6000 GPU; if an individual extractor is added or updated, training of a universal extractor needs to be performed again.

The official URL model was distilled in a process favouring ImageNet by including as many ImageNet instances as the other seven source domains combined in each mini-batch (Li et al., 2021). We distilled an alternative URL model while treating all source domains equally. Their comparison is shown in Table 17. The official model performs better in a majority of domains. This indicates that URL distillation may require external knowledge to focus on the right domains to achieve optimal performance. FES and its variants treat all extractors equally a priori and determine their task-specific relevance based purely on the support set.

9 Future work

FES exhibits good CDFSL performance with multiple source domains. It may be feasible to generalise it to other multi-domain learning problems, e.g., multi-domain transfer learning with a more substantial amount of labelled target domain training data.

The heatmaps show that FES generally assigns significant weights to only a small subset of extractor snapshots, implicitly nullifying a majority of snapshots that it deems irrelevant. Pruning strategies may be applied to FES to explicitly eliminate irrelevant snapshots to reduce computational costs.

FES maintains no prior bias to any source domain extractor, and its posterior bias depends on the support set only. In scenarios where prior knowledge is available regarding source and target domain relations, it may be beneficial to enable the user to apply explicit prior biases to certain source domains. This could be achieved in the form of regularisation, e.g., different regularisation pressures are applied to weights associated with different source domains.

10 Conclusion

We present the stacking-based CDFSL method FES and the variants ConFES and ReFES. The FES algorithms create snapshots from fine-tuning independent extractors on the support set, use cross-validation to avoid overfitting from support data reuse, and train a simple stacking classifier to appropriately weight the snapshots. FES, ConFES, and ReFES advance the state-of-the-art on the Meta-Dataset benchmark.

Perhaps more importantly, the FES approaches have some practical advantages in real-world scenarios compared to recent methods based on universal models. FES can work with out-of-the-box heterogeneous extractors. If the extractors are readily available, FES does not require their pretraining data down-stream. Its stacking classifier requires little hyperparameter tuning. FES is also computationally cheaper, unless the number of few-shot learning tasks is very large, e.g., in the thousands, where the total cost of performing FES on all tasks begins to exceed that of training a universal model once. Therefore, to field practitioners who wish to use extractors and fine-tuning algorithms specific to their work, FES is likely more flexible and user-friendly than universal-model methods.