Keywords

1 Introduction

Causal inference has drawn a lot of attention across various research areas including statistics [2, 25], economics and finance [3, 7, 15] commercial social network applications [5, 10] and health care [8, 12]. One of the main tasks of causal inference is to estimate the average treatment effect (ATE). For example, a biotech company must know to what extent a newly developed vaccine can reduce the probability of infection for the whole population. The classical method to acquire the ATE is to conduct randomized controlled trials (RCTs), where the treatment is randomly assigned to the population but not selectively. Then the effect of the vaccine (treatment) on the infection (outcome) is measured by the difference between the average infection rate of the vaccinated group (treated group) and that of the unvaccinated group (controlled group). RCTs are regarded as the golden standard for treatment effect estimation, but conducting RCTs is costly and time-consuming [9, 21]. Thus, estimating the treatment effects in the observational study instead of RCTs becomes more and more tempting.

When it comes to estimating the ATE from the observational data, we need to handle the selection bias. The selection bias exists due to the non-random treatment assignment. The treatment assignment may be further influenced by the covariates that also directly affect the outcome. In the vaccine example, limited vaccines tend to be distributed to vulnerable individuals who are susceptible to infection. Such a non-random treatment assignment mechanism naturally results in a covariate shift phenomenon. That is, the covariates of the treated population can substantially differ from that of the controlled population.

Two classical methods are developed for adjusting the shifted covariates: inverse propensity weighting (IPW) and regression adjustment (see more details in [26]). IPW weights the instances based on the propensity scores to mimic the principle of RCTs to estimate ATE. Nevertheless, the IPW estimators are sensitive to the misspecification of the propensity score. Regression adjustment methods directly estimate the outcome model instead of propensity scores, whereas they would inevitably lead to biased ATE estimations due to overfitting and regularization bias [3]. Researchers improve classical methods from the perspectives of statistics and methodology.

The orthogonal score function proposed in [3] is a statistical correction by incorporating both the outcome model and the propensity score estimations. Since such a score function satisfies the orthogonal condition, the ATE estimator derived from the score function is consistent as long as one of the two underlying relations is correctly specified. This is also known as the doubly robust property. Recently, balanced representation learning techniques have attracted researchers’ attention. The intuitive idea is to construct a pair of “twins" in the representation space by minimizing the imbalance between the distributions of the treated and controlled groups [23]. However, such methods mainly focus on the balance but overlook the discrimination between treated and controlled units. If the distributions of the treated and controlled groups in the representation space are too similar to be distinguished, it would be difficult to infer the ATE accurately. Such a trade-off plays a crucial role in identifying the treatment effects [23]. The importance of the undiscriminating problem is also emphasized by [10].

In this paper, with the tool of orthogonal machine learning, we propose a moderately-balanced representation learning (MBRL) framework to estimate the treatment effects. MBRL trains in a multi-task framework and stops on a perturbation error metric to obtain a moderately-balanced representation. The merits of MBRL include i) preserving predictive information for inferring individual outcomes; ii) designing a multi-task learning framework to achieve a moderately-balanced rather than over-balanced representation; iii) fully utilizing the orthogonality information during the training and validation stages to achieve superior treatment effect estimations.

2 Preliminaries

Potential Outcome Framework. Let \(\textbf{Z}\) be s-dimensional covariates such that , where is the sample space of covariates. \(D \in \{0, 1\}\) denotes the treatment variable. Y(0), Y(1) represent the potential outcomes for the treatment \(D=0\) and \(D=1\) respectively such that with being the sample space of outcome. We denote \(w=(\textbf{z}, d, y)\) as the realizations of the random variables \(W=(\textbf{Z}, D, Y)\). If the observed treatment is d, then the factual outcome \(Y^F\) equals Y(d). We suppose the observational dataset contains N individuals and the \(m^{th}\) individual is observed as \((\textbf{z}_m, d_m, y_m)\). The target quantity ATE \(\tau \) is defined as \(\tau :=\mathbb {E}\left[ Y(1)-Y(0)\right] \).

Identifying the treatment effects under the potential outcome framework [22] requires some fundamental assumptions: Strong Ignorability, Overlap, Consistency and Stable Unit Treatment Value Assumption (SUTVA). These assumptions guarantee that treatment effects can be inferred if we specify the relation \(\mathbb {E}\left[ Y \mid D, \textbf{Z}\right] \), which is equivalent to estimating \(g_{0}(D,\textbf{Z})\) in the following interactive model when the treatment variable takes a binary value [3]:

$$\begin{aligned} \begin{aligned} Y&=g_{0}(D,\textbf{Z})+\xi ,{} & {} \mathbb {E}\left[ \xi \mid D,\textbf{Z}\right] =0,\\ D&=m_{0}(\textbf{Z}) + \nu ,{} & {} \mathbb {E}\left[ \nu \mid \textbf{Z}\right] =0. \end{aligned} \end{aligned}$$
(1)

Here, \(g_0\) and \(m_0\) are the true nuisance functions. \(\xi \) and \(\nu \) are the noise terms. \(m_{0}(\textbf{Z})=\mathbb {E}\left[ D \mid \textbf{Z}\right] \) is the propensity score. Let i be an element of \(\{0, 1\}\). The true causal parameter \(\theta _{0}^i\) is defined as \(\theta _{0}^i:=\mathbb {E}\left[ Y(i)\right] =\mathbb {E}\left[ g_0(i, \textbf{Z})\right] \) for \(i \in \{0, 1\}\), and the true ATE \(\tau \) is computed by \(\tau =\theta _{0}^1-\theta _{0}^0\). We denote the estimated \((\theta ^{i}_{0}, g_0, m_0)\) as \((\hat{\theta }^{i}, \hat{g}, \hat{m})\), and then the estimated ATE is computed by \(\hat{\tau }=\hat{\theta }^1-\hat{\theta }^0\).

Orthogonal Estimators. We aim to estimate the true causal parameters \(\theta ^1_0\) and \(\theta ^0_0\) given N i.i.d. samples \(\{W_m=(\textbf{Z}_m, D_m, Y_m)\}^{N}_{m=1}\). The standard procedure to acquire the estimated causal parameters \(\hat{\theta }^1\) and \(\hat{\theta }^0\) is: 1) getting the estimated nuisance functions \(\hat{\rho }\), e.g., \(\hat{\rho }=(\hat{g}, \hat{m})\); 2) constructing a score function \(\psi (W, \theta ^i, \rho )\) such that we can derive the estimated causal parameter \(\hat{\theta }^i\) by solving \(\mathbb {E}\left[ \psi (W, \theta ^i, \hat{\rho })\right] =0\), where \(\theta ^i\) is a causal parameter that lies in the causal parameter space. According to [3], the estimator \(\hat{\theta }^i\) solved from \(\mathbb {E}\left[ \psi (W, \theta ^i, \hat{\rho })\right] =0\) is robust to the estimated nuisance functions \(\hat{\rho }\) if the corresponding score function \(\psi (W, \theta ^i, \rho )\) satisfies the orthogonal condition that is stated in Definition 1.

Definition 1 (Orthogonal Condition)

Let \(W=(\textbf{Z}, D, Y)\), \(\rho _0=(h_{0,1},\dots ,h_{0,\gamma })\) be the true nuisance functions and \(\theta _0\) be the true causal parameter with \(\theta \) being a causal parameter that lies in the causal parameter space. A score function \(\psi (W, \theta , \rho )\) is said to satisfy the orthogonal condition with respect to \(\rho =(h_1,...,h_\gamma )\) if

$$\begin{aligned} \begin{aligned} \mathbb {E}\left[ \partial _{h_i}\psi (W, \theta , \rho )\mid _{\rho =\rho _{0}, \theta =\theta _0} \mid \textbf{Z} \right] = 0 \;\;\; \forall 1 \le i \le \gamma . \end{aligned} \end{aligned}$$

Under the interactive model setup (1), the nuisance functions are (gm), and the true ones are \((g_0, m_0)\). In this case, the orthogonal condition guarantees that the estimator is consistent if either one of the two nuisance functions, but unnecessarily both, is accurately estimated. This is well known as the doubly robust property. In this paper, we introduce two orthogonal estimators \(\hat{\theta }_{1}\) [3] and \(\hat{\theta }_{2}\) [14] in Proposition 1, and we can estimate ATE by plugging the learned nuisance functions into the orthogonal estimators.

Proposition 1 (Orthogonal Estimators)

Let the nuisance functions be \(\rho =(g,m)\) and the causal parameter be \(\theta ^i\) for \(i \in \{0, 1\}\), the score functions \(\psi _{1}(W, \theta ^{i}, \rho )\) and \(\psi _{2}(W, \theta ^{i}, \rho )\) that satisfy the orthogonal condition (Definition 1) are:

$$\begin{aligned}&\psi _{1}(W, \theta ^{i}, \rho ) = \theta ^{i}-g(i,\textbf{Z}) -(Y-g(i,\textbf{Z}))\frac{iD+(1-i)(1-D)}{im(\textbf{Z})+(1-i)(1-m(\textbf{Z}))}; \end{aligned}$$
(2)
$$\begin{aligned}&\psi _{2}(W, \theta ^{i}, \rho ) = \theta ^{i}-g(i,\textbf{Z}) -(Y(i)-g(i,\textbf{Z}))\frac{\left( (D-m(\textbf{Z}))-\mathbb {E}\left[ \nu \mid \textbf{Z} \right] \right) ^2}{\mathbb {E}\left[ \nu ^2 \mid \textbf{Z}\right] }. \end{aligned}$$
(3)

The corresponding orthogonal estimators are:

$$\begin{aligned} \hat{\theta }_1^i \quad \text {solves} \quad \frac{1}{N}\sum \limits _{m=1}^{N}\psi _1(W_m, \theta ^{i}, \hat{\rho })=0; \quad \hat{\theta }_2^i \quad \text {solves} \quad \frac{1}{N}\sum \limits _{m=1}^{N}\psi _2(W_m, \theta ^{i}, \hat{\rho })=0. \end{aligned}$$

3 Method

In this section, we first introduce the orthogonality information in Sect. 3.1. Then we present the network structure, objective function and model selection criterion of the proposed MBRL method based on the orthogonality information in Sect. 3.2.

3.1 Orthogonality Information

Recall that the ATE estimators \(\hat{\theta }^i_{1}\) and \(\hat{\theta }^i_{2}\) are doubly robust since they are orthogonal estimators. Still, they could be non-orthogonal once the model setup (1) relaxes the restrictions on the noise terms \(\xi \) and \(\nu \) since the score functions \(\psi _1\) and \(\psi _2\) might violate the orthogonal condition. Hence, we propose the Noise Conditions, which would enforce the learned nuisance functions adapted to orthogonal estimators.

Proposition 2 (Noise Conditions)

Under the interactive model setup (1), the conditions on the noise terms \(\xi \) and \(\nu \), i.e., \(\mathbb {E}\left[ \xi \mid D, \textbf{Z} \right] =0\) and \(\mathbb {E}\left[ \nu \mid \textbf{Z} \right] =0\), are sufficient conditions for \(\psi _1\) and \(\psi _2\) being orthogonal score functions (\(\hat{\theta }_{1}^i\) and \(\hat{\theta }_{2}^i\) being orthogonal estimators).

Given the noise conditions, we can exploit an essential property, the noise orthogonality property.

Property 1 (Noise Orthogonality)

Under the interactive model setup (1) and the noise conditions, we have \(\mathbb {E}[(Y-g_0(D, \textbf{Z}))(D-m_0(\textbf{Z}))]=0\).

The noise conditions are sufficient conditions for the estimators \(\hat{\theta }_{1}^i\) and \(\hat{\theta }_{2}^i\) being orthogonal, so noise conditions play an important role when we approximate the true nuisance functions \((g_0, m_0)\) with estimated ones \((\hat{g}, \hat{m})\). Besides, under the noise conditions, the noise orthogonality can be utilized for our model selection. The decompositions similar to Noise Orthogonality also appeared in [3, 11].

3.2 The Proposed Framework

We propose a moderately-balanced representation learning (MBRL) framework to obtain \((\hat{g}, \hat{m})\) to estimate ATE, and the MBRL architecture is illustrated in Fig. 1. The MBRL network maps the original covariates space to the representation space (i.e., ) such that 1) the representation preserves predictive information for outcomes; 2) the map makes the distributional discrepancy between the treated group and the controlled group small enough; 3) the domain (treated or controlled) of each individual is well discriminated; 4) the orthogonality information is involved.

Fig. 1.
figure 1

The MBRL network architecture.

Learning Representation of Covariates. The distributions of the treated group and the controlled group are inherently disparate due to selection bias. Previous works handle this problem using a balanced representation learning method [16, 23], which forces the distributions of treatment and control groups to be similar enough in the representation space. Specifically, a representation is learned by minimizing the integral probability metrics (IPM), which measures the imbalance between the distributions of the treated population and the controlled population (see the details in [23]):

(4)

The Prediction of Outcome and Treatment. MBRL predicts the outcome by the function , which is partitioned into two functions \(f_0\) and \(f_1\):

$$\begin{aligned} {\begin{aligned} f(d_m, \Phi (\textbf{z}_m))=d_mf_1(\Phi (\textbf{z}_m))+(1-d_m)f_0(\Phi (\textbf{z}_m)). \end{aligned}} \end{aligned}$$
(5)

\(f_1\) and \(f_0\) are the output functions that map the representation to the potential outcomes for \(D=1\) and \(D=0\), respectively. \(f(d_m, \Phi (\textbf{z}_m))\) is the predicted factual outcome and we aim to minimize the factual outcome loss such that

(6)

Here, \(\hat{g}(d_m, \textbf{z}_m)=f(d_m, \Phi (\textbf{z}_m))\) is the estimated factual outcome of the \(m^{th}\) unit. Aside from making a low-error prediction over factual outcomes with a small divergence between treated and controlled groups, the distinguishability of the treated units from the controlled ones is also non-negligible. Therefore, we propose to maximize the distinguishability loss (measured by log-likelihood) such that

(7)

Here, \(\hat{m}(\textbf{z}_m)=\pi (\Phi (\textbf{z}_m))\) is the estimated probability of the \(m^{th}\) unit being assigned the treatment \(D=1\) (aka the estimated propensity score).

The Noise Regularizations. Recall Proposition 2 that \(\mathbb {E}\left[ \xi \mid D, \textbf{Z}\right] =0\) and \(\mathbb {E}\left[ \nu \mid \textbf{Z}\right] =0\) are sufficient conditions for score functions \(\psi _{1}\) and \(\psi _{2}\) being orthogonal. Empirically, we want to involve the following constraints:

$$\begin{aligned} {\begin{aligned}&\frac{1}{N}\sum _{m=1}^{N}\left[ y_m-f(d_m, \Phi (\textbf{z}_m))\right] =0,\\&\frac{1}{N}\sum _{m=1}^{N}\left[ d_m-\pi (\Phi (\textbf{z}_m))\right] =0. \end{aligned}} \end{aligned}$$
(8)

This motivates us to formalize \(\Omega _{y}\) and \(\Omega _{d}\) such that

$$\begin{aligned} {\begin{aligned}&\Omega _{y} = \epsilon _y \big | \frac{1}{N}\sum _{m=1}^{N}[y_m-f(d_m, \Phi (\textbf{z}_m))] \big |, \\&\Omega _{d} = \epsilon _d \big | \frac{1}{N}\sum _{m=1}^{N}[d_m-\pi (\Phi (\textbf{z}_m))] \big |. \end{aligned}} \end{aligned}$$
(9)

The partial derivative of \(\Omega _{y}\) w.r.t. \(\epsilon _y\) (or \(\Omega _{d}\) w.r.t. \(\epsilon _d\)) equaling 0 forces the learned nuisance functions to satisfy Eqn. (8). Therefore, minimizing the noise regularizations \(\Omega _{y}\) and \(\Omega _{d}\) adapts the entire learning process to satisfy the orthogonal score function. This idea corresponds to the targeted regularizations (see more discussions in [11, 24]).

Multi-task Learning and Perturbation Error. MBRL learns the nuisance functions through multi-task learning with following three tasks in each iteration:

(10)

Instead of putting into Task 3 as a regularization, we let be one of the multiple tasks. To be specific, Task 1 updates \(\pi \) to produce the propensity scores, and Task 2 achieves a balance between \(\{\Phi (\textbf{z}_m)\}_{m:d_m=1}\) and \(\{\Phi (\textbf{z}_m)\}_{m:d_m=0}\). Additionally, MBRL incorporates a novel model selection criterion, the Perturbation Error, according to the noise orthogonality property. It takes advantage of the noise orthogonality information by perturbating the main evaluation metric. For example, if the final model is selected by the metric root-mean-square error (\(RMSE = \sqrt{\frac{1}{N}\sum _{m=1}^{N}(y_m-\hat{y}_m)^2}\)), then the perturbation error \(\epsilon _{p}\) is defined as

$$\begin{aligned} \epsilon _{p}=RMSE + \beta |\frac{1}{N}\sum _{m=1}^{N}(y_m-\hat{y}_m)(d_m-\hat{d}_m)|. \end{aligned}$$

Here, \(\beta \) is the perturbation coefficient which is a constant; \(\hat{y}_m\) and \(\hat{d}_m\) are the predicted values of \(f(d_m, \Phi (\textbf{z}_m))\) and \(\pi (\Phi (\textbf{z}_m))\), respectively. The final model is selected on the validation set based on the minimum \(\epsilon _{p}\). If either outcome or propensity score is well specified (i.e., representations are moderately-balanced instead of over-balanced), the second term in \(\epsilon _{p}\) would be small.

4 Experiments

In this section, we conduct comprehensive experiments on benchmark datasets to evaluate the performance produced by MBRL and other prevalent causal inference methods. We further test the effectiveness of MBRL on simulated datasets with different levels of selection bias. All the experiments are run on Dell 7920 with 1\(\,\times \,\)16-core Intel Xeon Gold 6250 3.90 GHz CPU and 3x NVIDIA Quadro RTX 6000 GPU.

4.1 Dataset Description

Since the ground truth of treatment effects are inaccessible for real-world data, it is difficult to evaluate the performance of causal inference methods for ATE estimation. Previous causal inference literatures assess their methods on two prevalent semi-synthetic datasets: IHDP and Twins. IHDP. The IHDP dataset is a well-known benchmark dataset for causal inference introduced by [13]. It includes 747 samples with 25-dimensional covariates associated with the information of infants and their mothers, such as birth weight and mother’s age. These covariates are collected from a real-world randomized experiment. Our aim is to study the treatment effect of the specialist visits (binary treatment) on the cognitive scores (continuous-valued outcome). The outcome is generated using the NPCI package [6], and the selection bias is created by removing a subset of the treated population. We use the same 1000 IHDP datasets as the ones used in [23], where each dataset is split by the ratio of \(63\%/27\%/10\%\) as training/validation/test sets.

Twins. The Twins dataset [19] collects twin births in the USA between 1989 and 1991 [1]. After the data processing, each unit has 30 covariates relevant to parents, pregnancy and birth [28]. The treatment \(D=1\) indicates the heavier twin while \(D=0\) indicates the lighter twin, and the outcome Y is a binary variable defined as the 1-year mortality. Similar to [28], we only select twins who have the same gender and both weigh less than 2 kg, which finally gives 11440 pairs of twins whose mortality rate is \(17.7\%\) for the lighter twin, and \(16.1\%\) for the heavier twin. To create the selection bias, we selectively choose one of the two twins as the factual observation based on the covariates of \(m^{th}\) individual: \(D_m|\textbf{Z}_m \sim \) Bernoulli(Sigmoid(\(\textbf{w}^T\textbf{Z}_m+n\))), where and . We repeat the data generating process for 100 times, and the generated 100 Twins datasets are all split by the ratio of \(56\%/24\%/20\%\) as training/validation/test sets.

4.2 Performance Measurement and Experimental Settings

Performance Measurement. Generally, the comparisons are based on the absolute error in ATE: \(\epsilon _{ATE}=|\tau -\hat{\tau }|\). Additionally, we also test the performance of MBRL on individual treatment effect (ITE) estimations. For IHDP datasets, we adopt Precision in Estimation of Heterogeneous Effect (PEHE):

$$\epsilon _{PEHE}=\frac{1}{N}\sum _{m=1}^{N}\left( [y_m(1)-y_m(0)]-[\hat{y}_m(1)-\hat{y}_m(0)]\right) ^2.$$

For Twins datasets, we follow [19] to adopt Area Under ROC Curve (AUC).

Baseline Models. We compare our MBRL method with the following basline models: linear regression with the treatment as feature (OLS/LR\(_1\)), separate linear regression for each treatment group (OLS/LR\(_2\)), k-nearest neighbor (k-NN), bayesian additive regression trees (BART) [4], causal forest (CF) [25], balancing linear regression (BLR) [16], balancing neural network (BNN) [16], treatment-agnostic representation network (TARNet) [23], counterfactual regression with Wasserstein distance (CFR-WASS) [23], causal effect variational autoencoders (CEVAE) [19], local similarity preserved individual treatment effect (SITE) [27], generative adversarial networks for inference of treatment effect (GANITE) [28] and (Dragonnet) [24]. Experimental Details. In our experiments, is chosen as the Wasserstein distance. Let the empirical distribution of representation be \(P(\Phi (\textbf{Z}))=P(\Phi (\textbf{Z}) \mid D=1)\) for the treated group and \(Q(\Phi (\textbf{Z}))=Q(\Phi (\textbf{Z}) \mid D=0)\) for the controlled group. Assuming that is defined as the functional space of a family of 1-Lipschitz functions, we obtain the 1-Wasserstein distance for [23]:

Here, defines the set of push-forward functions that transform the representation distribution of the treated group \(P(\Phi (\textbf{Z}))\) to that of the controlled group \(Q(\Phi (\textbf{Z}))\).

Table 1. Performance comparisons and ablation study with mean ± standard error on 1000 IHDP datasets. \(\epsilon _{ATE}\): Lower is better. \(\sqrt{\epsilon _{PEHE}}\): Lower is better.

In addition, we adopt ELU activation function and set 4 fully connected layers with 200 units for both the representation encoder network \(\Phi (\cdot )\) and the discriminator \(\pi (\cdot )\), and 3 fully connected layers with 100 units for the outcome prediction networks \(f_0(\cdot )\) and \(f_1(\cdot )\). The optimizer is chosen as Adam [17], and the learning rate for the optimizer is set to be \(1e^{-3}\). We set (batch size, epoch) to be (100, 1000)/(1000, 250) for IHDP/Twins experiments, and the hyper parameters \((\lambda _1,\lambda _2)\) to be (0.01, 0.01)/(0.1, 0.1) for IHDP/Twins experiments. The final model early stops on the metric \(\epsilon _p\), and we choose \(\beta \) in \(\epsilon _p\) as 0.1 and 100 for IHDP experiments and Twins experiments, respectively.

For the baseline models, we follow the same settings of hyperparameters as in their published paper and code. For our MBRL network, the optimal hyperparameters are chosen in the same way as [23]. The searching ranges are reported in Table 4.

Table 2. Performance comparisons with mean ± standard error on 100 Twins datasets. \(\epsilon _{ATE}\): Lower is better. AUC: Higher is better.

4.3 Results Analysis

Table 1 and Table 2 report part of the performances of baseline methods and MBRL on IHDP and Twins datasets. We present the average values and standard errors of \(\epsilon _{ATE}\), \(\epsilon _{PEHE}\) and AUC (mean ± std). The lower \(\epsilon _{ATE}\) and \(\epsilon _{PEHE}\) or the higher AUC, the better. Bold indicates the best method for each dataset.

As stated in Table 1 and Table 2, we have the following observations. 1) MBRL achieves significant improvements in both ITE and ATE estimations across all datasets compared to the baseline models. 2) The advanced representation learning methods that focus on estimating ITE (such as SITE, TARNet and CFR-WASS) show their inapplicability to ATE estimations. By contrast, MBRL not only significantly outperforms these representation learning methods in ITE estimations but also remains among the best ATE results. 3) The state-of-the-art ATE estimation method, Dragonnet, achieves superior ATE estimations across all the baseline models but yields a substantial error in ITE estimations. Although Dragonnet shares a similar basic network architecture to MBRL, MBRL can obtain a substantially lower \(\epsilon _{ATE}\) than Dragonnet owing to the multi-task learning framework and the utilization for orthogonality information. These observations indicate that the proposed MBRL method is extremely effective for estimating treatment effects.

We further conduct an ablation study on IHDP datasets to test if orthogonality information is practical in real applications. The relevant results are reported in Table 3. We let MBRL* denote MBRL without perturbation error \(\epsilon _{p}\), and MBRL** denote MBRL without any orthogonality information (\(\epsilon _{p}\), \(\Omega _{d}\) and \(\Omega _{y}\)). We find that incorporating orthogonality information will enhance the power of estimating treatment effects, whether with or without orthogonal estimators. This enhancement is pronounced especially when orthogonal estimators are plugged in for in-sample data.

Table 3. Ablation study on IHDP datasets.
Table 4. The searching ranges of hyperparameters.

4.4 Simulation Study

In this part, we mainly investigate two questions. Q1. Does MBRL perform more stably to the level of selection bias than the state-of-the-art model Dragonnet? Q2. Can the noise orthogonality information, the perturbation error \(\epsilon _{p}\), improve ATE estimations regardless of different models/estimators/selection bias levels?

We generate 2500 treated samples whose covariates , and 5000 controlled whose covariates , where \(\boldsymbol{\mu }^1\) and \(\boldsymbol{\mu }^0\) are both 10-dimensional vector and . The level of selection bias, measured by KL divergence of \(\boldsymbol{\mu }^1\) with respect to \(\boldsymbol{\mu }^0\), would vary by fixing \(\boldsymbol{\mu }^0\) and adjusting \(\boldsymbol{\mu }^1\). The potential outcomes of \(m^{th}\) individual are generated as \(Y(1) \mid \textbf{Z}_m \sim (\textbf{w}_{1}^{T}\textbf{Z}_m+n_1)\), \(Y(0) \mid \textbf{Z}_m \sim (\textbf{w}_{0}^{T}\textbf{Z}_m+n_0)\), where , , , . By adjusting \(\boldsymbol{\mu }^1\) and fixing \(\boldsymbol{\mu }^0\), we obtain five datasets with different levels of KL divergence in \(\{0, \ 62.85, \ 141.41, \ 565.63, \ 769.89\}\). We run experiments on each dataset 100 times and draw box plots with regard to \(\epsilon _{ATE}\) on the test set in Fig. 2.

In Fig. 2(a), we first find that MBRL shows stronger robustness and achieves significantly better ATE estimations with regard to different selection bias levels compared with Dragonnet. In addition, it is noticable that choosing the perturbation error \(\epsilon _{p}\) as the model selection metric would yield smaller \(\epsilon _{ATE}\) for any model (Dragonnet or MBRL). Particularly, \(\epsilon _{p}\) corrects more errors for MBRL than Dragonnet, which indicates that \(\epsilon _{p}\) works better if a model utilizes the orthogonality information in the training stage. In Fig. 2(b), we have two main observations: i) the criterion \(\epsilon _{p}\) improves ATE estimations for all estimators across different selection bias levels; ii) the improvement brought by \(\epsilon _{p}\) becomes more substantial when selection bias increases.

Fig. 2.
figure 2

Comparisons between models with and without \(\epsilon _{p}\) w.r.t. varying levels of selection bias.

5 Related Work

Representation Learning. Our work has a strong connection with the balanced representation learning methods proposed in [16, 23], where they mainly focus on minimizing the imbalance between the different treatment groups in the representation space but overlook maximizing the discrimination of each unit’s treatment domain. IGNITE framework is proposed in [10] to infer individual treatment effects from networked data, where they achieve a balanced representation that captures patterns of hidden confounders predictive of treatment assignments. This inspires us to study treatment effects by training a moderately-balanced representation via multi-task learning. Other works relevant to representation learning include [18, 19, 24, 27, 28] and references therein.

Orthogonal Score Function. [3] develop the theory of double/debiased machine learning (DML) from [20]. They define the notion of orthogonal condition, which allows their DML estimator to be doubly robust. Based on the theory of [3], another orthogonal estimator is proposed by [14], aiming to overcome the high variance issue suffered by DML due to the misspecified propensity score. Despite the success of orthogonal estimators, the establishment of them requires the noise conditions to guarantee the corresponding score functions satisfying the orthogonal condition. None of the existing literature emphasizes the critical role of noise conditions or utilizes the orthogonality information for the model selection.

6 Conclusion

This paper proposes an effective representation learning method, MBRL, to study the treatment effects. Specifically, MBRL avoids the over-balanced issue by leveraging treatment domains of the representations via multi-task learning. MBRL further takes advantage of the orthogonality information and involves it in the training and validation stages. The extensive experiments show that 1) MBRL has strong predictability for the potential outcomes, distinguishability for the treatment assignment, applicability to orthogonal estimators, and robustness to the selection bias; 2) MRBL achieves substantial improvements on treatment effect estimations compared with existing state-of-the-art methods.