1 Introduction

Supervised learning models (e.g., Convolutional Neural Networks, CNNs) with appropriately learned parameters can generalize well to the test data [1, 2], under the assumption that both the training and test data are i.i.d. samples from a single joint probability distribution \(P(\varvec{x},y)\) of the features \(\varvec{x}\) and the class label y. While this is a reasonable assumption to make, it could be violated in real-world applications. For example, in object recognition, the training and test visual data can be collected from different joint distributions, each of which represents a certain combination of image content, image style, illumination, and background [3, 4]. Under such circumstances, even the powerful and expressive CNN models may fail to produce accurate predictions for the test data.

Domain generalization [5, 6], the science of generalizing a prediction model across different domains, exactly aims at addressing this non-i.i.d. supervised learning problem. To be specific, its goal is to train a prediction model by leveraging labeled data from multiple source domains, and boost the generalization ability of this model in an unseen target domain. In line with the terminology in the work of Muandet et al. [6], here we refer to a domain as a joint probability distribution \(P(\varvec{x},y)\). So far, domain generalization has been studied in various applications, such as action recognition [7, 8], object recognition [9, 10], and medical diagnosis [11, 12].

A basic premise behind domain generalization is that the unseen target domain should somehow relate to the source ones. Otherwise, it may be impractical to perform cross domain model generalization. From a statistical point of view, since a joint distribution \(P(\varvec{x},y)\) can be decomposed into \(P(\varvec{x},y)=P(\varvec{x})P(y \vert \varvec{x})\) or \(P(\varvec{x},y)=P(\varvec{x} \vert y)P(y)\), existing works generally assume that the source and target domains are related by a feature transformation, which makes the marginal distributions \(P(\varvec{x})\), the class-conditional distributions \(P(\varvec{x} \vert y)\), or the posterior distributions \(P(y \vert \varvec{x})\) similar among the domains [4, 6, 7, 13,14,15]. Based on this feature transformation and other assumptions on the distributions (e.g., the prior distribution P(y) being stable), the discrepancy between the source and target joint distributions is expected to decrease, and thus a source trained prediction model can generalize well to the target domain. Muandet et al. [6] assumed that the posterior distribution is stable across domains in the problem of automatic gating of flow cytometry data, and learned a dimension reduction matrix to match multiple source marginal distributions under the distributional variance, a metric that relies on the kernel mean embedding technique [16, 17]. Li et al. [7] continued to make this assumption in computer vision, and generalized a classification model to an unseen visual domain via learning domain invariant representations, which are produced by a neural network that matches the source marginal distributions in the activation space. However, from a causal learning perspective, reference [4] shows that for computer vision tasks such as object recognition, where the object classes are the causes of image features, not only the marginal distribution but also the posterior distribution can change across the visual domains. Accordingly, another line of works [3, 4, 18] propose to match the class-conditional distributions among multiple source domains, and assume that the prior distribution is stable. For instance, Conditional Invariant Deep Domain Generalization (CIDDG) [4] plays a minimax adversarial game between a network mapping and multiple discriminators to align the source class-conditional distributions in the activation space, which is also shown to be equivalent to matching these distributions under the generalized Jensen-Shannon (JS) divergence [19]. Due to a similar observation that the stability assumption of the posterior distribution may not hold, Zhao et al. [14] proposed to align the source posterior distributions as well as the source marginal distributions through adversarial training.

In this study, we address the domain generalization problem by exploiting the neural network model which contains a network mapping and a probabilistic classifier. In contrast to prior works [3, 4, 7, 14], we characterize the domain relationship as a network mapping that makes the source and target joint distributions similar. Under this characterization, we learn the network mapping via matching multiple source joint distributions to their mixture distribution, all of which are reflected by the accessible labeled source data, and learn a subsequent probabilistic model for target domain classification. After matching the distributions, we expect the network mapping to reduce the discrepancy between the source mixture joint distribution and the target joint distribution in the activation space, and consequently boost the generalization ability of the probabilistic classifier in the target domain.

To be specific, we match the source joint distributions under the Kullback-Leibler (KL) divergence, and show that by introducing a domain label variable l, the problem of approximating this divergence can be transformed into the problem of estimating a domain label posterior distribution \(P(l \vert \varvec{x},y)\). We then model this discrete posterior distribution as multiple linear functions, and show that by minimizing the \(L^{2}\)-distance, the optimal parameters of these functions can be estimated analytically. As a result, we obtain an explicit estimate for the KL divergence. When matching distributions in the activation space, this explicit KL divergence estimate frees us from tackling the challenging minimax problems (e.g., the ones in prior adversarial works [4, 14]), and allows us to solve a straightforward minimization problem similar to the ones in the kernel mean matching works [3, 6, 18]. Furthermore, to learn the downstream probabilistic classifier we minimize the typical cross-entropy loss. Our overall optimization model is a minimization problem that optimizes both the parameters of the network mapping and the probabilistic classifier to minimize a combination of the estimated KL divergence and the cross-entropy loss. In the remainder, we name our solution DGDE (Domain Generalization by Distribution Estimation) for convenience. To summarize, our contributions are as follows.

  • We propose the DGDE solution for domain generalization, which trains a neural network by optimizing its parameters to minimize the estimated KL divergence among the source joint distributions in the activation space, and the cross-entropy loss of the probabilistic classifier.

  • We show that in our case the KL divergence can be approximated via estimating a domain label posterior distribution, and that the distribution estimation can be conducted in an analytic manner by appropriately selecting the loss function and the hypothesis space. This brings an explicit estimate for the KL divergence, allowing us to solve a simple and straightforward minimization problem when matching the distributions.

  • We demonstrate the effectiveness of DGDE on several real-world applications, including object recognition, action recognition, and face recognition.

2 Related work

We discuss the domain adaptation and domain generalization works that are most related to our solution.

2.1 Domain adaptation

Domain adaptation [20, 21] is closely related to domain generalization in the sense that it also aims at learning a prediction model from the source data and generalizing it to the related but differently distributed target data. The key difference between them is that in domain adaptation the unlabeled target data are available for training the prediction model, while in domain generalization they are only accessible when testing the model. In general, domain adaptation can be addressed by matching the distributions between domains and training a prediction model [22,23,24,25,26]. Cicek and Soatto [27] aligned the source and target class-conditional distributions, encouraged them to have disjoint support, and finally employed semi-supervised learning tools to improve the generalization ability of the classifier. Hu et al. [28] consistently aligned the marginal and class-conditional distributions between domains by constraining the gradient of marginal and class-conditional alignment to be synchronous. Yang et al. [29] proposed a bi-directional class-level adversaries framework for domain adaptation, which optimizes the bi-directional adversarial loss and the class-level discrepancy loss. Chen et al. [30] made use of the CNN to directly align the source and target joint distributions under the relative chi-squared divergence, and simultaneously learned a probabilistic classifier for classifying the target data.

Different from these domain adaptation methods, our domain generalization solution DGDE works under the setting where the unlabeled target data are not observed in advance, and improves the generalization ability of the neural network model in the unseen target domain by matching the source joint distributions.

2.2 Domain generalization

The problem of generalizing a prediction model from multiple source domains to an unseen target domain is first explored in machine learning and computer vision [5, 31]. Muandet et al. [6] formally introduced the terminology “domain generalization” for this problem, and improved a source trained classifier to an unseen target domain by proposing the Domain-Invariant Component Analysis (DICA) approach. In particular, DICA finds a feature transformation by minimizing the distributional variance among multiple source marginal distributions, and also preserves the functional relationship between input and output variables. Thereafter, matching the distributions of multiple source domains has become a fundamental solution to domain generalization [3, 4, 7, 14, 15, 32].

Scatter Component Analysis (SCA) [13] learns a projection matrix via minimizing the distributional variance among the source marginal distributions, as well as maximizing the separability of the classes and the separability of the unlabeled data. MMD-based Adversarial Auto-Encoder (MMD-AAE) [7] aligns the distributions of the coded source features via minimizing the Maximum Mean Discrepancy (MMD) [17], and simultaneously matches the aligned distribution to a prior Laplacian distribution via minimizing the chi-squared divergence between them. Based on a famous adversarial work [22], Adversarial Feature Learning with Accuracy Constraint (AFLAC) [32] not only learns source domain invariant features, but also ensures that the domain invariance does not interfere with the classification accuracy. By contrast, Conditional Invariant Domain Generalization (CIDG) [18] finds a dimension reduction matrix to match the source class-conditional distributions, and the source class prior-normalized marginal distributions, both under the distributional variance. This approach is then extended to its end-to-end deep counterpart CIDDG [4], in which the projection matrix is replaced by the neural network mapping, and the distributional variance by the generalized JS divergence. Moreover, Domain Generalization via Entropy Regularization (DGER) [14] aligns the source marginal distributions, and further matches the source posterior distributions via entropy regularization.

Apart from distribution matching, domain generalization is also addressed in other manners [9, 33, 34]. Li et al. [8] designed an episodic training procedure to train a deep network in a way that exposes it to the distribution shift that characterizes a novel domain at runtime. Dou et al. [11] proposed to enforce semantic features via global class alignment and local sample clustering, with losses explicitly derived in an episodic learning procedure. Zhang et al. [34] proposed a disentangled learning framework for domain generalization, which separates semantic and variation representations into different subspaces while enforcing invariance constraints. Gao et al. [35] performed meta-learning to find a reusable white-box loss function, which is solved using the Implicit Function Theorem (IFT) to obtain gradients of the target domain performance with respect to the source domain loss parameters.

Our DGDE explores the distribution matching solution to domain generalization, but it is pretty different from the previous attempts [4, 6, 13,14,15, 18] in this line. In particular, DGDE directly matches the source joint distributions for domain generalization, rather than respectively matching their components (the marginal distributions, the class-conditional distributions, etc.), which is practiced in [6, 14, 18]. Additionally, as a crucial building block of DGDE, the explicit KL divergence approximator (i.e., Eq. (17)), which is derived via innovatively estimating the domain label posterior distribution, enables our approach to match the distributions via solving a simple minimization problem, rather than the challenging minimax problems tackled in prior works [4, 14], which also leverage the KL divergence for distribution comparison.

3 Domain generalization by distribution estimation

In this section, we first describe the domain generalization problem and present our motivation. Following that, we elaborate on the estimation of the distribution \(P(l \vert \varvec{x},y)\) for divergence approximation in Sect. 3.1, and the estimation of the distribution \(P(y \vert \varvec{x})\) for classification in Sect. 3.2. Eventually, we present the optimization model and the learning algorithm in Sect. 3.3. For clarity and easy readability, we present in Table 1 an overview of the mathematical symbols used to describe our solution.

Table 1 Symbols and their descriptions

Let \(\mathcal {X}\) be an input feature space, \(\mathcal {Y}=\{1, \ldots , c\}\) be a class label space, and \(\mathcal {L}=\{1, \ldots , n\}\) be a domain label space. With random variables \(\varvec{x} \in \mathcal {X}\), \(y \in \mathcal {Y}\), and \(l \in \mathcal {L}\), we define a joint probability distribution for each domain l as \(P(\varvec{x},y \vert l)\). In domain generalization, the training data consist of n i.i.d. datasets \(\mathcal {D}^{1}=\{(\varvec{x}_{i}^{1},y_{i}^{1})\}_{i=1}^{m_{1}}\), \(\ldots\), \(\mathcal {D}^{n}=\{(\varvec{x}_{i}^{n},y_{i}^{n})\}_{i=1}^{m_{n}}\), which are respectively drawn from n related source domains \(P(\varvec{x},y \vert l=1)\), \(\ldots\), \(P(\varvec{x},y \vert l=n)\). Note that, the union of these datasets can also be viewed as an i.i.d. set \(\mathcal {D} = \mathcal {D}^{1}\cup \ldots \cup \mathcal {D}^{n}=\{(\varvec{x}_{i},y_{i})\}_{i=1}^{m}\) sampled from the source mixture joint distribution \(P(\varvec{x},y)=\sum _{s=1}^{n}P(\varvec{x},y \vert l=s)P(l=s)\), where the number of samples \(m=m_{1}+\ldots +m_{n}\). Given these training data, the goal of domain generalization is to learn a classification model \(f: \mathcal {X} \rightarrow \mathcal {Y}\) that generalizes well to an unknown but related target domain. Namely, the model should well predict the labels of samples governed by the target joint distribution \(P(\varvec{x},y \vert l=t)\).

Fig. 1
figure 1

The logic behind our solution to domain generalization

To ensure successful model generalization, it is crucial to exploit the relationship among domains. Here, we characterize the domain relationship as a neural network mapping F parameterized by \(\varvec{\Theta }_{F}\), which matches the source and target joint distributions in the activation space, i.e., Fig. 1-\(\textcircled {1}\). Note that this characterization is appropriate and similar ones have also been introduced in [4, 14, 15]. Since both the target joint distribution and its random samples are not accessible, we therefore learn this mapping via matching the n source joint distributions to their mixture distribution in the activation space,Footnote 1i.e., Fig. 1-\(\textcircled {2}\). We expect that such a mapping F can also generalize to the target joint distribution and make it similar to the source mixture joint distribution, i.e., Fig. 1-\(\textcircled {3}\). Under such circumstances, a source trained probabilistic classifier \(f(\varvec{x})=\mathop {\textrm{argmax}}_{y \in \mathcal {Y}}P(y \vert F(\varvec{x};\varvec{\Theta }_{F});\varvec{\Theta }_{C})\) with \(\varvec{\Theta }_{C}\) being its parameter, generalizes well to the target domain, i.e., Fig. 1-\(\textcircled {4}\). Specifically, we exploit the KL divergence and the cross-entropy loss to respectively quantify the distribution discrepancy and the classification loss, and match the distributions and learn the classification model via minimizing a cost function in the form

$$\begin{aligned} \mathcal {L}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})=\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F}) + \gamma \mathcal {L}_{\text {d}}(\varvec{\Theta }_{F}). \end{aligned}$$
(1)

Here, \(\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})\) is the cross-entropy loss of the probabilistic classifier, \(\mathcal {L}_{\text {d}}(\varvec{\Theta }_{F})\) is the estimated KL divergence from the n source joint distributions to their mixture distribution in the activation space, and \(\gamma ~(>0)\) is a tradeoff parameter for balancing the two terms. Below, we detail out the form of \(\mathcal {L}_{\text {d}}(\varvec{\Theta }_{F})\) and \(\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})\).

Remark 1

In real-world scenes, the objects from source domains \(P(\varvec{x}, y \vert l = 1), \ldots , P(\varvec{x}, y \vert l = n)\) may be quite different due to different camera viewpoints, backgrounds, or lighting conditions. By mapping the data to the activation space using the network mapping F, such redundant information irrelevant to object recognition could probably be reduced, making the source joint distributions similar in the space, i.e., \(P(F(\varvec{x}), y \vert l = 1) \approx \cdots \approx P(F(\varvec{x}), y \vert l = n)\).

3.1 Distribution estimation for divergence approximation

The KL divergence from the n source joint distributions \(P(\varvec{x},y \vert l=1),\ldots , P(\varvec{x},y \vert l=n)\) to their mixture distribution \(P(\varvec{x},y)=\sum _{s=1}^{n}P(\varvec{x},y \vert l=s)P(l=s)\) is defined asFootnote 2

$$\begin{aligned}&\sum _{s=1}^{n}\textrm{KL}\big (P(\varvec{x},y \vert l=s)\Vert P(\varvec{x},y)\big )\nonumber \\&=\sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s)\log \frac{P(\varvec{x},y \vert l=s)}{P(\varvec{x},y)}d\varvec{x}dy. \end{aligned}$$
(2)

Clearly, Eq. (2) is non-negative and takes 0 when \(P(\varvec{x},y \vert l=1)=\cdots =P(\varvec{x},y \vert l=n)=P(\varvec{x},y)\).

In the following derivations, we show that the KL divergence in Eq. (2) can be expressed by the domain label posterior distribution \(P(l \vert \varvec{x},y)\).

$$\begin{aligned}&\sum _{s=1}^{n}\textrm{KL}\big (P(\varvec{x},y \vert l=s)\Vert P(\varvec{x},y)\big ) \nonumber \\&\quad =\sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s)\log \frac{P(l=s \vert \varvec{x},y)P(\varvec{x},y)}{P(\varvec{x},y)P(l=s)}d\varvec{x}dy \end{aligned}$$
(3)
$$\begin{aligned}&\quad =\sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s)\log \frac{P(l=s \vert \varvec{x},y)}{P(l=s)}d\varvec{x}dy \end{aligned}$$
(4)
$$\begin{aligned}&\quad =\sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s)\log P(l=s \vert \varvec{x},y)d\varvec{x}dy \nonumber \\&\qquad - \sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s) \log P(l=s)d\varvec{x}dy \end{aligned}$$
(5)
$$\begin{aligned}&\quad =\sum _{s=1}^{n}\int P(\varvec{x},y \vert l=s)\log P(l=s \vert \varvec{x},y)d\varvec{x}dy \nonumber \\&\qquad - \sum _{s=1}^{n}\log P(l=s) . \end{aligned}$$
(6)

Equation (3) makes use of the Bayes’ rule and writes \(P(\varvec{x},y \vert l=s)\) as \(P(\varvec{x},y \vert l=s)=\frac{P(l=s \vert \varvec{x},y)P(\varvec{x},y)}{P(l=s)}\) for \(s \in \{1,2,\ldots , n\}\). Equation (4) cancels out the factor \(P(\varvec{x},y)\). Equation (5) expands the \(\log\) term in Eq. (4). Equation (6) holds since \(\int P(\varvec{x},y \vert l=s)d\varvec{x}dy=1\). For the first term in Eq. (6), we approximate the expectations with respect to distributions \(P(\varvec{x},y \vert l=1), \ldots , P(\varvec{x},y \vert l=n)\) by the empirical averages of their samples \(\mathcal {D}^{1}, \ldots , \mathcal {D}^{n}\), and estimate the KL divergence as

$$\begin{aligned}&\sum _{s=1}^{n}\textrm{KL}\big (P(\varvec{x},y \vert l=s)\Vert P(\varvec{x},y)\big ) \nonumber \\&\quad \approx \sum _{s=1}^{n}\frac{1}{m_{s}}\sum _{i=1}^{m_{s}}\log P(l=s \vert \varvec{x}_{i}^{s},y_{i}^{s}) \nonumber \\&\quad - \sum _{s=1}^{n}\log P(l=s), \end{aligned}$$
(7)

where \((\varvec{x}_{i}^{s},y_{i}^{s}) \in \mathcal {D}^{s}\). As such, the problem of divergence approximation naturally transforms into the problem of estimating the domain label posterior distribution \(P(l \vert \varvec{x},y)\).

To estimate the discrete posterior distribution \(P(l \vert \varvec{x},y)\), we model it as multiple linear functions and learn the function parameters via minimizing the \(L^{2}\)-distance between distributions. As will be shown shortly, such a choice leads to the analytic solution of the parameters and consequently an explicit estimate of the KL divergence. Note that, with other choices like the nonlinear functions and other divergences (distances) between distributions, the analytic solution may not be possible. In particular, let us model the discrete domain label posterior distribution \(P(l \vert \varvec{x},y)\) as:

$$\begin{aligned}&P(l=1 \vert \varvec{x},y;\varvec{\alpha }^{1}) = \sum _{i=1}^{m}\alpha _{i}^{1}p\big ((\varvec{x},y),(\varvec{x}_{i},y_{i})\big ), \end{aligned}$$
(8)
$$\begin{aligned}&\quad \cdots , \nonumber \\&\quad P(l=n \vert \varvec{x},y;\varvec{\alpha }^{n}) = \sum _{i=1}^{m}\alpha _{i}^{n}p\big ((\varvec{x},y),(\varvec{x}_{i},y_{i})\big ) , \end{aligned}$$
(9)

where \(p\big ((\varvec{x},y),(\varvec{x}_{i},y_{i})\big ) = k(\varvec{x},\varvec{x}_{i})\delta (y,y_{i})\) is a product of the feature and label kernels, and \(\varvec{\alpha }^{s} = (\alpha _{1}^{s}, \ldots , \alpha _{m}^{s})^{\top }~(s=1, \ldots , n)\) are the parameters of the functions. The feature kernel \(k(\varvec{x},\varvec{x}_{i})=\exp \Big (\frac{-\Vert \varvec{x}-\varvec{x}_{i}\Vert ^{2}}{\sigma }\Big )\) is the Gaussian kernel with positive kernel width \(\sigma\), and the label kernel \(\delta (y,y_{i})\) is the delta kernel that evaluates 1 if \(y=y_{i}\) and 0 otherwise. The linear-in-parameter functions from Eqs. (8) to (9) resemble the Radial Basis Function (RBF) networks and are reasonable choices for function approximation [36]. We learn parameters \(\varvec{\alpha }^{s}\) by matching \(P(l=s \vert \varvec{x},y;\varvec{\alpha }^{s})\) to the true distribution \(P(l=s \vert \varvec{x},y)\) under the \(L^{2}\)-distance:

$$\begin{aligned}&(\varvec{\alpha }_{\text {opt}}^{1}, \ldots, \varvec{\alpha }_{\text {opt}}^{n}) \nonumber \\&= \mathop {\textrm{argmin}}_{(\varvec{\alpha }^{1}, \cdots, \varvec{\alpha }^{n})}\Big (\int \sum _{s=1}^{n}\big (P(l=s \vert \varvec{x},y) \nonumber \\&~~~~~ - P(l=s \vert \varvec{x},y;\varvec{\alpha }^{s})\big )^{2} \times P(\varvec{x},y)d\varvec{x}dy\Big ) \end{aligned}$$
(10)
$$\begin{aligned}&= \mathop {\textrm{argmin}}_{(\varvec{\alpha }^{1}, \ldots, \varvec{\alpha }^{n})}\Big (\sum _{s=1}^{n}\int P(l=s \vert \varvec{x},y;\varvec{\alpha }^{s})^{2}P(\varvec{x},y)d\varvec{x}dy \nonumber \\&~~~~~ - 2\int P(l \vert \varvec{x},y;\varvec{\alpha }^{l})P(\varvec{x},y,l)d\varvec{x}dydl\Big ) . \end{aligned}$$
(11)

Here, the objective function in Eq. (10) is the \(L^{2}\)-distance between posterior distributions \(P(l=s \vert \varvec{x},y)\) and \(P(l=s \vert \varvec{x},y;\varvec{\alpha }^{s})\). Equation (11) expands the quadratic term in Eq. (10) and discards the constant \(\sum _{s=1}^{n}\int P(l=s \vert \varvec{x},y)^{2}P(\varvec{x},y)d\varvec{x}dy\). By approximating expectations via sample averages, we arrive at the empirical counterpart of Eq. (11):

$$\begin{aligned}&(\widehat{\varvec{\alpha }}^{1}, \ldots, \widehat{\varvec{\alpha }}^{n}) \nonumber \\&= \mathop {\textrm{argmin}}_{(\varvec{\alpha }^{1}, \ldots, \varvec{\alpha }^{n})}\Big (\frac{1}{m}\sum _{s=1}^{n}\sum _{i=1}^{m}P(l=s \vert \varvec{x}_{i},y_{i};\varvec{\alpha }^{s})^{2} \nonumber \\&~~~ -\frac{2}{m}\sum _{i=1}^{m}P(l_{i} \vert \varvec{x}_{i},y_{i};\varvec{\alpha }^{l_{i}}) + \lambda \sum _{s=1}^{n}\Vert \varvec{\alpha }^{s}\Vert ^{2}\Big ) \end{aligned}$$
(12)
$$\begin{aligned}&=\mathop {\textrm{argmin}}_{(\varvec{\alpha }^{1}, \ldots, \varvec{\alpha }^{n})}\Big (\frac{1}{m}\sum _{s=1}^{n}(\varvec{\alpha }^{s})^{\top }(\varvec{P}^{\top }\varvec{P})\varvec{\alpha }^{s} \nonumber \\&~~~ - \frac{2}{m}\sum _{s=1}^{n}\textbf{1}_{m_{s}}^{\top }\varvec{P}^{s}\varvec{\alpha }^{s} + \lambda \sum _{s=1}^{n}(\varvec{\alpha }^{s})^{\top }\varvec{\alpha }^{s}\Big ) \end{aligned}$$
(13)
$$\begin{aligned}&=\mathop {\textrm{argmin}}_{(\varvec{\alpha }^{1}, \ldots, \varvec{\alpha }^{n})}\sum _{s=1}^{n}\big ((\varvec{\alpha }^{s})^{\top }(\varvec{H}+\lambda \textbf{I}_{m})\varvec{\alpha }^{s} - 2(\varvec{b}^{s})^{\top }\varvec{\alpha }^{s}\big ) \end{aligned}$$
(14)
$$\begin{aligned}&= \big ((\varvec{H}+\lambda \textbf{I}_{m})^{-1}\varvec{b}^{1}, \ldots, (\varvec{H}+\lambda \textbf{I}_{m})^{-1}\varvec{b}^{n}\big ). \end{aligned}$$
(15)

In Eq. (12), a regularization term \(\lambda \sum _{s=1}^{n}\Vert \varvec{\alpha }^{s}\Vert ^{2}\) with regularization parameter \(\lambda ~(>0)\) is added to the empirical averages to avoid overfitting. Equation (13) writes the objective function in matrix form, where \(\varvec{1}_{m_{s}}\) is a \(m_{s}\)-dimensional column vector of ones, \(\varvec{P}^{s} \in \mathbb {R}^{m_{s} \times m}\), and \(\varvec{P} = \big ((\varvec{P}^{1})^{\top }, \ldots , (\varvec{P}^{n})^{\top }\big ) \in \mathbb {R}^{m \times m}\). The (ij)-th element of \(\varvec{P}^{s}\) is defined as \(p_{ij}^{s} = p\big ((\varvec{x}_{i}^{s},y_{i}^{s}),(\varvec{x}_{j},y_{j})\big )\). Equation (14) introduces three notations, where \(\varvec{H} = \frac{1}{m}\varvec{P}^{\top }\varvec{P}\), \(\varvec{b}^{s} = \frac{1}{m}(\varvec{P}^{s})^{\top }\textbf{1}_{m_{s}}\), and \(\textbf{I}_{m}\) is the \(m \times m\) identity matrix. These notations explicitly make Eq. (14) an unconstrained quadratic optimization problem, whose analytic solution is then presented in Eq. (15). Because a probability distribution is non-negative and sums up to 1, we process the estimated domain label posterior distribution as

$$\begin{aligned}&P_{\text {nor}}(l=s \vert \varvec{x},y;\widehat{\varvec{\alpha }}^{s}) \nonumber \\&= \frac{\max \{10^{-8}, P(l=s \vert \varvec{x},y;\widehat{\varvec{\alpha }}^{s}) \}}{\sum _{j=1}^{n}\max \{10^{-8}, P(l=j \vert \varvec{x},y;\widehat{\varvec{\alpha }}^{j})\}}. \end{aligned}$$
(16)

Plugging this estimated distribution into Eq. (7), we obtain the KL divergence approximator:

$$\begin{aligned}&\sum _{s=1}^{n}\widehat{\textrm{KL}}\big (P(\varvec{x},y \vert l=s)\Vert P(\varvec{x},y)\big )\nonumber \\&= \sum _{s=1}^{n}\frac{1}{m_{s}}\sum _{i=1}^{m_{s}}\log P_{\text {nor}}(l=s \vert \varvec{x}_{i}^{s},y_{i}^{s};\widehat{\varvec{\alpha }}^{s}) \nonumber \\&~ - \sum _{s=1}^{n}\log P(l=s). \end{aligned}$$
(17)

According to the above derivations, the estimated KL divergence from the n source joint distributions to their mixture distribution in the activation space, \(\mathcal {L}_{\text {d}}(\varvec{\Theta }_{F})\), is therefore defined as

$$\begin{aligned}&\mathcal {L}_{\text {d}}(\varvec{\Theta }_{F})\nonumber \\&=\sum _{s=1}^{n}\frac{1}{m_{s}}\sum _{i=1}^{m_{s}}\log P_{\text {nor}}(l=s \vert F(\varvec{x}_{i}^{s};\varvec{\Theta }_{F}),y_{i}^{s};\widehat{\varvec{\alpha }}^{s}), \end{aligned}$$
(18)

where \(F(\varvec{x};\varvec{\Theta }_{F})\) denotes the activation features produced by the network mapping F. Note that since our goal is to minimize the estimated divergence via optimizing the network mapping, we therefore drop the term \(\sum _{s=1}^{n}\log P(l=s)\), which is clearly independent of the network mapping.

3.2 Distribution estimation for classification

After matching distributions via the network mapping F, we learn a downstream probabilistic model for target domain classification. To be specific, we aim to estimate another posterior distribution \(P(y \vert F(\varvec{x};\varvec{\Theta }_{F});\varvec{\Theta }_{C})\), which is the ultimate softmax output of the network. Following the common practice in [4, 32, 37], we exploit the cross-entropy loss to quantify the loss of this probabilistic model and define \(\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})\) as

$$\begin{aligned}&\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F}) \nonumber \\&=\frac{-1}{m}\sum _{i=1}^{m}\sum _{j=1}^{c}\delta (y_{i}, j)\log P(y=j \vert F(\varvec{x}_{i};\varvec{\Theta }_{F});\varvec{\Theta }_{C}), \end{aligned}$$
(19)

where \(\delta (\cdot , \cdot )\) is the delta kernel function previously defined in Sect. 3.1.

3.3 Optimization model and learning algorithm

figure a

Putting Eqs. (18) and (19) together, we present the optimization model of our DGDE solution as

$$\begin{aligned} \mathop {\textrm{min}}_{\varvec{\Theta }_{C},\varvec{\Theta }_{F}}\mathcal {L}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})=\mathcal {L}_{\text {c}}(\varvec{\Theta }_{C},\varvec{\Theta }_{F}) + \gamma \mathcal {L}_{\text {d}}(\varvec{\Theta }_{F}). \end{aligned}$$
(20)

Note that, our solution is general and can be implemented with either shallow or deep neural network model. As aforementioned, in the network model, the network mapping is parameterized by \(\varvec{\Theta }_{F}\), and the downstream probabilistic classifier is parameterized by \(\varvec{\Theta }_{C}\). In the experiments, we implement our DGDE with both shallow and deep neural network models to show its effectiveness.

We employ the minibatch Stochastic Gradient Descent (SGD) algorithm to solve Problem (20), and provide the pseudo code of the optimization procedure in Algorithm 1. In the algorithm, \(\nabla _{\varvec{\Theta }_{F}}\mathcal {L}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})\) is the gradient with respect to \(\varvec{\Theta }_{F}\), \(\nabla _{\varvec{\Theta }_{C}}\mathcal {L}(\varvec{\Theta }_{C},\varvec{\Theta }_{F})\) is the gradient with respect to \(\varvec{\Theta }_{C}\), and \(\eta ~(>0)\) is the learning rate.

4 Experiments

Below, we evaluate our DGDE solution on 6 real-world visual datasets (see Fig. 2), which are popular in the domain generalization and domain adaptation literature [7, 14, 38, 39]. We start by describing the datasets in Sect. 4.1, then introduce the experimental setup in Sect. 4.2, present the experimental results in Sect. 4.3, and eventually finish by conducting the empirical analysis in Sect. 4.4. Our solution is implemented using Pytorch,Footnote 3 and the experiments are run on a PC equipped with a NVIDIA RTX 3090 GPU and 24 G RAM.

4.1 Datasets

We first summarize the statistics of the datasets in Table 2, and then describe each dataset in the following.

Table 2 Statistics of the visual datasets
Fig. 2
figure 2

Example images from 6 datasets. a IXMAS [40]. b Office-Caltech [41]. c PIE-Multiview [39]. d VLCS [42]. e PACS [43]. f Office-Home [44]

IXMAS [40] is cross-view action recognition dataset. It contains videos of 11 human actions recorded from 5 different views (domains): View0 (V0), View1 (V1), View2 (V2), View3 (V3), and View4 (V4). Following prior works [7, 8, 13], we keep the first 5 actions and exclude the irregular actions, resulting in 91 image samples in each domain. See Fig. 2(a) for the example images.

Office-Caltech [41] contains 4 different visual object datasets: Amazon (A), Caltech (C), DSLR (D), and Webcam (D), which are acquired in different environments and share 10 object categories. In the experiments, each dataset is regarded as a domain and the number of samples in each domain is 958, 1123, 157, and 295, respectively. See Fig. 2(b) for the example images.

PIE-Multiview [39] is a face recognition dataset containing face images of 67 individuals captured from different views, illumination conditions, and expressions. This dataset has 6 subsets (domains): looking forward (C27), looking-downward (C09), looking towards left in an increasing angle (C05, C37, C25, C02). These 6 domains respectively contain 1404, 1407, 1407, 1404, 1407, and 1407 face images. See Fig. 2(c) for the example images.

VLCS [42] contains images from 4 well-known datasets (domains): VOC2007 (V) [45], LabelMe (L) [46], Caltech-101 (C) [47], and SUN09 (S) [48]. These domains (V, L, C, S) share 5 categories (i.e., bird, car, chair, dog, and person) and have 3376, 2656, 1415, and 3282 image samples in each of them. See Fig. 2(d) for the example images.

PACS [43] is composed of 4 subsets corresponding to 4 different image styles: ArtPainting (A), Cartoon (C), Photo (P), and Sketch (S), with images from 7 classes: dog, elephant, giraffe, guitar, house, horse, and person. In the experiments, each image style is viewed as a domain, and the number of images in each domain is 2048, 2344, 1670, and 3929, respectively. See Fig. 2(e) for the example images.

Office-Home [44] is a large visual recognition dataset that comprises 4 domains: Art (A, artistic depictions of objects), Clipart (C, clipart images), Product (P, objects without a background) and RealWorld (R, objects captured with a regular camera). There are 65 categories shared by these domains, and the number of images in each domain is 2421, 4379, 4428, and 4357, respectively. See Fig. 2(f) for the example images.

4.2 Experimental setup

4.2.1 Comparison methods

Our DGDE solution is general and can be implemented with both shallow and deep neural networks. For completeness, we therefore respectively compare our shallow and deep implementations against existing shallow and deep domain generalization methods. To be specific, the shallow competitors include DICA [6], SCA [13], CIDG [18], and Multidomain Discriminant Analysis (MDA) [3]. Since these are dimensionality reduction methods, we enable their classification ability via appending a softmax classifier to them. The end-to-end deep competitors encompass Deeper, Broader and Artier Domain Generalization (DBADG) [43], CIDDG [4], Cross-Gradient (CrossGrad) [49], Jigsaw puzzle based Generalization (JiGen) [9], Deep Domain-Adversarial Image Generation (DDAIG) [50], Mixture of Multiple Latent Domains (MMLD) [51], EISNet [52], Representation Self-Challenging (RSC) [53], DGER [14], Domain Invariant Representation learning with domain Transformations via Generative Adversarial Networks (DIRT-GAN) [15], Adversarial Teacher-Student Representation Learning (ATSRL) [54], Stochastic Weight Averaging Densely (SWAD) [55], and Proxy-based Contrastive Learning (PCL) [56].

4.2.2 Evaluation protocol

We run the domain generalization methods on n source domains, i.e., a collection of n image subsets here, and then employ them to classify the samples from an unseen target domain, i.e., a held-out image subset. This is also known as the leave-one-domain-out evaluation protocol [4, 8, 11, 57]. Particularly, following [7, 39, 58], we run the shallow methods on the IXMAS dataset with the 5000-dimensional dense trajectories features [13], on the Office-Caltech dataset with the handcrafted 800-dimensional SURF features [41], and on the PIE-Multiview dataset with the 1024-dimensional gray-scale image pixel features [39]. Moreover, we run the deep end-to-end methods on the VLCS dataset with the AlexNet [1] backbone following [4, 14, 15, 43], and on the PACS and Office-Home datasets with the ResNet50 [2] backbone following [14, 53, 54, 56]. To ensure fair comparison with prior deep results, on the VLCS dataset, we randomly divide each domain into a training set (70%) and a test set (30%), and evaluate on the test set of the held-out target domain [14]. On the PACS and Office-Home datasets, we split the data from the source domains to 9 (train): 1 (val) and test on the whole held-out target domain [54]. Finally, along the general practice in domain generalization, the multi-class classification accuracy (%) on the target domain is adopted as the performance metric for all the methods.

4.2.3 Implementation details

On datasets IXMAS, Office-Caltech, and PIE-Multiview, we implement our DGDE solution using the one-Hidden-Layer Neural Networks (1HLNN) with 2500, 400, and 512 hidden neurons, respectively. Thus, on each dataset the number of hidden neurons for each network is half of the number of its input neurons. These shallow networks with the ReLU activation are trained from scratch by the minibatch SGD with the learning rate \(\eta\) set to \(10^{-3}\) on IXMAS, and to \(10^{-2}\) on Office-Caltech and PIE-Multiview. Besides, the tradeoff parameter \(\gamma\) is selected from the range \(\{10^{-2}, 10^{-1}, 1, 10^{1}, 10^{2}\}\) via cross-validation on the training data. On datasets VLCS, PACS, and Office-Home, we start training our solution with the CNN backbones from their ImageNet pre-trained models following [14, 56]. The optimizer is still the minibatch SGD with the learning rate \(\eta =10^{-3}\). Furthermore, this time \(\gamma\) is not selected through a grid search as in the experiments with shallow networks, since the corresponding procedure would be computationally costly. Instead, following [22], \(\gamma\) is initiated at 0 and is gradually changed to 1 using the formula \(\gamma =\frac{2}{1+\exp (-10\times \textrm{iter}/\mathrm {total\_iter})}-1\), where \(\textrm{iter}\) is the current iteration times and \(\mathrm {total\_iter}\) is the total iteration times. For all the experiments, a minibatch in every iteration of the SGD algorithm consists of n minibatches, which are randomly sampled from the n source domains. The Gaussian kernel width \(\sigma\) is set to the median squared distance between the training data in a manner similar to [59]. The regularization parameter \(\lambda\) is fixed at \(10^{-2}\), whose sensitivity analysis will be presented in Sect. 4.4.4. Additionally, following [11, 14, 57] we repeat the experiment on each task with different random seeds, and report the average classification accuracy of our approach over 5 independent runs.

4.3 Experimental results

4.3.1 Results of shallow domain generalization methods

Table 3 Classification accuracy (%) of shallow domain generalization methods on dataset IXMAS. Under the leave-one-domain-out evaluation protocol, the names of the source domains are omitted
Table 4 Classification accuracy (%) of shallow domain generalization methods on dataset Office-Caltech
Table 5 Classification accuracy (%) of shallow domain generalization methods on dataset PIE-Multiview

We report the experimental results of the shallow domain generalization methods on datasets IXMAS, Office-Caltech, and PIE-Multiview in Tables 3, 4, and 5, respectively. In every table, the names of the source domains are omitted under the leave-one-domain-out evaluation protocol. For every column in the table, the best result is highlighted in bold.

The results from Tables 3 to 5 evidence that on majority of the tasks, our DGDE solution with shallow implementation performs significantly better than the comparison methods. In particular, DICA and SCA do not perform well since they only match the marginal distributions and neglect matching the posterior distributions, which is critical in domain generalization [14, 15]. By contrast, our DGDE solution aims at matching the joint distributions, which includes matching the marginal distributions and the posterior distributions. Therefore, our solution naturally outperforms its competitors. Note that to construct the domain generalization tasks from PIE-Multiview, i.e., Table 5, we follow the popular leave-one-domain-out evaluation protocol [4, 8, 11, 57]. However, on this dataset with 6 domains, we are curious about how our solution will behave under another protocol of leaving several domains out. Hence, we will explore this later in Sect. 4.4.2.

4.3.2 Results of deep domain generalization methods

Table 6 Classification accuracy (%) of deep domain generalization methods with the AlexNet backbone on dataset VLCS. Under the leave-one-domain-out evaluation protocol, the names of the source domains are omitted
Table 7 Classification accuracy (%) of deep domain generalization methods with the ResNet50 backbone on dataset PACS
Table 8 Classification accuracy (%) of deep domain generalization methods with the ResNet50 backbone on dataset Office-Home

We report the experimental results of the deep domain generalization methods on datasets VLCS, PACS, and Office-Home in Tables 6, 7, and 8, respectively. Note that, the results of the deep comparison methods are all cited from prior works, since the corresponding experimental settings are the same. To be specific, the results of the comparison methods in Table 6 are quoted from [14, 15, 60], the results in Table 7 from [14, 54, 56], and the results in Table 8 from [54, 56].

From the results in Tables 6, 7, and 8, we observe that on most tasks, our solution with deep implementation again outperforms its comparison methods, some of which are complex methods with a large set of model parameters and involve the heavy work of tuning quite a few parameters (e.g., MASF, DGER). For example, in Table 7, our DGDE performs much better than MASF and DGER on every task, and also yields superior results to the most recent method PCL. We believe that in domain generalization where one does not have much knowledge about the target domain, the solution should aim at addressing the most critical problem. Our DGDE is exactly targeted at tackling the critical joint distribution mismatch problem in domain generalization [14, 15], and therefore produces better results than the other methods. Together with the previous shallow results, we remark that our solution of matching the source joint distributions under our KL divergence approximator is indeed effective for addressing the domain generalization problem. And our solution works well with both the shallow models and the end-to-end deep architectures.

4.4 Empirical analysis

4.4.1 Model assumption

Fig. 3
figure 3

Estimated KL divergence values computed using the activation features from the Baseline and our DGDE networks, which are trained on the tasks with the target domains being DSLR and Webcam. a Divergence values from the source joint distributions to their mixture distribution. b Divergence values from the source mixture joint distribution to the target joint distribution

We examine the model assumption behind our solution to better understand why it can improve the generalization ability of a neural network in the domain generalization problem. To this end, we first check the estimated KL divergence from the source joint distributions to their mixture distribution, using the activation features from the Baseline and our DGDE networks. Here, the Baseline trains a normal neural network by simply minimizing the cross-entropy loss, i.e., the one defined in Eq. (19). The Baseline and DGDE networks are trained on the domain generalization tasks from dataset Office-Caltech, with the target domains being DSLR and Webcam. The results are illustrated in Fig. 3(a). We observe from this figure that the two divergence values of DGDE are much smaller than the ones of the Baseline, indicating that our DGDE network mappings (i.e., the hidden layers) are indeed effective in reducing the discrepancy among multiple source joint distributions in the activation space. More importantly, we then check the estimated KL divergence from the source mixture joint distribution to the target joint distribution, and depict the results in Fig. 3(b). From Fig. 3(b), we again observe that the two divergence values of DGDE are below the ones of the Baseline. This suggests that while the differently distributed target data are not available for training our DGDE networks, on the tested evaluations the mappings of our DGDE networks successfully generalize to the related target joint distributions, and narrow down the distribution gap between source and target. Consequently, our DGDE networks perform well in the target domain classification tasks.

4.4.2 Leave several domains out

Table 9 Classification accuracy (%) of shallow domain generalization methods on the PIE-Multiview dataset under the leave-several-domains-out evaluation protocol

We study our solution under the interesting evaluation protocol of leave-several-domains-out. Namely, we first run our DGDE on the same set of source domains, and then observe how the resulting network model will generalize to several different but related target domains. In particular, we employ the PIE-Multiview dataset with 6 domains, and construct 3 domain generalization tasks with the same set of source domains {C09, C05, C25} and the different target domains C27, C37, and C02. Note that such construction ensures that the 3 different target domains are all related to the source ones and hence the possibility of cross domain model generalization. Subsequently, we run our DGDE and other shallow domain generalization methods on {C09, C05, C25}, and report in Table 9 the classification accuracy in target domains C27, C37, and C02. Clearly, the results in Table 9 show that our DGDE network (a single network) manages to generalize well to different target domains and consistently outperforms the other domain generalization methods. This confirms that our solution is also effective in this new evaluation protocol for domain generalization.

4.4.3 Further comparison

Table 10 Comparison of the methods in terms of their parameters (\(>0\))

We compare our solution with other methods in terms of their parameters and present the results in Table 10. We note that our solution has the same number of parameters as most of the comparison methods. Specifically, as aforementioned, we set the value of our parameter \(\gamma\) by cross-validation or by following the strategy in [22], and fix the other parameter \(\lambda =10^{-2}\), whose sensitivity analysis is presented in Sect. 4.4.4.

Table 11 Comparison of the methods in terms of their execution time (second)

We compare our solution with others in terms of their execution time (second). Table 11 reports the comparison results on the domain generalization task where the target domain is Amazon (Office-Caltech). The results are recorded under the same environment. According to Table 11, we note that while our DGDE solution is not computationally more efficient than the comparison methods, its execution time is within a reasonable range, considering its superior performance.

4.4.4 Parameter sensitivity

We investigate the sensitivity of DGDE with respect to different choices of its parameter \(\lambda\), which is kept fixed at \(10^{-2}\) in the main experiments. To this end, we run the sensitivity experiments on the domain generalization tasks where the target domains are DSLR (Office-Caltech), C25 (PIE-Multiview), and Sketch (PACS), with \(\lambda\) varying in the range \(\{10^{-4}, 10^{-3}, \cdots , 10^{3}\}\). Figure 4 depicts the classification accuracy of DGDE versus different choices of \(\lambda\) on these tasks. From Fig. 4, we observe that on different tasks from different datasets, DGDE attains its superior performance when \(\lambda\) is around \(10^{-2}\). Therefore, as a general guideline for choosing this parameter, we would suggest fixing \(\lambda =10^{-2}\).

Fig. 4
figure 4

Parameter sensitivity of DGDE on domain generalization tasks from different datasets

4.4.5 Feature visualization

Fig. 5
figure 5

T-SNE visualization of source and target data in the activation spaces of DDAIG, SWAD, PCL, and DGDE. The source domains are Clipart, Product, and RealWorld, and the target domain is Art. a DDAIG, b SWAD, c PCL, d DGDE

We exploit the t-SNE visualization tool [61] and visualize in Fig. 5(a)-Fig. 5(d) the source and target data from the activation spaces of DDAIG, SWAD, PCL, and DGDE. The domain generalization task has the source domains Clipart, Product, and RealWorld, and the target domain Art from the Office-Home dataset. By comparing Fig. 5(d) against Fig. 5(a), Fig. 5(b), and Fig. 5(c), we observe that our DGDE solution better aligns the source and target data in the network activation space than its competitors DDAIG, SWAD, and PCL. These visualization results show that our joint distribution matching under the KL divergence is a powerful solution to domain generalization.

5 Conclusion

In this paper, we propose the DGDE solution to domain generalization, which minimizes the discrepancy among the source joint distributions to match them in the neural network activation space, and optimizes the probabilistic classifier for target domain classification. Specifically, we quantify the distribution discrepancy using the KL divergence, and innovatively derive the KL approximator via estimating the domain label posterior distribution. This distribution is estimated using multiple linear-in-parameter functions and the \(L^{2}\)-distance, leading to the analytic solution of the parameters. In the experiments, we implement the proposed DGDE solution with both shallow and deep neural network models, and show the power of these implementations on several publicly available datasets. As a future work, we intend to explore the application of our DGDE solution to the problem of RGB-Infrared cross-modality person ReID (RGB-IR ReID) [62], where the RGB images and infrared (IR) images come from different but related modalities.