Keywords

1 Introduction

Standard supervised learning has shown impressive results when training and test samples follow the same distribution. However, many real world applications do not conform to this setting, so that research successes do not readily translate into practice [20]. Domain Generalization (DG) addresses this problem: it aims at training models that generalize well under domain shift. In contrast to domain adaption, where a few labeled and/or many unlabeled examples are provided for each target test domain, in DG absolutely no data is available from the test domains’ distributions making the problem unsolvable in general.

In this work, we view the problem of DG specifically using ideas from causal discovery. This viewpoint makes the problem of DG well-posed: we assume that there exists a feature vector \(h^\star (\mathbf {X})\) whose relation to the target variable Y is invariant across all environments. Consequently, the conditional probability \(p(Y\mid h^\star (\mathbf {X}))\) has predictive power in each environment. From a causal perspective, changes between domains or environments can be described as interventions; and causal relationships – unlike purely statistical ones – remain invariant across environments unless explicitly changed under intervention. This is due to the fundamental principle of “Independent Causal Mechanisms" which will be discussed in Sect. 3. From a causal standpoint, finding robust models is therefore a causal discovery task [4, 24]. Taking a causal perspective on DG, we aim at identifying features which (i) have an invariant relationship to the target variable Y and (ii) are maximally informative about Y. This problem has already been addressed with some simplifying assumptions and a discrete combinatorial search by [22, 35], but we make weaker assumptions and enable gradient based optimization. The later is attractive because it readily scales to high dimensions and offers the possibility to learn very informative features, instead of merely selecting among predefined ones. Approaches to invariant relations similar to ours were taken by [10], who restrict themselves to linear relations, and [2, 19], who consider a weaker notion of invariance. Problems (i) and (ii) are quite intricate because the search space has combinatorial complexity and testing for conditional independence in high dimensions is notoriously difficult. Our main contributions to this problem are the following: First, by connecting invariant (causal) relations with normalizing flows, we propose a differentiable two-part objective of the form \(I(Y; h(\mathbf {X})) + \lambda _I \mathcal {L}_{I}\), where I is the mutual information and \(\mathcal {L}_{I}\) enforces the invariance of the relation between \(h(\mathbf {X})\) and Y across all environments. This objective operationalizes the ICM principle with a trade-off between feature informativeness and invariance controlled by parameter \(\lambda _I\). Our formulation generalizes existing work because our objective is not restricted to linear models. Second, we take advantage of the continuous objective in three important ways: (1) We can learn invariant new features, whereas graph-based methods as in e.g. [22] can only select features from a pre-defined set. (2) Our approach does not suffer from the scalability problems of combinatorial optimization methods as proposed in e.g. [30] and [35]. (3) Our optimization via normalizing flows, i.e. in the form of a density estimation task, facilitates accurate maximization of the mutual information. Third, we show how our objective simplifies in important special cases and under which conditions its optimal solution identifies the true causal parents of the target variable Y. We empirically demonstrate that the new method achieves good results on two datasets proposed in the literature.

2 Related Work

Different types of invariances have been considered in the field of DG. One type is defined on the feature level, i.e. features \(h(\mathbf {X})\) are invariant across environments if they follow the same distribution in all environments (e.g. [5, 8, 27]). However, this form of invariance is problematic since the distribution of the target variable might change between environments, which induces a corresponding change in the distribution of \(h(\mathbf {X})\). A more plausible and theoretically justified assumption is the invariance of relations [22, 30, 35]. The relation between a target Y and features \(h(\mathbf {X})\) is invariant across environments, if the conditional distribution \(p(Y\mid h(\mathbf {X}))\) remains unchanged in all environments. Existing approaches exhaustively model conditional distributions for all possible feature selections and check for the invariance property [22, 30, 35], which scales poorly for large feature spaces. We derive a theoretical result connecting normalizing flows and invariant relations, which enables gradient-based learning of an invariant solution. In order to exploit our formulation, we also use the Hilbert-Schmidt-Independence Criterion that has been used for robust learning by [11] in the one environment setting. [2, 19, 38] also propose gradient-based learning frameworks, which exploit a weaker notion of invariance: They aim to match the conditional expectations across environments, whereas we address the harder problem of matching the entire conditional distributions. The connection between DG, invariance and causality has been pointed out for instance by [24, 35, 39]. From a causal perspective, DG is a causal discovery task [24]. For studies on causal discovery in the purely observational setting see e.g. [6, 29, 36], but they do not take advantage of variations across environments. The case of different environments has been studied by [4, 9, 15, 16, 22, 26, 30, 37]. Most of these approaches rely on combinatorial optimization or are restricted to linear mechanisms, whereas our continuous objective efficiently optimizes very general non-linear models. The distinctive property of causal relations to remain invariant across environments in the absence of direct interventions has been known since at least the 1930s [7, 13]. However, its crucial role as a tool for causal discovery was – to the best of our knowledge– only recently recognized by [30]. Their estimator – Invariant Causal Prediction (ICP) – returns the intersection of all subsets of variables that have an invariant relation w.r.t. Y. The output is shown to be the set of the direct causes of Y under suitable conditions. Again, this approach requires linear models and exhaustive search over all possible variable sets \(\mathbf {X}_S\). Extensions to time series and non-linear additive noise models were studied in [14, 33]. Our treatment of invariance is inspired by these papers and also discusses identifiability results, i.e. conditions when the identified variables are indeed the direct causes, with two key differences: Firstly, we propose a formulation that allows for a gradient-based learning and does not need strong assumptions on the underlying causal model. Second, while ICP tends to exclude features from the parent set when in doubt, our algorithm prefers to err towards best predictive performance in this situation.

3 Preliminaries

In the following we introduce the basics of this work as well as the connection between DG and causality. Basics on causality are presented in Appendix A. We first define our notation as follows: We denote the set of all variables describing the system under study as \(\widetilde{\mathbf {X}} = \{X_1, \dots , X_D\}\). One of these variables will be singled out as our prediction target, whereas the remaining ones are observed and may serve as predictors. To clarify notation, we call the target variable \(Y \equiv X_i\) for some \(i \in \{1, \dots , D \}\), and the remaining observations are \(\mathbf {X}= \widetilde{\mathbf {X}} \setminus \{Y\}\). Realizations of a random variable (RV) are denoted with lower case letters, e.g. \(x_i\). We assume that observations can be obtained in different environments \(e \in \mathcal {E}\). Symbols with superscript, e.g. \(Y^e\), refer to a specific environment, whereas symbols without refer to data pooled over all environments. We distinguish known environments \(e \in \mathcal {E}_{\text {seen}}\), where training data are available, from unknown ones \(e\in \mathcal {E}_{\text {unseen}}\), where we wish our models to generalize to. The set of all environments is \(\mathcal {E}= \mathcal {E}_{\text {seen}}\cup \mathcal {E}_{\text {unseen}}\). We assume that all RVs have a density \(p_A\) with probability distribution \(P_A\) (for some variable or set A). We consider the environment to be a RV E and therefore a system variable similar to [26]. This gives an additional view on causal discovery and the DG problem. Independence and dependence of two variables A and B is written as \(A \perp B\) and \(A\not \perp B\) respectively. Two RVs AB are conditionally independent given C if \(P(A, B\mid C) = P(A\mid C) P(B \mid C)\). This is denoted with \(A \perp B \mid C\). It means A does not contain any information about B if C is known (see e.g. [31]). Similarly, one can define independence and conditional independence for sets of RVs.

3.1 Invariance and the Principle of ICM

DG is in general unsolvable because distributions between seen and unseen environments could differ arbitrarily. In order to transfer knowledge from \(\mathcal {E}_{\text {seen}}\) to \(\mathcal {E}_{\text {unseen}}\), we have to make assumptions on how seen and unseen environments relate. These assumptions have a close link to causality. We assume certain relations between variables remain invariant across all environments. A subset \(\mathbf {X}_S \subset \mathbf {X}\) of variables elicits an invariant relation or satisfies the invariance property w.r.t. Y over a subset \(W \subset \mathcal {E}\) of environments if

$$\begin{aligned} \forall e, e' \in W:\quad P(Y^e \mid \mathbf {X}_S^e = u) = P(Y^{e'} \mid \mathbf {X}_S^{e'}=u) \end{aligned}$$
(1)

for all u where both conditional distributions are well-defined. Equivalently, we can define the invariance property by \(Y \perp E \mid \mathbf {X}_S\) and \(I(Y;E \mid \mathbf {X}_S) = 0\) for E restricted to W. The invariance property for computed features \(h(\mathbf {X})\) is defined analogously by the relation \(Y \perp E \mid h(\mathbf {X})\). Although we can only test for Eq. 1 in \(\mathcal {E}_{\text {seen}}\), taking a causal perspective allows us to derive plausible conditions for an invariance to remain valid in all environments \(\mathcal {E}\). In brief, we assume that environments correspond to interventions in the system and invariance arises from the principle of independent causal mechanisms [31]. We specify these conditions later in Assumption 1 and 2. At first, consider the joint density \(p_{\widetilde{\mathbf {X}}}(\widetilde{\mathbf {X}})\).The chain rule offers a combinatorial number of ways to decompose this distribution into a product of conditionals. Among those, the causal factorization

$$\begin{aligned} p_{\widetilde{\mathbf {X}}} (x_1, \dots , x_D) = {\textstyle \prod }_{i=1}^D p_i ( x_i \mid \mathbf {x}_{pa(i)}) \end{aligned}$$
(2)

is singled out by conditioning each \(X_i\) onto its direct causes or causal parents \(\mathbf {X}_{pa(i)}\), where pa(i) denotes the appropriate index set. The special properties of this factorization are discussed in [31]. The conditionals \(p_i\) of the causal factorization are called causal mechanisms. An intervention onto the system is defined by replacing one or several factors in the decomposition with different (conditional) densities \(\overline{p}\). Here, we distinguish soft-interventions where \(\overline{p}_j(x_j \mid \mathbf {x}_{pa(j)}) \ne p_j(x_j \mid \mathbf {x}_{pa(j)})\) and hard-interventions where \(\overline{p}_j(x_j \mid \mathbf {x}_{pa(j)}) = \overline{p}_j(x_j)\) is a density which does not depend on \(x_{pa(j)}\) (e.g. an atomic intervention where \(x_j\) is set to a specific value \(\overline{x}\)). The resulting joint distribution for a single intervention is

$$\begin{aligned} \overline{p}_{\widetilde{\mathbf {X}}} (x_1, \dots , x_D) = \overline{p}_j (x_j \mid \mathbf {x}_{pa(j)}) {\textstyle \prod }_{i=1, i\ne j}^D p_i ( x_i \mid \mathbf {x}_{pa(i)}) \end{aligned}$$
(3)

and extends to multiple simultaneous interventions in the obvious way. The principle of independent causal mechanisms (ICM) states that every mechanism acts independently of the others [31]. Consequently, an intervention replacing \(p_j\) with \(\overline{p}_j\) has no effect on the other factors \(p_{i\ne j}\), as indicated by Eq. 3. This is a crucial property of the causal decomposition – alternative factorizations do not exhibit this behavior. Instead, a coordinated modification of several factors is generally required to model the effect of an intervention in a non-causal decomposition. We utilize this principle as a tool to train robust models. To do so, we make two additional assumptions, similar to [30] and [14]:

Assumption(1) Any differences in the joint distributions \(p^e_{\widetilde{\mathbf {X}}}\) from one environment to the other are fully explainable as interventions: replacing factors \(p_i^e ( x_i \mid \mathbf {x}_{pa(i)})\) in environment e with factors \(p_i^{e'} ( x_i \mid \mathbf {x}_{pa(i)})\) in environment \(e'\) (for some subset of the variables) is the only admissible change.(2) The mechanism \(p( y \mid \mathbf {x}_{pa(Y)})\) for the target variable Y is invariant under changes of environment, i.e. we require conditional independence \(Y \perp E \mid \mathbf {X}_{pa(Y)}\).

Assumption 2 implies that Y must not directly depend on E. Consequences in case of omitted variables are discussed in Appendix B. If we knew the causal decomposition, we could use these assumptions directly to train a robust model for Y – we would simply regress Y on its parents \(\mathbf {X}_{pa(Y)}\). However, we only require that a causal decomposition with these properties exists, but do not assume that it is known. Instead, our method uses the assumptions indirectly – by simultaneously considering data from different environments – to identify a stable regressor for Y. We call a regressor stable if it solely relies on predictors whose relationship to Y remains invariant across environments, i.e. is not influenced by any intervention. By assumption 2, such a regressor always exists. However, predictor variables beyond \(\mathbf {X}_{pa(Y)}\), e.g. children of Y or parents of children, may be included into our model as long as their relationship to Y remains invariant across all environments. We discuss this and further illustrate Assumption 2 in Appendix B. In general, prediction accuracy will be maximized when all suitable predictor variables are included into the model. Accordingly, our algorithm will asymptotically identify the full set of stable predictors for Y. In addition, we will prove under which conditions this set contains exactly the parents of Y.

3.2 Domain Generalization

To exploit the principle of ICM for DG, we formulate the DG problem as follows

$$\begin{aligned} h^{\star }&:= \mathop {\mathrm {arg\,max}}\limits _{h\in \mathcal {H}} \Big \{ \min _{e \in \mathcal {E}} I(Y^e; h(\mathbf {X}^e)) \Big \}&\quad \text {s.t.}\quad Y \perp E \mid h(\mathbf {X}) \end{aligned}$$
(4)

The optimization problem in Eq. 4 asks to find features \(h(\mathbf {X})\) which are maximally informative in the worst environment subject to the invariance constraint. where \(h\in \mathcal {H}\) denotes a learnable feature extraction function \(h :\mathbb {R}^D \rightarrow \mathbb {R}^M\) where M is a hyperparameter. This optimization problem defines a maximin objective: The features \(h(\mathbf {X})\) should be as informative as possible about the response Y even in the most difficult environment, while conforming to the ICM constraint that the relationship between features and response must remain invariant across all environments. In principle, our approach can also optimize related objectives like the average mutual information over environments. However, very good performance in a majority of the environments could then mask failure in a single (outlier) environment. We opted for the maximin formulation to avoid this. On the other hand there might be scenarios where the maxmin formulation is limited. For instance when the training signal is very noisy in one environment, the classifier might discard valuable information from the other environments. As it stands, Eq. 4 is hard to optimize, because traditional independence tests for the constraint \( Y \perp E \mid h(\mathbf {X})\) cannot cope with conditioning variables selected from a potentially infinitely large space \(\mathcal {H}\). A re-formulation of the DG problem to circumvent these issues is our main theoretical contribution.

3.3 Normalizing Flows

Normalizing flows form a class of probabilistic models that has recently received considerable attention, see e.g. [28]. They model complex distributions by means of invertible functions T (chosen from some model space \(\mathcal {T}\)), which map the densities of interest to latent normal distributions. Normalizing flows are typically built with specialized neural networks that are invertible by construction and have tractable Jacobian determinants. We represent the conditional distribution \(P(Y\mid h(\mathbf {X}))\) by a conditional normalizing flow (see e.g. [1]). The literature typically deals with Structural Causal Models restricted to additive noise. With normalizing flows, we are able to lift this restriction to the much broader setting of arbitrary distributions (for details see Appendix C). The corresponding loss is the negative log-likelihood (NLL) of Y under T, given by

$$\begin{aligned} \mathcal {L}_{\mathrm {NLL}} (T, h) := \mathbb {E}_{h(\mathbf {X}),Y} \big [\Vert T(Y; h(\mathbf {X}) \Vert ^2/2 - \log |\det \nabla _y T(Y;h(\mathbf {X}))|\big ]+C \end{aligned}$$
(5)

where \(\det \nabla _y T\) is the Jacobian determinant and \(C= \dim (Y) \log (\sqrt{2\pi })\) is a constant that can be dropped [28]. Equation 5 can be derived from the change of variables formula and the assumption that T maps to a standard normal distribution [28]. If we consider the NLL on a particular environment \(e \in \mathcal {E}\), we denote this with \(\mathcal {L}^e_\mathrm {NLL}\). Lemma 1 shows that normalizing flows optimized by NLL are indeed applicable to our problem:

Lemma 1

(proof in Appendix C) Let \( h^{\star }, T^{\star } := \arg \min _{h \in \mathcal {H}, T \in \mathcal {T}} \mathcal {L}_{\mathrm {NLL}} (T, h)\) be the solution of the NLL minimization problem on a sufficiently rich function space \(\mathcal {T}\). Then the following properties hold for any set \(\mathcal {H}\) of feature extractors:

  1. (a)

    \(h^{\star }\) also maximizes the mutual information, i.e. \(h^{\star } = \arg \max _{g \in \mathcal {H}} I(g(\mathbf {X}); Y)\)

  2. (b)

    \(h^{\star }\) and the latent variables \(R=T^{\star }(Y; h^{\star }(\mathbf {X}))\) are independent: \(h^{\star }(\mathbf {X}) \perp R\)

Statement (a) guarantees that \(h^\star \) extracts as much information about Y as possible. Hence, the objective (4) becomes equivalent to optimizing (5) when we restrict the space \(\mathcal {H}\) of admissible feature extractors to the subspace \(\mathcal {H}_{\perp }\) satisfying the invariance constraint \(Y \perp E \mid h(\mathbf {X})\): \(\mathop {\mathrm {arg\,min}}\nolimits _{h \in \mathcal {H}_\perp }\max _{e \in \mathcal {E}}\min _{ T\in \mathcal {T}} \mathcal {L}_\mathrm {NLL}^e(T;h) = \mathop {\mathrm {arg\,max}}\nolimits _{h \in \mathcal {H}_\perp } \min _{e \in \mathcal {E}} I(Y^e; h(\mathbf {X}^e))\) (Appendix C). Statement (b) ensures that the flow indeed implements a valid structural equation, which requires that R can be sampled independently of the features \(h(\mathbf {X})\).

4 Method

In the following we propose a way of indirectly expressing the constraint in Eq. 4 via normalizing flows. Thereafter, we combine this result with Lemma 1 to obtain a differentiable objective for solving the DG problem. We also present important simplifications for least squares regression and softmax classification and discuss relations of our approach with causal discovery.

4.1 Learning the Invariance Property

The following theorem establishes a connection between invariant relations, prediction residuals and normalizing flows. The key consequence is that a suitably trained normalizing flow translates the statistical independence of the latent variable R from the features and environment \((h(\mathbf {X}), E)\) into the desired invariance of the mechanism \(P(Y \mid h(\mathbf {X}))\) under changes of E. We will exploit this for an elegant reformulation of the DG problem (4) into the objective (7) below.

Theorem 1

Let h be a differentiable function and \(Y, \mathbf {X}, E\) be RVs. Furthermore, let \(R=T(Y; h(\mathbf {X}))\) be a continuous, differentiable function that is a diffeomorphism in Y. Suppose that \(R \perp (h(\mathbf {X}), E)\). Then, it holds that \(Y \perp E \mid h(\mathbf {X})\).

Proof

The decomposition rule for the assumption (i) \(R \perp (h(\mathbf {X}), E)\) implies (ii) \(R \perp h (\mathbf {X})\). To simplify notation, we define \(Z:=h(\mathbf {X})\). Because T is invertible in Y and due to the change of variables (c.o.v.) formula, we obtain

$$\begin{aligned} p_{Y\mid Z,E}(y \mid z,e) \overset{(c.o.v.)}{=}&p_{R \mid Z, E}( T(y, z) \mid z, e ) \left| \det \frac{\partial T}{\partial y} (y, z) \right| \\ \overset{(i)}{=} \;\; p_R(r) \left| \det \frac{\partial T}{\partial y} (y, z) \right| \overset{(ii)}{=}\;\;&p_{R \mid Z} (r \mid z)\left| \det \frac{\partial T}{\partial y} (y,z )\right| \overset{(c.o.v.)}{=} p_{Y \mid Z}(y\mid z). \end{aligned}$$

This implies \(Y \perp E \mid Z\). The theorem states in particular that if there exists a suitable diffeomorphism T such that \(R \perp (h(\mathbf {X}), E)\), then \(h(\mathbf {X})\) satisfies the invariance property w.r.t. Y. Note that if Assumption 2 is violated, the condition \(R \perp (h(\mathbf {X}), E)\) is unachievable in general and therefore the theorem is not applicable (see Appendix B). We use Theorem 1 in order to learn features h that meet this requirement. In the following, we denote a conditional normalizing flow parameterized via \(\theta \) with \(T_\theta \). Furthermore, \(h_\phi \) denotes a feature extractor implemented as a neural network parameterized via \(\phi \). We can relax condition \(R\perp (h_\phi (\mathbf {X}), E)\) by means of the Hilbert Schmidt Independence Criterion (HSIC), a kernel-based independence measure (see Appendix D for the definition and [12] for details). This loss, denoted as \(\mathcal {L}_I\), penalizes dependence between the distributions of R and \((h_\phi (\mathbf {X}),E)\). The HSIC guarantees that

$$\begin{aligned} \mathcal {L}_I\big (P_{R}, P_{h_\phi (\mathbf {X}), E}\big ) = 0\quad \Longleftrightarrow \quad R \perp (h_\phi (\mathbf {X}) , E) \end{aligned}$$
(6)

where \(R = T_\theta (Y ; h_\phi (\mathbf {X}))\) and \(P_{R}, P_{h_\phi (\mathbf {X}), E}\) are the distributions implied by the parameter choices \(\phi \) and \(\theta \). Due to Theorem 1, minimization of \(\mathcal {L}_I (P_R, P_{h_\phi (\mathbf {X}), E})\) w.r.t. \(\phi \) and \(\theta \) will thus approximate the desired invariance property \(Y \perp E \mid h_\phi (\mathbf {X})\), with exact validity upon perfect convergence. When \(R \perp (h_\phi (\mathbf {X}) , E)\) is fulfilled, the decomposition rule implies \(R \perp E\) as well. However, if the differences between environments are small, empirical convergence is accelerated by adding a Wasserstein loss which enforces the latter (see Appendix D and Sect. 5.2).

4.2 Exploiting Invariances for Prediction

Equation 4 can be re-formulated as a differentiable loss using a Lagrange multiplier \(\lambda _I\) on the HSIC loss. \(\lambda _I\) acts as a hyperparameter to adjust the trade-off between the invariance property of \(h_\phi (\mathbf {X})\) w.r.t. Y and the mutual information between \(h_\phi (\mathbf {X})\) and Y. See Appendix F for algorithm details. In the following, we consider normalizing flows in order to optimize Eq. 4. Using Lemma 1(a), we maximize \(\min _{e \in \mathcal {E}} I(Y^e; h_\phi (\mathbf {X}^e))\) by minimizing \( \max _{e \in \mathcal {E}} \{ \mathcal {L}_{\mathrm {NLL}}(T_\theta ; h_\phi ) \}\) w.r.t. \(\phi , \theta \). To achieve the described trade-off between goodness-of-fit and invariance, we therefore optimize

$$\begin{aligned}&\arg \min _{\theta , \phi } \Big ( \max _{e \in \mathcal {E}} \Big \{ \mathcal {L}^e_{\mathrm {NLL}}(T_\theta , h_\phi ) \Big \} + \lambda _I \mathcal {L}_I (P_R, P_{h_\phi (\mathbf {X}), E}) \Big ) \end{aligned}$$
(7)

where \(R^e = T_\theta (Y^e, h_\phi (\mathbf {X}^e))\) and \(\lambda _I >0\). The first term maximizes the mutual information between \(h_\phi (\mathbf {X})\) and Y in the environment where the features are least informative about Y and the second term aims to ensure an invariant relation. In the special case that the data is governed by additive noise, Eq. 7 simplifies: Let \(f_\theta \) be a regression function, then solving for the noise term gives \(Y - f_\theta (\mathbf {X})\) which corresponds to a diffeomorphism in Y, namely \(T_\theta (Y; X) = Y- f_\theta (\mathbf {X})\). Under certain assumptions (see Appendix E) we obtain an approximation of Eq. 7 via

$$\begin{aligned} \arg \min _{\theta } \Big ( \max _{e \in \mathcal {E}_{\text {seen}}} \Big \{&\mathbb {E}\big [(Y^e-f_\theta (\mathbf {X}^e))^2\big ] \Big \} + \lambda _I \mathcal {L}_I (P_R, P_{f_\theta (\mathbf {X}), E}) \Big ) \end{aligned}$$
(8)

where \(R^e = Y^e- f_\theta ( \mathbf {X}^e)\) and \(\lambda _I > 0\). Here, \(\arg \max _\theta I(f_\theta (\mathbf {X}^e), Y^e)\) corresponds to the argmin of the L2-Loss in the corresponding environment. Alternatively we can view the problem as to find features \(h_\phi :\mathbb {R}^D \rightarrow \mathbb {R}^m\) such that \(I(h_\phi (\mathbf {X}), Y)\) gets maximized under the assumption that there exists a model \(f_\theta (h_\phi (\mathbf {X})) + R = Y\) where R is independent of \(h_\phi (\mathbf {X})\) and is Gaussian. In this case we obtain the learning objective

$$\begin{aligned} \arg \min _{\theta , \phi } \Big ( \max _{e \in \mathcal {E}_{\text {seen}}} \Big \{&\mathbb {E}\big [(Y^e-f_\theta (h_\phi (\mathbf {X}^e)))^2\big ] \Big \} + \lambda _I \mathcal {L}_I (P_R, P_{h_\phi (\mathbf {X}), E}) \Big ) \end{aligned}$$
(9)

For the classification case, we consider the expected cross-entropy loss

$$\begin{aligned} - \mathbb {E}_{\mathbf {X},Y} \Big [f(\mathbf {X})_Y - \log \Big ( \sum _{c} \exp \big ( f(\mathbf {X})_c\big )\Big )\Big ] \end{aligned}$$
(10)

where \(f:\mathcal {X} \rightarrow \mathbb {R}^m\) returns the logits. Minimizing the expected cross-entropy loss amounts to maximizing the mutual information between \(f(\mathbf {X})\) and Y [3, 34, Eq. 3]. We set \(T(Y; f(\mathbf {X})) = Y\cdot \mathrm {softmax}(f(\mathbf {X}))\) with component-wise multiplication. Then T is invertible in Y conditioned on the softmax output and therefore Theorem 1 is applicable. Now we can apply the same invariance loss as above in order to obtain a solution to Eq. 4.

4.3 Relation to Causal Discovery

Under certain conditions, solving Eq. 4 leads to features which correspond to the direct causes of Y (identifiability). In this case we obtain the causal mechanism by computing the conditional distribution of Y given the direct causes. Hence Eq. 4 can be seen as an approximation of the causal mechanism when the identifiability conditions are met. The following Proposition states the conditions when the direct causes of Y can be found by exploiting Theorem 1.

Proposition 1

We assume that the underlying causal graph G is faithful with respect to \(P_{\widetilde{\mathbf {X}}, E}\). We further assume that every child of Y in G is also a child of E in G. A variable selection \(h(\mathbf {X})= \mathbf {X}_S\) corresponds to the direct causes if the following conditions are met: (i) \(T(Y;h(\mathbf {X})) \perp E, h(\mathbf {X})\) is satisfied for a diffeomorphism \(T(\cdot ; h(\mathbf {X}))\), (ii) \(h(\mathbf {X})\) is maximally informative about Y and (iii) \(h(\mathbf {X})\) contains only variables from the Markov blanket of Y.

The Markov blanket of Y is the only set of vertices which are necessary to predict Y (see Appendix A). We give a proof of Proposition 1 as well as a discussion in Appendix G. To facilitate explainability and explicit causal discovery, we employ the same gating function and complexity loss as in [17]. The gating function \(h_\phi \) is a 0-1 mask that marks the selected variables, and the complexity loss \(\mathcal {L}(h_\phi )\) is a soft counter of the selected variables. Intuitively speaking, if we search for a variable selection that conforms to the conditions in Proposition 1, the complexity loss will exclude all non-task relevant variables. Therefore, if \(\mathcal {H}\) is the set of gating functions, then \(h^\star \) in Eq. 4 corresponds to the direct causes of Y under the conditions listed in Proposition 1. The complexity loss as well as the gating function can be optimized by gradient descent.

5 Experiments

The main focus of this work is on the theoretical and methodological improvements of causality-based domain generalization using information theoretical concepts. A complete and rigorous quantitative evaluation is beyond the scope of this work. In the following we demonstrate proof-of-concept experiments.

5.1 Synthetic Causal Graphs

Fig. 1.
figure 1

(a) Detection accuracy of the direct causes for baselines and our gating architectures, broken down for different target variables (left) and mechanisms (right: Linear, Tanhshrink, Softplus, ReLU, Multipl. Noise) (b) Logarithmic plot of L2 errors, normalized by CERM test error. For each method (ours in bold) from left to right: training error, test error on seen environments, domain generalization error on unseen environments.

To evaluate our methods for the regression case, we follow the experimental design of [14]. It rests on the causal graph in Fig. 2. Each variable \(X_1,...,X_6\) is chosen as the regression target Y in turn, so that a rich variety of local configurations around Y is tested. The corresponding structural equations are selected among four model types of the form \(f(\mathbf {X}_{pa(i)}, N_i) = \sum _{j\in pa(i)} \texttt {mech}( a_j X_j) + N_i\), where \(\texttt {mech}\) is either the identity (hence we get a linear Structural Causal Model (SCM)), Tanhshrink, Softplus or ReLU, and one multiplicative noise mechanism of the form \(f_{i} ( \mathbf {X}_{pa(i)}, N_i) = (\sum _{j \in pa(i)} a_j X_j ) \cdot ( 1+ (1/4) N_i) +N_i\), resulting in 1365 different settings. For each setting, we define one observational environment (using exactly the selected mechanisms) and three interventional ones, where soft or do-interventions are applied to non-target variables according to Assumptions 1 and 2 (full details in Appendix H). Each inference model is trained on 1024 realizations of three environments, whereas the fourth one is held back for DG testing. The tasks are to identify the parents of the current target variable Y, and to train a transferable regression model based on this parent hypothesis. We measure performance by the accuracy of the detected parent sets and by the L2 regression errors relative to the regression function using the ground-truth parents. We evaluate four models derived from our theory: two normalizing flows as in Eq. 7 with and without gating mechanisms (FlowG, Flow) and two additive noise models, again with and without gating mechanism (ANMG, ANM), using a feed-forward network with the objective in Eq. 9 (ANMG) and Eq. 8 (ANM).

Fig. 2.
figure 2

Directed graph of our SCM. Target variable Y is chosen among \(X_1,\dots , X_6\) in turn.

For comparison, we train three baselines: ICP (a causal discovery algorithm also exploiting ICM, but restricted to linear regression, [30]), a variant of the PC-Algorithm (PC-Alg, see Appendix H.4) and standard empirical-risk-minimization ERM, a feed-forward network minimizing the L2-loss, which ignores the causal structure by regressing Y on all other variables. We normalize our results with a ground truth model (CERM), which is identical to ERM, but restricted to the true causal parents of the respective Y. The accuracy of parent detection is shown in Fig. 1a. The score indicates the fraction of the experiments where the exact set of all causal parents was found and all non-parents were excluded. We see that the PC algorithm performs unsatisfactorily, whereas ICP exhibits the expected behavior: it works well for variables without parents and for linear SCMs, i.e. exactly within its specification. Among our models, only the gating ones explicitly identify the parents. They clearly outperform the baselines, with a slight edge for ANMG, as long as its assumption of additive noise is fulfilled. Figure 1b and Table 1 report regression errors for seen and unseen environments, with CERM indicating the theoretical lower bound. The PC algorithm is excluded from this experiment due to its poor detection of the direct causes. ICP wins for linear SCMs, but otherwise has largest errors, since it cannot accurately account for non-linear mechanisms. ERM gives reasonable test errors (while overfitting the training data ), but generalizes poorly to unseen environments, as expected. Our models perform quite similarly to CERM. We again find a slight edge for ANMG, except under multiplicative noise, where ANMG’s additive noise assumption is violated and Flow is superior. All methods (including CERM) occasionally fail in the domain generalization task, indicating that some DG problems are more difficult than others, e.g. when the differences between seen environments are too small to reliably identify the invariant mechanism or the unseen environment requires extrapolation beyond the training data boundaries. Models without gating (Flow, ANM) seem to be slightly more robust in this respect. A detailed analysis of our experiments can be found in Appendix H.

Table 1. Medians and upper \(95\%\) quantiles for domain generalization L2 errors (i.e. on unseen environments) for different model types and data-generating mechanisms (lower is better).

5.2 Colored MNIST

To demonstrate that our model is able to perform DG in the classification case, we use the same data generating process as in the colored variant of the MNIST-dataset established by [2], but create training instances online rather than upfront. The response is reduced to two labels – 0 for all images with digit \(\{0,\dots , 4\}\) and 1 for digits \(\{5, \dots 9\}\) – with deliberate label noise that limits the achievable shape-based classification accuracy to 75%. To confuse the classifier, digits are additionally colored such that colors are spuriously associated with the true labels at accuracies of 90% resp. 80% in the first two environments, whereas the association is only 10% correct in the third environment. A classifier naively trained on the first two environments will identify color as the best predictor, but will perform terribly when tested on the third environment. In contrast, a robust model will ignore the unstable relation between colors and labels and use the invariant relation, namely the one between digit shapes and labels, for prediction. We supplement the HSIC loss with a Wasserstein term to explicitly enforce \(R \perp E\), i.e. \(\mathcal {L}_{I} = \mathrm {HSIC} + \mathrm {L2}( \mathrm {sort}(R^{e_1}), \mathrm {sort}(R^{e_2}))\) (see Appendix D). This gives a better training signal as the HSIC alone, since the difference in label-color association between environments 1 and 2 (90% vs. 80%) is deliberately chosen very small to make the task hard to learn. Experimental details can be found in Appendix I. Figure 3a shows the results for our model: Naive training (\(\lambda _I=0\), i.e. invariance of residuals is not enforced) gives accuracies corresponding to the association between colors and labels and thus completely fails in test environment 3. In contrast, our model performs close to the best possible rate for invariant classifiers in environments 1 and 2 and still achieves 68.5% in environment 3. This is essentially on par with preexisting methods. For instance, IRM achieves 71% on the third environment for this particular dataset, although the dataset itself is not particularly suitable for meaningful quantitative comparisons. Figure 3b demonstrates the trade-off between goodness of fit in the training environments 1 and 2 and the robustness of the resulting classifier: the model’s ability to perform DG to the unseen environment 3 improves as \(\lambda _I\) increases. If \(\lambda _I\) is too large, it dominates the classification training signal and performance breaks down in all environments. However, the choice of \(\lambda _I\) is not critical, as good results are obtained over a wide range of settings.

Fig. 3.
figure 3

(a) Accuracy of a standard classifier and our model (b) Performance of the model in the three environments, depending on the hyperparameter \(\lambda _I\).

6 Discussion

In this paper, we have introduced a new method to find invariant and causal models by exploiting the principle of ICM. Our method works by gradient descent in contrast to combinatorial optimization procedures. This circumvents scalability issues and allows us to extract invariant features even when the raw data representation is not in itself meaningful (e.g. we only observe pixel values). In comparison to alternative approaches, our use of normalizing flows places fewer restrictions on the underlying true generative process. We have also shown under which circumstances our method guarantees to find the underlying causal model. Moreover, we demonstrated theoretically and empirically that our method is able to learn robust models w.r.t. distribution shift. Future work includes ablations studies in order to improve the understanding of the influence of single components, e.g. the choice of the maxmin objective over the average mutual information or the Wasserstein loss and the HSIC loss. As a next step, we will examine our approach in more complex scenarios where, for instance, the invariance assumption may only hold approximately.