Keywords

1 Introduction

In the last decade, deep learning has taken over artificial intelligence (AI) research, driving breakthroughs in applications such as face recognition [2, 23], self-driving cars [3, 21], language modeling [4], game play [24], and algorithmic stock trading [28]. Nonetheless, the predictive power of deep learning often comes at the cost of limited interpretability. The black-box nature of these methods makes it hard to diagnose whether model predictions are, e.g., riddled by confounders or align well with domain expert knowledge.

To overcome these shortcomings of deep models, there is a growing interest in developing methods to explain neural network predictions, leading to the renaissance of the eXplainable AI (XAI) [7] field. Although methods for xAI can come in many flavors, a significant portion seeks to identify which input elements (e.g., features) are most relevant/salient to the corresponding predictions.

Understanding model predictions is especially important for high-stake applications such as drug design and personalized medicine, in which inputs are structured and usually given as graphs. In particular, graph neural networks (GNNs) [22] are the gold standard for supervised learning over graph-structured data, with notable applications in molecular property prediction [8], fake-news detection [18, 19], and content recommendation [5, 30]. Not surprisingly, a flurry of recent works has focused on explaining GNN predictions [16, 26, 31]

Notably, several recent works [10, 17, 25, 27] show that GNNs are often more complex than they need to—especially for node-level prediction tasks—, counting on expandable model components. That is, it is possible to simplify GNNs without loss of predictive performance, and still benefit from the lower computational cost and greater interpretability of simplified models. These findings naturally lead to the following question:

— Can we build simpler explanatory models for graph neural networks?

In this work, we propose ConveX, a simple explanation strategy for GNNs. More specifically, we first distill the (possibly complex) GNN we want to explain into a simpler one. Then, we generate explanations for our simplified model solving a convex optimization problem that allows identifying which nodes/edges of the input graph are most relevant for a given prediction.

Experiments with synthetic and real-world datasets show that, despit its simplicity, ConveX achieves competitive results when compared to state-of-the-art models, such as GNNExplainer [31] and PGMExplainer [26].

The remaining of this paper is organized as follows. Section two provides a brief background on GNNs. Section 3 discusses related works on explanations for node classification. Section 4 proposes ConveX, a novel explanation strategy that leverages knowledge distillation. Section 5 compares ConveX againast the state-of-the-art (SOTA) in a variety of explainability benchmarks. Finally, Sect. 6 draws conclusions and discusses future works.

2 Background

Notation. We define a graph \(G = (\mathcal {V}, \mathcal {E})\), with a set of nodes \(\mathcal {V} = \{1, \ldots , n\}\) and a set of edges \(\mathcal {E} \subseteq \mathcal {V}\times \mathcal {V}\). We denote the adjacency matrix of G by \(\boldsymbol{A} \in \mathbb {R}^{n\times n}\), i.e., \(A_{ij}\) is one if \((i,j) \in \mathcal {E}\) and zero otherwise. Let \(\boldsymbol{D}\) be the diagonal degree matrix of G, i.e., \(D_{i i} {:}{=}\sum _{j} A_{i j}\). We also define the normalized adjacency matrix as \(\widetilde{\boldsymbol{A}} = \boldsymbol{D}^{-1/2} (\boldsymbol{A} + \boldsymbol{I}) \boldsymbol{D}^{-1/2}\), where \(\boldsymbol{I}\) is the n-dimensional identity matrix. Furthermore, let \(\boldsymbol{X} \in \mathbb {R}^{n \times d}\) be a matrix of d-dimensional node features. Throughout this work, we also often denote a graph G using the pair \((\boldsymbol{A}, \boldsymbol{X})\).

2.1 Graph Neural Networks

Graph neural networks (GNNs) have gained traction over the last few years due to their potential to extract meaningful graph representations while still preserving invariances (e.g., to node permutations). In general, modern GNNs apply a series of convolutions over the node states (initially equal to node features), after which we have refined representations for each node. Then, we can use these representations for arbitrary downstream tasks.

The output of a (fully-convolutional) GNN after \(\ell \) layers can be written as a function \(\boldsymbol{H}^{(\ell )} = f(\boldsymbol{H}^{(\ell -1)},\boldsymbol{A})\), where \(\boldsymbol{H}^{(\ell )}\) is matrix with n rows, one for each node in \(\mathcal {V}\). Specific GNNs differ essentially in how they define \(f(\cdot ,\cdot )\). After a suitable number of layers, say L, we can apply a multi-layer perceptron (MLP) on top of node embeddings to get predictions. If we consider node classification, the logits for all nodes are given by \(\boldsymbol{Y} = \text {MLP}(\boldsymbol{H}^{(L)})\), where \( \boldsymbol{Y} \in \mathbb {R}^{n \times C} \) and C is the number of classes in our task.

The remaining of this subsection covers two GNN models: graph convolutional networks (GCNs) [12] and simplified graph convolutions (SGCs) [27]. The former is arguably the most popular GNN in the literature and is used profusely throughout our experiments. The latter is a linear graph model, which will be an asset for our explanation method.

Graph Convolutional Networks (GCNs) [12] are multi-layer architectures in which nodes repeatedly gather their neighbors’ states, subsequently combining them using a symmetric weighting scheme. Then, each node updates its state in a recurrent fashion using the result from neighborhood aggregation. Similarly to conventional feedforward networks, GCNs can be seen as a sequence of linear transformations followed by non-linear activation functions (e.g. ReLU). Denoting the weights of the \(\ell \)-th GCN layer by \(\boldsymbol{\varTheta }^{(\ell )}\) and the activation function as \(\sigma \), we can write compactly the output of said layer as:

$$\begin{aligned} \boldsymbol{H}^{(\ell )} = \sigma \left( \widetilde{\boldsymbol{A}} \boldsymbol{H}^{(\ell -1)} {\boldsymbol{\varTheta }^{(\ell )}}\right) , \end{aligned}$$
(1)

where \(\boldsymbol{H}^{(\ell -1)}\) is the output of the previous layer and \(\boldsymbol{H}^{(0)}\) equals the original features, i.e., \(\boldsymbol{H}^{(0)} =\boldsymbol{X}\).

Simple Graph Convolution (SGC) [27] is a simplification of GCN. We can derive SGC by removing the activation functions from intermediate layers and subsequently collapsing all weight matrices into one. To make it more concrete, recall the we can use the recursion in Eq. (1) to write the output of an \(\ell \)-layer GCN as:

$$\begin{aligned} \boldsymbol{H}^{(\ell )}&= \sigma \left( \widetilde{\boldsymbol{A}} \boldsymbol{H}^{(\ell -1)} \varTheta ^{(\ell )}\right) , \end{aligned}$$
(2)
$$\begin{aligned}&= \sigma \left( \widetilde{\boldsymbol{A}} \,\sigma \left( \cdots \,\sigma \left( \widetilde{\boldsymbol{A}}\,\sigma \left( \widetilde{\boldsymbol{A}} \boldsymbol{X} \boldsymbol{\varTheta }^{(1)}\right) \boldsymbol{\varTheta }^{(2)}\right) \dots \right) \boldsymbol{\varTheta }^{(\ell )}\right) \end{aligned}$$
(3)

Removing the intermediate non-linear activations from the equation above leaves us with:

$$\begin{aligned} \boldsymbol{H}^{(\ell )}&= \sigma \left( \widetilde{\boldsymbol{A}} \cdots \widetilde{\boldsymbol{A}}\widetilde{\boldsymbol{A}} \boldsymbol{X} \boldsymbol{\varTheta }^{(1)} \boldsymbol{\varTheta }^{(2)}\cdots \boldsymbol{\varTheta }^{(\ell )}\right) \end{aligned}$$
(4)
$$\begin{aligned}&= \sigma \left( \widetilde{\boldsymbol{A}}^{\ell } \boldsymbol{X} \boldsymbol{\varTheta }^{(1)} \boldsymbol{\varTheta }^{(2)}\cdots \boldsymbol{\varTheta }^{(\ell )}\right) . \end{aligned}$$
(5)

To finish the derivation, we substitute the product \(\boldsymbol{\varTheta }^{(1)}\boldsymbol{\varTheta }^{(2)} \cdots \boldsymbol{\varTheta }^{(\ell )}\) by a single weight matrix \(\boldsymbol{\varTheta }\). Therefore, we wind up with the node embeddings for an L-layer SGC:

$$\begin{aligned} \boldsymbol{H} = \sigma \left( \widetilde{\boldsymbol{A}}^{L} \boldsymbol{X} \boldsymbol{\varTheta }\right) . \end{aligned}$$
(6)

Notably, Wu et al. [27] showed that SGC often performs similarly to or better than GCN in a variety of node classification tasks. On top of that, training SGCs is computationally more than GCNs andd SGC models comprise significantly fewer parameters.

3 Related Works: Explanaining Node Predictions

Despite the intuitive nature of graph convolutions, stacking GNN layers—in hope of achieving superior performance—leads to an undesirable side effect: lack of interpretability. These explanations might be crucial to, e.g., validate decisions in critical applications or to help experts validate the model. To mitigate this issue, there is a growing literature proposing instance-level explanations for node prediction. Notably, the predominant approach behind SOTA methods is to create auxiliary models (explainers) capable of identifying which input elements (e.g., nodes or edges) are most relevant to the prediction of the GNN we want to explain. The rationale behind this choice is that members of a class usually share common structures, e.g., mutagenic compounds often comprise telltale chemical sequences such as \(NO_2\) (a pair of oxygen atoms together with a nitrogen atom). Consequently, checking which patterns are most relevant to a GNN prediction can be instrumental for domain experts and may help us debug spurious correlations.

GNNExplainer [31] was one of the first methods to explain GNN predictions. The method is perturbation-based and generates explanations by learning masks for edges and node attributes that allow it to remove information that does not impact GNN’s predictions. The masks are randomly initialized and treated as trainable variables. Then, GNNExplainer combines the masks with the original graph G, thus generating a subgraph \(G_s\) and a subset of attributes of nodes \(\boldsymbol{X}_s\) that contains important information for the prediction. The idea is to learn these masks by maximizing the mutual information between the possible explanatory subgraph \((G_s, \boldsymbol{X}_s)\) and the GNN prediction \(\boldsymbol{y}\), that is

$$\begin{aligned} \max _{G_{s}} \textrm{MI}(\boldsymbol{y} , (G_s, \boldsymbol{X}_s)) = H(\boldsymbol{y}) - H(\boldsymbol{y}|G = G_{s}, \boldsymbol{X}=\boldsymbol{X}_s). \end{aligned}$$
(7)

The entropy term \(H(\boldsymbol{y})\) is constant as the GNN to be explained \(\varPhi \) is fixed after training. So, maximizing the mutual information between the predicted label distribution \(\boldsymbol{y}\) and the explanation \((G_s, \boldsymbol{X}_s)\) is equivalent to minimizing the conditional entropy \(H(\boldsymbol{y}|G = G_{s}, \boldsymbol{X}=\boldsymbol{X}_s)\), which can be expressed as follows:

$$\begin{aligned} H(\boldsymbol{y}|G = G_{s}, \boldsymbol{X}=\boldsymbol{X}_s) = - \mathbb {E}_{\boldsymbol{y}|G_s, \boldsymbol{X}_s}[\log P_{\varPhi }(\boldsymbol{y}|G=G_{s},\boldsymbol{X}=\boldsymbol{X}_s)]. \end{aligned}$$
(8)

In practice, the program in Eq. (7) is an intractable combinatorial optimization problem. To circumvent this limitation, Ying et al. [31] relax the problem by learning a soft mask \(M \in [0,1]^{n\times n}\) that weighs the edges of the input graph G.

PGM-Explainer [26] builds on probabilistic graphical models to obtain instance - level explanations for GNNs. Its explanation is an interpretable Bayesian network that approximates the GNN’s prediction. Its process of generating explanations has three main steps: 1. Data generation, this step consists of generating, pre-processing, and recording a set of input-output pairs, called sampled data, of the prediction to be explained. 2. Variable selection, eliminates unimportant variables from sampled data to improve the runtime and encourage compact explanations. For this step, PGM-Explainer needs to identify which variables are important and avoid eliminating them, PGM-Explainer addresses this problem by observing that important variables to GNN’s prediction are in the Markov-blanket of prediction. 3. Structure learning, takes the filtered data from the previous step and generates an explanation.

While outside the scope of our work, it is worth mentioning that, besides instance-level, there are other works focusing on model-level explanations. These works aim to identify which graph patterns impact the most the overall behavior of a GNN [32].

4 ConveX

We now introduce ConveX—a new method for node-level explanation of graph neural networks. ConveX follows a two-step procedure: i) we first fit a simple GNN to approximate the predictions of the model to be explained (distillation phase); ii) then, we explain the simplified GNN by solving an optimization problem that has a simple loss landscape (i.e., it is convex if the GNN is linear).

Hereafter, we denote the GNN to be explained by \(\varPhi \). The matrix \(\boldsymbol{\hat{Y}} = \varPhi (G)\) comprises the predictions for all nodes in G—the i-th row of \(\boldsymbol{\hat{Y}}\) contains the class predictions of node i. Given the GNN \(\varPhi \) and its predictions \(\boldsymbol{\hat{Y}}\), the interest is in finding an explainer model capable of identifying the most influential information for the prediction \(\boldsymbol{y}_i\) from \(\varPhi \) to node i. In this work, the GNN \(\varPhi \) is considered a black-box model, that is, the explainer model does not have access to any internal information, such as parameters or hidden representations of the model, but only to the input and output of the GNN.

4.1 Knowledge Distillation

Despite the success of deep learning models, most of these models have high complexity due to the need to adjust a large number of parameters. With that in mind, the approach called knowledge distillation [9] arose, in which the objective is to distill the knowledge of a complex network through a simple and interpretable model. The use of knowledge distillation for models of neural networks in graphs has been addressed in recent works [29, 33] that validate the efficiency of these strategies.

In this work, we use SGC [27] to approximate the predictions obtained with the GNN \(\varPhi \). Formally, the distilled model (SGC), here denoted by \(\varPsi \), receives the input graph \(G=(\boldsymbol{A}, \boldsymbol{X})\) and provides class predictions \(\hat{\boldsymbol{Y}}^{(\varPsi )} = \varPsi (G) = \textrm{softmax}(\widetilde{\boldsymbol{A}}^L \boldsymbol{X} \boldsymbol{\varTheta })\), where \(\boldsymbol{\varTheta }\) is the matrix of model parameters, L is the number of layers of the GNN \(\varPhi \), and the softmax function acts row-wise.

The distillation process consists of adjusting the parameters \(\boldsymbol{\varTheta }\) of the SGC model \(\varPsi \) so that it approximates the predictions of the network to be explained \(\varPhi \). This can be achieved by minimizing the Kullback-Leibler divergence \(\textrm{KL}\) between the predictions of \(\varPhi \) and \(\varPsi _{\boldsymbol{\varTheta }}\). Consider that \(\hat{\boldsymbol{y}}_i^{(\varPsi _{\boldsymbol{\varTheta }})}\) and \(\hat{\boldsymbol{y}}_i^{(\varPhi )}\) denote the class predictions for node i from the \(\varPsi _{\boldsymbol{\varTheta }}\) and \(\varPhi \) models, respectively. More concisely, our distillation process consists in solving:

$$\begin{aligned} \min _{\boldsymbol{\varTheta }}~ \left\{ \sum _i \textrm{KL} \left( \hat{\boldsymbol{y}}_i^{(\varPhi )}, \hat{\boldsymbol{y}}_i^{(\varPsi _\varTheta )} \right) = \sum _{i\in \mathcal {V}} \sum _{c=1}^{C} \hat{y}_{ic}^{(\varPhi )} \log \frac{\hat{y}_{ic}^{(\varPhi )}}{\hat{y}_{ic}^{(\varPsi _\varTheta )} } \right\} . \end{aligned}$$
(9)

4.2 Obtaining Node-Level Explanations

Intuitively, finding a good explanation \(\boldsymbol{\mathcal {E}}\) for a prediction \(\hat{\boldsymbol{y}}_i\)\(\varPhi \)’s prediction for node \(i \in \mathcal {V}\) given a graph G—can be seen as finding the smallest subgraph \(G_\mathcal {E}\) of G containing the nodes that influence the most that prediction. That being said, we define \(\boldsymbol{\mathcal {E}}\) as an n-dimensional vector of indicator variables, one for each node in G’s vertex set \(\mathcal {V}\).

Since we focus on explaining \(\varPsi \) as a surrogate of \(\varPhi \), we start off by formulating our problem as finding the explanation \(\boldsymbol{\mathcal {E}}\) the minimizes a discrepancy between output \(\hat{\boldsymbol{y}}_i^{(\varPsi )}\) of \(\varPsi \) given the original graph and the prediction using only the graph induced by \(\boldsymbol{\mathcal {E}}\):

$$\begin{aligned} \min _{\boldsymbol{\mathcal {E}}\in \{0,1\}^n}\parallel \widetilde{\boldsymbol{A}}^{L}_{i}\textrm{diag}(\boldsymbol{\mathcal {E}})\boldsymbol{X}\boldsymbol{\varTheta } - \widetilde{\boldsymbol{A}}^{L}_{i}\boldsymbol{X}\boldsymbol{\varTheta } \parallel _{2}^{2}, \end{aligned}$$
(10)

where \(\widetilde{\boldsymbol{A}}^{L}_{i}\) denotes the i-th row of the matrix \(\widetilde{\boldsymbol{A}}^{L}\). Nonetheless, the formulation in Eq. (10) has a major issue: it does not impose any budget on our node selection, admitting trivial solutions like \(\boldsymbol{\mathcal {E}} = \{1\}^n\). To solve this and simultaneously avoid binary optimization, we replace the search space by the simplex \(\varDelta = \{\boldsymbol{r} \in \mathbb {R}^n : \sum _i r_i = 1, \forall _i r_i \ge 0 \}\). Implementing this change and re-arranging computations, we wind up with:

$$\begin{aligned} \min _{\boldsymbol{\mathcal {E}} \in \varDelta }\Big \Vert \widetilde{\boldsymbol{A}}^{L}_{i}\left( \textrm{diag}(\boldsymbol{\mathcal {E}}) - \boldsymbol{I}_n\right) \boldsymbol{X}\boldsymbol{\varTheta } \Big \Vert _{2}^{2}, \end{aligned}$$
(11)

where \(\boldsymbol{I}_n\) is the n-dimensional identity matrix. Notably, \(\varDelta \) is a convex set. It is also easy to prove the objective function in Eq. (11) is a quadratic program. To this end let us denote the objective function above by f. Then, it follows that:

$$\begin{aligned} f(\boldsymbol{\mathcal {E}})&= \Big \Vert \widetilde{\boldsymbol{A}}^{L}_{i}\left( \textrm{diag}(\boldsymbol{\mathcal {E}}) - \boldsymbol{I}_n\right) \boldsymbol{X}\boldsymbol{\varTheta } \Big \Vert _{2}^{2}\\&= \Big \Vert \widetilde{\boldsymbol{A}}^{L}_{i}\textrm{diag}(\boldsymbol{\mathcal {E}}) \boldsymbol{X}\boldsymbol{\varTheta } \Big \Vert _{2}^{2} - 2 \left( \boldsymbol{\mathcal {E}}^\intercal \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) \boldsymbol{X}\boldsymbol{\varTheta }^\intercal \widetilde{\boldsymbol{A}}^{L}_{i} \boldsymbol{X}\boldsymbol{\varTheta }\right) + \delta \\&= \Big \Vert \widetilde{\boldsymbol{A}}^{L}_{i}\textrm{diag}(\boldsymbol{\mathcal {E}}) \boldsymbol{X}\boldsymbol{\varTheta } \Big \Vert _{2}^{2} + \boldsymbol{\mathcal {E}}^\intercal \boldsymbol{c} + \delta \\&= \boldsymbol{\mathcal {E}}^\intercal \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) \boldsymbol{X} \boldsymbol{\varTheta } \boldsymbol{\varTheta }^\intercal \boldsymbol{X}^\intercal \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) \boldsymbol{\mathcal {E}} + \boldsymbol{\mathcal {E}}^\intercal \boldsymbol{c} + \delta \\&= \boldsymbol{\mathcal {E}}^\intercal \boldsymbol{Q} \boldsymbol{\mathcal {E}} + \boldsymbol{\mathcal {E}}^\intercal \boldsymbol{c} + \delta \end{aligned}$$

which is the cannonic quadratic form with \(\delta \) being a constant and

$$\begin{aligned}&\boldsymbol{c} = - 2 \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) \boldsymbol{X}\boldsymbol{\varTheta }^\intercal \widetilde{\boldsymbol{A}}^{L}_{i} \boldsymbol{X}\boldsymbol{\varTheta },\\&\boldsymbol{Q} = \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) \boldsymbol{X} \boldsymbol{\varTheta } \boldsymbol{\varTheta }^\intercal \boldsymbol{X}^\intercal \textrm{diag}\left( \left( \widetilde{\boldsymbol{A}}^{L}_{i}\right) ^\intercal \right) , \end{aligned}$$

and since \(\boldsymbol{Q}\) has the form \(\boldsymbol{P}^\intercal \boldsymbol{P}\), it is positive semidefinite, and our optimization problem in Eq. (11) is convex.

Since a conventional GNN layer uses first order information (i.e., the immediate neighborhood) to update node representations, only nodes at distance \(\le L\) influence each other’s prediction in a GNN with depth L. Therefore, \(G_\mathcal {E}\) must be a subgraph of the graph induced by i and its L-neighborhood in G. For this reason, we mask out nodes outside this neighborhood, holding their importance at zero. For ease of implementation, we solve Eq. (11) in the unconstrained \(\mathbb {R}^n\), mapping vectors from \(\mathbb {R}^n\) to \(\varDelta \) using the softmax function. After optimizing for \(\boldsymbol{\mathcal {E}}\), the values in \(\boldsymbol{\mathcal {E}}\) serve as a ranking for the importance of each node.

5 Experiments

In this section, we validate the efficacy of ConveX on artificial and real-world benchmarks. We have implemented all experiments using PyTorch [20]. The code and data used in this work are availableFootnote 1.

5.1 Datasets

Synthetic Datasets. We consider six popular synthetic datasets for node-level explanation of GNNs: BA-House-Shapes, BA-Community, BA-Grids, Tree-Cycles, Tree-Grids, and BA-Bottle-Shaped. We note that these datasets are available in [31] and [26].

Each dataset consists of a single graph comprised of multiple copies of the same motif (i.e., subgraph with a specific structural pattern) connected to base subgraphs. For the datasets, BA-House-Shapes, BA-Community, BA-Grids, and BA-Bottle-Shaped the base subgraphs are randomly generated using the Barabási- Albert (BA) [1] model. For datasets Tree-Cycles and Tree-Grids, the base subgraphs are an 8-level balanced binary tree. The class of each motif-node depends only on elements belonging to the same motif, and consequently, the explanation associated with any node must include only motif elements, i.e., it cannot include base nodes. Therefore, the base subgraphs denote information irrelevant to the prediction of any node.

Real-World Datasets. We also use two real-world datasets: Bitcoin-Alpha and Bitcoin-OTC [13, 14]. These datasets denote networks in which nodes correspond to user accounts that trade Bitcoin. A directed edge (uv) (between users u and v) denotes the degree of reliability assigned by u to v, i.e., each edge has an associated score denoting the degree of trust. The Bitcoin-Alpha and Bitcoin-OTC networks have 3783 and 5881 accounts, respectively. Platform members rate other members on a scale from \(-10\) (total distrust) to \(+10\) (total trust). Each account is labeled as trusted or untrusted based on the ratings of other members. In addition, accounts have features that capture account output information, such as the average rate or normalized number of votes the account has taken. Target explanations for each node are provided by experts.

5.2 Experimental Setup

We compare ConveX against three explainers: GNNExplainer [31], and PGMExplainer [26] and an extension of Shapley Additive explanations (SHAP) [15] to GNNs—SHAP is an additive feature attribution method. To ensure a valid comparison, we closely follow guidelines and the evaluation setup from the original works. Following the setup in [26], the GNN to be explained consists of a 3-layer GCN [12] with ReLU activation. We use an 80%/10%/10% (train/val/test) split for all datasets. Moreover, the GCN model is trained for 10000 epochs using early stopping with patience of 100 epochs.

For the distillation phase in ConveX, we use a SGC [27] model with 3 layers. During model distillation, predictions from all nodes are used such that the distiller model better fits the model to be explained. We use learning rate of 0.1 with a weight decay of \(5.0 \times 10^{-6}\) for 10000 epochs. To obtain node-level explanations, ConveX applies gradient descent with learning rate of 0.1 for a maximum number of iterations equal to 100. For the synthetic benchmarks, which do not count on node features, we train the SGCs using one-hot node features that indicate each node’s degree.

All implementations were developed using the PyTorch [20] and Torch Geometric [6] libraries. In addition, in all experiments, we use Adam [11] optimizer.

Fig. 1.
figure 1

Illustration of how synthetic datasets are assembled with their motifs and base nodes. The node labels are represented by colors. Furthermore, to evaluate how each explainer behaves, examples of explanations extracted from GNNExplainer, PGM-Explainer, and ConveX for the same prediction in each dataset are shown. (Color figure online)

Our evaluation setup closely follows that from [26]. In particular, we report accuracy for the synthetic datasets as “true” explanations are available. Importantly, the models only aim to explain predictions for motif-nodes. An explanation consists of a rank of n nodes, where n is the number of nodes in the corresponding motif. For Bitcoin-Alpha and Bitcoin-OTC, we use precision as evaluation metric. The explanation methods generate a fixed number of nodes n as explanations (with \(n \in \{3, 4, 5\}\)) as true explanations for these datasets have varying sizes. Lastly, we do not provide errorbars (std. deviations) to allow comparison with the numbers reported in the original works [26, 31].

5.3 Results

The results achieved with the \(\varPsi \) distiller for the eight datasets are shown in the Table 1. The evaluation metric used was the accuracy between the model to be explained and the distiller model. Table 1 reports the mean and standard deviation calculated over 10 independent runs.

Table 1. Distillation accuracy.

In general, the results show that the distiller \(\varPsi \) manages to closely approximate the model to be explained \(\varPhi \) since for most datasets we obtain accuracy above 90%. Interestingly, even when poor distillation results occur (e.g., for the BA-community dataset), ConveX achieves significantly higher results than other state-of-the-art explainers (as we will observe in Table 2). One possible explanation for this counter-intuitive result is that the distiller can differentiate between motif and base nodes, and this suffices to obtain good explanations as we are only interested in explaining motif nodes. Figure 2 reports the confusion matrix for the BA-community dataset. Despite the low distillation accuracy (86.6%), the model correctly predicts base nodes (classes 1 and 5). Therefore, the model achieves high accuracy for the binary classification problem of distinguishing motif and base nodes, supporting our hypothesis.

Fig. 2.
figure 2

Confusion matrix for the BA-Community dataset. The classes 1 and 5 correspond to base nodes. Note that the distiller is able to differentiate well between motif- and base-nodes.

Table 2 reports the results for the synthetic datasets. The results show that ConveX generated better explanations than all the explainers on four out of six datasets: BA-House, BA-Community, BA-Grids and BA-Bottle. PGMExplainer achieves the highest accuracy on Tree-Cycles and Tree-Grids.

Figure 1 shows motifs, base subgraphs and explanations provided from different explainers for the synthetic datasets. Overall, ConveX is the model that better captures the motif structure on BA-House, BA-Community, BA-Grids and BA-Bottle, validating its higher accuracy on these benchmarks.

Table 2. Performance (accuracy) of explanatory models for synthetic datasets. ConveX is the best performing model on 4/6 datasets.
Table 3. Performance (precision) of explanatory models for real-world datasets.

Table 3 shows the results for the Bitcoin-Alpha and Bitcoin-OTC datasets. ConveX outperforms GNNExplainer and SHAP by a large margin. Notably, ConveX is the best performing model for top-3 and top-4 predictions while PGM-Explainer achieves the highest precision for top-5.

6 Conclusion

In this work, we focused on finding simpler formulations for models that explain graphs neural networks and we proposed ConveX, a convex model for generating explanations for predictions of GNNs. ConveX seeks to identify which information is relevant to GNN predictions in a simple way. The proposed model uses knowledge distillation to condense the knowledge of a possibly complex network into a simpler and more interpretable model. To explain node predictions, ConveX computes a node importance vector (weights) that preserves the predictions of the distilled net. To emphasize the notion that we must prioritize nodes, ConveX further restricts the importance vector to lie in the simplex.

We evaluated ConveX using eight datasets and compared it against three baseline explainers. The results show that ConveX generates better explanations in most cases. Overall, our assessment provides evidence towards a positive answer to our initial research question: it is possible to build simpler explainers for graph neural networks.