Keywords

1 Introduction

Humans often think about how they can alter the outcome of a situation. What do I need to change for the bank to approve my loan? or Which symptoms would lead to a different medical diagnosis? are common examples. This form of counterfactual reasoning comes natural to us and explains how to arrive at a desired outcome in an interpretable manner. Moreover, examples of counterfactual instances resulting in a different outcome can give powerful insights of what is important to the underlying decision process, making it a compelling method to explain predictions of machine learning models (Fig. 1).

In the context of predictive models, given a test instance and the model’s prediction, a counterfactual instance describes the necessary change in input features that alter the prediction to a predefined output [21]. For classification models the predefined output can be any target class or prediction probability distribution. Counterfactual instances can then be found by iteratively perturbing the input features of the test instance until the desired prediction is reached. In practice, the counterfactual search is posed as an optimization problem—we want to minimize an objective function which encodes desirable properties of the counterfactual instance with respect to the perturbations. The key insight of this formulation is the need to design an objective function that allows us to generate high quality counterfactual instances. A counterfactual instance \(x_{\text {cf}}\) should have the following desirable properties:

Fig. 1.
figure 1

(a) Examples of original and counterfactual instances on the MNIST dataset along with predictions of a CNN model. (b) A counterfactual instance on the Adult (Census) dataset highlighting the feature changes required to alter the prediction of an NN model.

  1. 1.

    The model prediction on \(x_{\text {cf}}\) needs to be close to the predefined output.

  2. 2.

    The perturbation \(\delta \) changing the original instance \(x_{0}\) into \(x_{\text {cf}}=x_{0}+\delta \) should be sparse.

  3. 3.

    The counterfactual \(x_{\text {cf}}\) needs to be interpretable. We consider an instance \(x_{\text {cf}}\) interpretable if it lies close to the model’s training data distribution. This definition does not only apply to the overall data set, but importantly also to the training instances that belong to the counterfactual class. Let us illustrate this with an intuitive example. Assume we are predicting house prices with features including the square footage and the number of bedrooms. Our house is valued below £500,000 and we would like to know what needs to change about the house in order to increase the valuation above £500,000. By simply increasing the number of bedrooms and leaving the other features unchanged, the model predicts that our counterfactual house is now worth more than £500,000. This sparse counterfactual instance lies fairly close to the overall training distribution since only one feature value was changed. The counterfactual is however out-of-distribution with regards to the subset of houses in the training data valued above £500,000 because other relevant features like the square footage still resemble a typical house valued below £500,000. As a result, we do not consider this counterfactual to be very interpretable. We show in the experiments that there is often a trade-off between sparsity and interpretability.

  4. 4.

    The counterfactual instance \(x_\text {cf}\) needs to be found fast enough to ensure it can be used in a real life setting.

An overly simplistic objective function may return instances which satisfy properties 1. and 2., but where the perturbations are not interpretable with respect to the counterfactual class.

In this paper we propose using class prototypes in the objective function to guide the perturbations quickly towards an interpretable counterfactual. The prototypes also allow us to remove computational bottlenecks from the optimization process which occur due to numerical gradient calculation for black box models. In addition, we propose two novel metrics to quantify interpretability which provide a principled benchmark for evaluating interpretability at the instance level. We show empirically that prototypes improve the quality of counterfactual instances on both image (MNIST) and tabular (Wisconsin Breast Cancer) datasets. Finally, we propose using pairwise distance measures between categories of categorical variables to define meaningful perturbations for such variables and illustrate the effectiveness of the method on the Adult (Census) dataset.

2 Related Work

Counterfactual instances—synthetic instances of data engineered from real instances to change the prediction of a machine learning model—have been suggested as a way of explaining individual predictions of a model as an alternative to feature attribution methods such as LIME [23] or SHAP [19].

Wacther et al. [27] generate counterfactuals by minimizing an objective function which sums the squared difference between the predictions on the perturbed instance and the desired outcome, and a scaled \(L_{1}\) norm of the perturbations. Laugel et al. [15] find counterfactuals through a heuristic search procedure by growing spheres around the instance to be explained. The above methods do not take local, class specific interpretability into account. Furthermore, for black box models the number of prediction calls during the search process grows proportionally to either the dimensionality of the feature space [27] or the number of sampled observations [9, 15], which can result in a computational bottleneck. Dhurandhar et al. [7, 9] propose the framework of Contrastive Explanations which find the minimal number of features that need to be changed/unchanged to keep/change a prediction.

A key contribution of this paper is the use of prototypes to guide the counterfactual search process. Kim et al. [14], Gurumoorthy et al. [11] use prototypes as example-based explanations to improve the interpretability of complex datasets. Besides improving interpretability, prototypes have a broad range of applications like clustering [13], classification [4, 26], and few-shot learning [25]. If we have access to an encoder [24], we follow the approach of [25] who define a class prototype as the mean encoding of the instances which belong to that class. In the absence of an encoder, we find prototypes through class specific k-d trees [3].

To judge the quality of the counterfactuals we introduce two novel metrics which focus on local interpretability with respect to the training data distribution. This is different from [8] who define an interpretability metric relative to a target model. Kim et al. [14] on the other hand quantify interpretability through a human pilot study measuring the accuracy and efficiency of the humans on a predictive task. Luss et al. [20] also highlight the importance of good local data representations in order to generate high quality explanations.

Another contribution of this paper is a principled approach to handling categorical variables during the counterfactual generation process. Some previously proposed solutions are either computationally expensive [27] or do not take relationships between categories into account [9, 22]. We propose using pairwise distance measures to define embeddings of categorical variables into numerical space which allows us to define meaningful perturbations when generating counterfactuals.

3 Methodology

3.1 Background

The following section outlines how the prototype loss term is constructed and why it improves the convergence speed and interpretability. Finding a counterfactual instance \(x_{\text {cf}} = x_{0} + \delta \), with both \(x_{\text {cf}}\) and \(x_{0}\) \(\in \) \(\mathcal {X}\subseteq \mathbb {R}^D\) where \(\mathcal {X}\) represents the D-dimensional feature space, implies optimizing an objective function of the following form:

$$\begin{aligned} \min _{\delta } c \cdot f_{\kappa }(x_{0},\delta ) + f_{\text {dist}}(\delta ). \end{aligned}$$
(1)

\(f_{\kappa }(x_{0},\delta )\) encourages the predicted class i of the perturbed instance \(x_{\text {cf}}\) to be different than the predicted class \(t_{0}\) of the original instance \(x_{0}\). Similar to [7], we define this loss term as:

(2)

where \([f_{\text {pred}}(x_{0} + \delta )]_{i}\) is the i-th class prediction probability, and \(\kappa \ge 0\) caps the divergence between \([f_{\text {pred}}(x_{0} + \delta )]_{t_{0}}\) and \([f_{\text {pred}}(x_{0} + \delta )]_{i}\). The term \(f_{\text {dist}}(\delta )\) minimizes the distance between \(x_{0}\) and \(x_{\text {cf}}\) with the aim to generate sparse counterfactuals. We use an elastic net regularizer [28]:

$$\begin{aligned} f_{\text {dist}}(\delta ) = \beta \cdot \Vert \delta \Vert _{1} + \Vert \delta \Vert _{2}^2 = \beta \cdot L_{1} + L_{2}. \end{aligned}$$
(3)

While the objective function (1) is able to generate counterfactual instances, it does not address a number of issues:

  1. 1.

    \(x_{\text {cf}}\) does not necessarily respect the training data manifold, resulting in out-of-distribution counterfactual instances. Often a trade off needs to be made between sparsity and interpretability of \(x_{\text {cf}}\).

  2. 2.

    The scaling parameter c of \(f_{\kappa }(x_{0},\delta )\) needs to be set within the appropriate range before a potential counterfactual instance is found. Finding a good range can be time consuming.

[7] aim to address the first issue by adding in an additional loss term \(L_{\text {AE}}\) which represents the \(L_{2}\) reconstruction error of \(x_{cf}\) evaluated by an autoencoder AE which is fit on the training set:

$$\begin{aligned} L_{\text {AE}} = \gamma \cdot \Vert x_{0} + \delta - \text {AE}(x_{0} + \delta ) \Vert _{2}^2. \end{aligned}$$
(4)

The loss L to be minimized now becomes:

$$\begin{aligned} L = c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {AE}}. \end{aligned}$$
(5)

The autoencoder loss term \(L_{\text {AE}}\) penalizes out-of-distribution counterfactual instances, but does not take the data distribution for each prediction class i into account. This can lead to sparse but uninterpretable counterfactuals, as illustrated by Fig. 2. The first row of Fig. 2(b) shows a sparse counterfactual 3 generated from the original 5 using loss function (5). Both visual inspection and reconstruction of the counterfactual instance using AE in Fig. 2(e) make clear however that the counterfactual lies closer to the distribution of a 5 and is not interpretable as a 3. The second row adds a prototype loss term to the objective function, leading to a less sparse but more interpretable counterfactual 6.

Fig. 2.
figure 2

First row: (a) original instance and (b) uninterpretable counterfactual 3. (c), (d) and (e) are reconstructions of (b) with respectively \(\text {AE}_{3}\), \(\text {AE}_{5}\) and \(\text {AE}\). Second row: (a) original instance and (b) interpretable counterfactual 6. (c), (d) and (e) are reconstructions of (b) with respectively \(\text {AE}_{6}\), \(\text {AE}_{5}\) and \(\text {AE}\).

The \(L_{\text {AE}}\) loss term also does not consistently speed up the counterfactual search process since it imposes a penalty on the distance between the proposed \(x_{\text {cf}}\) and its reconstruction by the autoencoder without explicitly guiding \(x_{\text {cf}}\) towards an interpretable solution. We address these issues by introducing an additional loss term, \(L_{\text {proto}}\).

3.2 Prototype Loss Term

By adding in a prototype loss term \(L_{\text {proto}}\), we obtain the following objective function:

$$\begin{aligned} L = c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {AE}} + L_{\text {proto}}, \end{aligned}$$
(6)

where \(L_{\text {AE}}\) becomes optional. The aim of \(L_{\text {proto}}\) is twofold:

  1. 1.

    Guide the perturbations \(\delta \) towards an interpretable counterfactual \(x_{\text {cf}}\) which falls in the distribution of counterfactual class i.

  2. 2.

    Speed up the counterfactual search process without too much hyperparameter tuning.

To define the prototype for each class, we can reuse the encoder part of the autoencoder from \(L_{\text {AE}}\). The encoder \(\text {ENC}(x)\) projects \(x \in \mathcal {X}\) onto an E-dimensional latent space \(\mathbb {R}^E\). We also need a representative, unlabeled sample of the training dataset. First the predictive model is called to label the dataset with the classes predicted by the model. Then for each class i we encode the instances belonging to that class and order them by increasing \(L_{2}\) distance to \(\text {ENC}(x_{0})\). Similar to [25], the class prototype is defined as the average encoding over the K nearest instances in the latent space with the same class label:

(7)

for the ordered \(\lbrace x_{k}^{i}\rbrace _{k=1}^{K}\) in class i. It is important to note that the prototype is defined in the latent space, not the original feature space.

The Euclidean distance is part of a class of distance functions called Bregman divergences. If we consider that the encoded instances belonging to class i define a cluster for i, then \(\mathrm {proto}_{i}\) equals the cluster mean. For Bregman divergences the cluster mean yields the minimal distance to the points in the cluster [1]. Since we use the Euclidean distance to find the closest class to \(x_{0}\), \(\mathrm {proto}_{i}\) is a suitable class representation in the latent space. When generating a counterfactual instance for \(x_{0}\), we first find the nearest prototype \(\mathrm {proto}_{j}\) of class \(j \ne t_{0}\) to the encoding of \(x_{0}\):

$$\begin{aligned} j = \mathop {\text {arg min}}\limits _{i \ne t_{0}} \Vert \text {ENC}(x_{0}) - \mathrm {proto}_{i} \Vert _{2}. \end{aligned}$$
(8)

The prototype loss \(L_{\text {proto}}\) can now be defined as:

$$\begin{aligned} L_{\text {proto}} = \theta \cdot \Vert \text {ENC}(x_{0} + \delta ) - \mathrm {proto}_{j} \Vert _{2}^2, \end{aligned}$$
(9)

where \(\text {ENC}(x_{0} + \delta )\) is the encoding of the perturbed instance. As a result, \(L_{\text {proto}}\) explicitly guides the perturbations towards the nearest prototype \(\mathrm {proto}_{j \ne t_{0}}\), speeding up the counterfactual search process towards the average encoding of class j. This leads to more interpretable counterfactuals as illustrated by the experiments. Algorithm 1 summarizes this approach.

figure a
figure b

3.3 Using K-D Trees as Class Representations

If we do not have a trained encoder available, we can build class representations using k-d trees [3]. After labeling the representative training set by calling the predictive model, we can represent each class i by a separate k-d tree built using the instances with class label i. This approach is similar to [12] who use class specific k-d trees to measure the agreement between a classifier and a modified nearest neighbour classifier on test instances. For each k-d tree \(j \ne t_{0}\), we compute the Euclidean distance between \(x_{0}\) and the k-nearest item in the tree \(x_{j, k}\). The closest \(x_{j, k}\) across all classes \(j \ne t_{0}\) becomes the class prototype \(\mathrm {proto}_{j}\). Note that we are now working in the original feature space. The loss term \(L_{\text {proto}}\) is equal to:

$$\begin{aligned} L_{\text {proto}} = \theta \cdot \Vert x_{0} + \delta - \mathrm {proto}_{j} \Vert _{2}^2. \end{aligned}$$
(10)

Algorithm 2 outlines the k-d trees approach.

3.4 Categorical Variables

Creating meaningful perturbations for categorical data is not straightforward as the very concept of perturbing an input feature implies some notion of rank and distance between the values a variable can take. We approach this by inferring pairwise distances between categories of a categorical variable based on either model predictions (Modified Value Distance Metric) [6] or the context provided by the other variables in the dataset (Association-Based Distance Metric) [16]. We then apply multidimensional scaling [5] to project the inferred distances into one-dimensional Euclidean space, which allows us to perform perturbations in this space. After applying a perturbation in this space, we map the resulting number back to the closest category before evaluating the classifier’s prediction.

3.5 Removing \(L_{\text { pred}}\)

In the absence of \(L_{\text {proto}}\), only \(L_{\text {pred}}\) encourages the perturbed instance to predict class \(i \ne t_{0}\). In the case of black box models where we only have access to the model’s prediction function, \(L_{\text {pred}}\) can become a computational bottleneck. This means that for neural networks, we can no longer take advantage of automatic differentiation and need to evaluate the gradients numerically. Let us express the gradient of \(L_{\text {pred}}\) with respect to the input features x as follows:

$$\begin{aligned} \frac{\partial L_{\text {pred}}}{\partial x} = \frac{\partial f_{\kappa }(x)}{\partial x} = \frac{\partial f_{\kappa }(x)}{\partial f_{\text {pred}}} \frac{\partial f_{\text {pred}}}{\partial x}, \end{aligned}$$
(11)

where \(f_{\text {pred}}\) represents the model’s prediction function. The numerical gradient approximation for \(f_{\text {pred}}\) with respect to input feature k can be written as:

$$\begin{aligned} \frac{\partial f_{\text {pred}}}{\partial x_{k}} \approx \frac{f_{\text {pred}}(x + \epsilon _{k}) - f_{\text {pred}}(x - \epsilon _{k})}{2 \epsilon }, \end{aligned}$$
(12)

where \(\epsilon _{k}\) is a perturbation with the same dimension as x and taking value \(\epsilon \) for feature k and 0 otherwise. As a result, the prediction function needs to be evaluated twice for each feature per gradient step just to compute \(\tfrac{\partial f_{\text {pred}}}{\partial x_{k}}\). For a \(28\times 28\) MNIST image, this translates into a batch of \(28 \cdot 28 \cdot 2=1568\) prediction function calls. Eliminating \(L_{\text {pred}}\) would therefore speed up the counterfactual search process significantly. By using the prototypes to guide the counterfactuals, we can remove \(L_{\text {pred}}\) and only call the prediction function once per gradient update on the perturbed instance to check whether the predicted class i of \(x_{0} + \delta \) is different from \(t_{0}\). This eliminates the computational bottleneck while ensuring that the perturbed instance moves towards an interpretable counterfactual \(x_{\text {cf}}\) of class \(i \ne t_{0}\).

3.6 FISTA Optimization

Like [7], we optimize our objective function by applying a fast iterative shrinkage-thresholding algorithm (FISTA) [2] where the solution space for the output \(x_{\text {cf}} = x_{0} + \delta \) is restricted to \(\mathcal {X}\). The optimization algorithm iteratively updates \(\delta \) with momentum for N optimization steps. It also strips out the \(\beta \cdot L_{1}\) regularization term from the objective function and instead shrinks perturbations \(| \delta _{k} | < \beta \) for feature k to 0. The optimal counterfactual is defined as \(x_{\text {cf}} = x_{0} + \delta ^{n^{*}}\) where \(n^{*} = \text {arg min}_{n \in {1, ..., N}} \beta \cdot \Vert \delta ^{n} \Vert _{1} + \Vert \delta ^{n} \Vert _{2}^{2}\) and the predicted class on \(x_{\text {cf}}\) is \(i \ne t_{0}\).

4 Experiments

The experiments are conducted on an image and tabular dataset. The first experiment on the MNIST handwritten digit dataset [17] makes use of an autoencoder to define and construct prototypes. The second experiment uses the Breast Cancer Wisconsin (Diagnostic) dataset [10]. The latter dataset has lower dimensionality so we find the prototypes using k-d trees. Finally, we illustrate our approach for handling categorical data on the Adult (Census) dataset [10].

4.1 Evaluation

The counterfactuals are evaluated on their interpretability, sparsity and speed of the search process. The sparsity is evaluated using the elastic net loss term \(\text {EN}(\delta ) = \beta \cdot \Vert \delta \Vert _{1} + \Vert \delta \Vert _{2}^2\) while the speed is measured by the time and the number of gradient updates required until a satisfactory counterfactual \(x_{\text {cf}}\) is found. We define a satisfactory counterfactual as the optimal counterfactual found using FISTA for a fixed value of c for which counterfactual instances exist.

In order to evaluate interpretability, we introduce two interpretability metrics IM1 and IM2. Let \(\text {AE}_{i}\) and \(\text {AE}_{t_{0}}\) be autoencoders trained specifically on instances of classes i and \(t_{0}\), respectively. Then IM1 measures the ratio between the reconstruction errors of \(x_{\text {cf}}\) using \(\text {AE}_{i}\) and \(\text {AE}_{t_{0}}\):

(13)

A lower value for IM1 means that \(x_{\text {cf}}\) can be better reconstructed by the autoencoder which has only seen instances of the counterfactual class i than by the autoencoder trained on the original class \(t_{0}\). This implies that \(x_{\text {cf}}\) lies closer to the data manifold of counterfactual class i compared to \(t_{0}\), which is considered to be more interpretable.

The second metric IM2 compares how similar the reconstructed counterfactual instances are when using \(\text {AE}_{i}\) and an autoencoder trained on all classes, \(\text {AE}\). We scale IM2 by the \(L_{1}\) norm of \(x_{\text {cf}}\) to make the metric comparable across classes:

(14)

A low value of IM2 means that the reconstructed instances of \(x_{\text {cf}}\) are very similar when using either \(\text {AE}_{i}\) or \(\text {AE}\). As a result, the data distribution of the counterfactual class i describes \(x_{\text {cf}}\) as good as the distribution over all classes. This implies that the counterfactual is interpretable. Figure 2 illustrates the intuition behind IM1 and IM2.

The uninterpretable counterfactual 3 (\(x_{\text {cf,1}}\)) in the first row of Fig. 2(b) has an IM1 value of 1.81 compared to 1.04 for \(x_{\text {cf,2}}\) in the second row because the reconstruction of \(x_{\text {cf,1}}\) by \(\text {AE}_{5}\) in Fig. 2(d) is better than by \(\text {AE}_{3}\) in Fig. 2(c). The IM2 value of \(x_{\text {cf,1}}\) is higher as well—0.15 compared to 0.12 for \(x_{\text {cf,2}}\))—since the reconstruction by \(\text {AE}\) in Fig. 2(e) yields a clear instance of the original class 5.

Finally, for MNIST we apply a multiple model comparison test based on the maximum mean discrepancy [18] to evaluate the relative interpretability of counterfactuals generated by each method.

4.2 Handwritten Digits

The first experiment is conducted on the MNIST dataset. The experiment analyzes the impact of \(L_{\text {proto}}\) on the counterfactual search process with an encoder defining the prototypes for K equal to 5. We further investigate the importance of the \(L_{\text {AE}}\) and \(L_{\text {pred}}\) loss terms in the presence of \(L_{\text {proto}}\). We evaluate and compare counterfactuals obtained by using the following loss functions:

$$\begin{aligned} \begin{aligned} A&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2}\\ B&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {AE}}\\ C&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {proto}}\\ D&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {AE}} + L_{\text {proto}}\\ E&= \beta \cdot L_{1} + L_{2} + L_{\text {proto}}\\ F&= \beta \cdot L_{1} + L_{2} + L_{\text {AE}} + L_{\text {proto}} \end{aligned} \end{aligned}$$
(15)

For each of the ten classes, we randomly sample 50 numbers from the test set and find counterfactual instances for 3 different random seeds per sample. This brings the total number of counterfactuals to 1,500 per loss function.

The model used to classify the digits is a convolutional neural network with 2 convolution layers, each followed by a max-pooling layer. The output of the second pooling layer is flattened and fed into a fully connected layer followed by a softmax output layer over the 10 possible classes. For objective functions B to F, the experiment also uses a trained autoencoder for the \(L_{\text {AE}}\) and \(L_{\text {proto}}\) loss terms. The autoencoder has 3 convolution layers in the encoder and 3 deconvolution layers in the decoder. Full details of the classifier and autoencoder, as well as the hyperparameter values used can be found in the supplementary material.

Results. Table 1 summarizes the findings for the speed and interpretability measures.

Fig. 3.
figure 3

(a) Mean time in seconds and number of gradient updates needed to find a satisfactory counterfactual for objective functions A to F across all MNIST classes. The error bars represent the standard deviation to illustrate variability between approaches. (b) Mean IM1 and IM2 for objective functions A to F across all MNIST classes (lower is better). The error bars represent the 95% confidence bounds. (c) Sparsity measure \(\text {EN}(\delta )\) for loss functions A to F. The error bars represent the 95% confidence bounds.

Speed. Figure 3(a) shows the mean time and number of gradient steps required to find a satisfactory counterfactual for each objective function. We also show the standard deviations to illustrate the variability between the different loss functions. For loss function A, the majority of the time is spent finding a good range for c to find a balance between steering the perturbed instance away from the original class \(t_{0}\) and the elastic net regularization. If c is too small, the \(L_{1}\) regularization term cancels out the perturbations, but if c is too large, \(x_{\text {cf}}\) is not sparse anymore.

The aim of \(L_{\text {AE}}\) in loss function B is not to speed up convergence towards a counterfactual instance, but to have \(x_{\text {cf}}\) respect the training data distribution. This is backed up by the experiments. The average speed improvement and reduction in the number of gradient updates compared to A of respectively 36% and 54% is significant but very inconsistent given the high standard deviation. The addition of \(L_{\text {proto}}\) in C however drastically reduces the time and iterations needed by respectively 77% and 84% compared to A. The combination of \(L_{\text {AE}}\) and \(L_{\text {proto}}\) in D improves the time to find a counterfactual instance further: \(x_{\text {cf}}\) is found 82% faster compared to A, with the number of iterations down by 90%.

Table 1. Summary statistics with 95% confidence bounds for each loss function for the MNIST experiment.

So far we have assumed access to the model architecture to take advantage of automatic differentiation during the counterfactual search process. \(L_{\text {pred}}\) can however form a computational bottleneck for black box models because numerical gradient calculation results in a number of prediction function calls proportionate to the dimensionality of the input features. Consider \(A'\) the equivalent of loss function A where we can only query the model’s prediction function. E and F remove \(L_{\text {pred}}\) which results in approximately a 100x speed up of the counterfactual search process compared to \(A'\). The results can be found in the supplementary material.

Quantitative Interpretability. IM1 peaks for loss function A and improves by respectively 13% and 26% as \(L_{\text {AE}}\) and \(L_{\text {proto}}\) are added (Fig. 3(b)). This implies that including \(L_{\text {proto}}\) leads to more interpretable counterfactual instances than \(L_{\text {AE}}\) which explicitly minimizes the reconstruction error using \(\text {AE}\). Removing \(L_{\text {pred}}\) in E yields an improvement over A of 29%. While \(L_{\text {pred}}\) encourages the perturbed instance to predict a different class than \(t_{0}\), it does not impose any restrictions on the data distribution of \(x_{\text {cf}}\). \(L_{\text {proto}}\) on the other hand implicitly encourages the perturbed instance to predict \(i \ne t_{0}\) while minimizing the distance in latent space to a representative distribution of class i.

Fig. 4.
figure 4

(a) Shows the original instance, (b) to (g) on the first row illustrate counterfactuals generated by using loss functions A to F. (b) to (g) on the second row show the reconstructed counterfactuals using AE.

The picture for IM2 is similar. Adding in \(L_{\text {proto}}\) brings IM2 down by 34% while the combination of \(L_{\text {AE}}\) and \(L_{\text {proto}}\) only reduces the metric by 24%. For large values of K the prototypes are further from \(\text {ENC}(x_{0})\) resulting in larger initial perturbations towards the counterfactual class. In this case, \(L_{\text {AE}}\) ensures the overall distribution is respected which makes the reconstructed images of \(\text {AE}_{i}\) and \(\text {AE}\) more similar and improves IM2. The impact of K on IM1 and IM2 is illustrated in the supplementary material. The removal of \(L_{\text {pred}}\) in E and F has little impact on IM2. This emphasizes that \(L_{\text {proto}}\)—optionally in combination with \(L_{\text {AE}}\)—is the dominant term with regards to interpretability.

Finally, performing kernel multiple model comparison tests [18] indicates that counterfactuals generated by methods not including the prototype term (A and B) result in high rejection rates for faithfully modelling the predicted class distribution (see supplementary material).

Visual Interpretability. Figure 4 shows counterfactual examples on the first row and their reconstructions using \(\text {AE}\) on the second row for different loss functions. The counterfactuals generated with A or B are sparse but uninterpretable and are still close to the manifold of a 2. Including \(L_{\text {proto}}\) in Fig. 4(d) to (g) leads to a clear, interpretable 0 which is supported by the reconstructed counterfactuals on the second row. More examples can be found in the supplementary material.

Sparsity. The elastic net evaluation metric \(\text {EN}(\delta )\) is also the only loss term present in A besides \(L_{\text {pred}}\). It is therefore not surprising that A results in the most sparse counterfactuals (Fig. 3(c)). The relative importance of sparsity in the objective function goes down as \(L_{\text {AE}}\) and \(L_{\text {proto}}\) are added. \(L_{\text {proto}}\) leads to more sparse counterfactuals than \(L_{\text {AE}}\) (C and E), but this effect diminishes for large K.

4.3 Breast Cancer Wisconsin (Diagnostic) Dataset

The second experiment uses the Breast Cancer Wisconsin (Diagnostic) dataset which describes characteristics of cell nuclei in an image and labels them as malignant or benign. The real-valued features for the nuclei in the image are the mean, error and worst values for characteristics like the radius, texture or area of the nuclei. The dataset contains 569 instances with 30 features each. The first 550 instances are used for training, the last 19 to generate the counterfactuals. For each instance in the test set we generate 5 counterfactuals with different random seeds. Instead of an encoder we use k-d trees to find the prototypes. We evaluate and compare counterfactuals obtained by using the following loss functions:

$$\begin{aligned} \begin{aligned} A&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2}\\ B&= c \cdot L_{\text {pred}} + \beta \cdot L_{1} + L_{2} + L_{\text {proto}}\\ C&= \beta \cdot L_{1} + L_{2} + L_{\text {proto}} \end{aligned} \end{aligned}$$
(16)

The model used to classify the instances is a 2 layer feedforward neural network with 40 neurons in each layer. More details can be found in the supplementary material.

Results. Table 2 summarizes the findings for the speed and interpretability measures.

Speed. \(L_{\text {proto}}\) drastically reduces the time and iterations needed to find a satisfactory counterfactual. Loss function B finds \(x_{\text {cf}}\) in 13% of the time needed compared to A while bringing the number of gradient updates down by 91%. Removing \(L_{\text {pred}}\) and solely relying on the prototype to guide \(x_{\text {cf}}\) reduces the search time by 92% and the number of iterations by 93%.

Quantitative Interpretability. Including \(L_{\text {proto}}\) in the loss function reduces IM1 and IM2 by respectively 55% and 81%. Removing \(L_{\text {pred}}\) in C results in similar improvements over A.

Sparsity. Loss function A yields the most sparse counterfactuals. Sparsity and interpretability should however not be considered in isolation. The dataset has 10 attributes (e.g. radius or texture) with 3 values per attribute (mean, error and worst). B and C which include \(L_{\text {proto}}\) perturb relatively more values of the same attribute than A which makes intuitive sense. If for instance the worst radius increases, the mean should typically follow as well. The supplementary material supports this statement.

Table 2. Summary statistics with 95% confidence bounds for each loss function for the Breast Cancer Wisconsin (Diagnostic) experiment.
Fig. 5.
figure 5

Left: Embedding of the categorical variable “Education” in numerical space using association based distance metric (ABDM). Right: Frequency based embedding.

4.4 Adult (Census) Dataset

The Adult (Census) dataset consists of individuals described by a mixture of numerical and categorical features. The predictive task is to determine whether a person earns more than $50k/year. As the dataset contains categorical features, it is important to use a principled approach to define perturbations over these features. Figure 5 illustrates our approach using the association based distance metric [16] (ABDM) to embed the feature “Education” into one dimensional numerical space over which perturbations can be defined. The resulting embedding defines a natural ordering of categories in agreement with common sense for this interpretable variable. By contrast, the frequency embedding method as proposed by [9] does not capture the underlying relation between categorical values.

Since ABDM infers distances from other variables by computing dissimilarity based on the K-L divergence, it can break down if there is independence between categories. In such cases one can use MVDM [6] which uses the difference between the conditional model prediction probabilities of each category. A counterfactual example changing categorical features is shown in Fig. 1.

5 Discussion

In this paper we introduce a model agnostic counterfactual search process guided by class prototypes. We show that including a prototype loss term in the objective results in more interpretable counterfactual instances as measured by two novel interpretability metrics. We demonstrate that prototypes speed up the search process and remove the numerical gradient evaluation bottleneck for black box models thus making our method more appealing for practical applications. By fixing selected features to the original values during the search process we can also obtain actionable counterfactuals which describe concrete steps to take to change a model’s prediction. To facilitate the practical use of counterfactual explanations we provide an open source library with our implementation of the method.