Keywords

1 Introduction

Machine learning (ML) based image classification algorithms, such as deep neural networks, are increasingly employed in settings where transparency and comprehensibility of decisions are crucial such as medical diagnostics or industrial quality control. Research on explainable artificial intelligence (XAI) is addressing these requirements [1] by providing techniques to support the decision making of ML black-box models and thereby allow users to develop justified trust [14]. Many XAI methods identify the most relevant information in the input for the classifier decision. While this information is helpful for model developers, e.g., to detect overfitting [13], it might be not expressive enough to explain model decisions for domain experts such as medical experts or quality engineers [14].

Cognitive science research provides theories as well as empirical evidence that explanations by examples are highly effective for humans to grasp complex concepts [7, 11, 12]. Therefore, we consider in this paper two kinds of examples, with the specific goal of explaining image classifier AI models to end users and domain experts without expertise in machine learning:

  1. 1.

    Prototypes, representing typical representative instances of some image class as a global explanation of the model, and

  2. 2.

    Near hits and misses of some given input, representing examples from the training data similar to the input image and from the same (or opposite, respectively) class, as local explanations.

In combination, prototypes, near hits and near misses allow users to get a better understanding of information considered relevant as well as of the decision boundaries of a given classification algorithm.

Numerous algorithms for computing prototypes of a given data set exist. In this paper, we primarily use [8], a widely used state-of-the-art approach based on Maximum Mean Discrepancy (see Subsect. 3.1 for details). ProtoDash [4] builds on the former, but at time of writing, no adequate implementation with sufficient adaptability for our experiments could be found. We additionally use Partitioning around Medoids [15] as a baseline approach for comparison; an improved version of a simple k-Medoids clustering algorithm [6], where we interpret the associated medoids of each cluster as prototypes.

Near hits and misses (NHMs) as relating to classified data are much less well covered by the existing literature, especially as an explanatory tool. One notable exception is [11], where NHMs are computed specifically for Prolog clauses to explain classifications in the context of Inductive Logic Programming. Conceptually however, finding close matches of a given input according to some metric is a ubiquitous tool in many distinct areas, such as in feature selection [17] or – more closely related to our purposes – in content-based image retrieval [5].

For providing more faithful explanations, we differentiate between two vector embeddings for handling images: a model-specific relying on the CNN-based classification model to be explained, and a model-agnostic allowing obtain another embedding unbiased by our data sets and unrelated to our classification model.

In the following, we describe the algorithms used for example-based explanations with focus on their evaluation for two data sets – the classic MNIST and a real-world data set of casting manufacturing image data for industrial quality control [2]. We start in Sect. 2 with describing the setup for our experiments, i.e. the data sets and classifier models used, and a brief overview of the final user-centric architecture. Sections 3 and 4 deal with prototypes and near hits and misses, respectively, the algorithms used, their parameters and our evaluations thereof. Lastly, in Sect. 5 we present our demonstrator implementation.

2 Methodology

2.1 Data Sets

We primarily use [2] for our experiments; a data set consisting of 1100 grayscale images of cast metal components of size 512\(\,\times \,\)512 labelled with one of two classes, “ok” (419 entries) and “defective” (681 entries), see Fig. 1. The entries of the latter class show various kinds of defects, e.g., blow holes, abrasions, scratches etc. (see Fig. 1a). Notably, the data set is highly homogeneous in that the objects in the images are very similar to each other (except for the defects, which are usually subtle), but differ with respect to features that are irrelevant for purposes of classification, e.g., lighting conditions and angle (see Fig. 1b). The data set occasionally contains multiple images from different angles of the same object, which makes it especially interesting for the purpose of evaluating near hits and misses.

Fig. 1.
figure 1

Some examples from the casting data set [2].

For comparison, we additionally use the MNIST data set of handwritten digits [9]. Since the casting data set is restricted to two classes, we correspondingly restrict MNIST to two classes – namely “1” and “7” (each consisting of 7877 and 7293 entries, respectively), which are uniformly white digits on black background, but differ significantly in their shapes within their respective classes.

2.2 Models and Embeddings

For each of our two data sets, we trained a small standard convolutional neural network (CNN) with three convolutional and two fully connected layers on the respective classification tasks, with resulting accuracies of 96.82% and 99.72% respectively.

These models actually serve two purposes: Firstly, they naturally serve as toy classifier models to be explained by our overall approach. Secondly, we can use feature extraction on the models to obtain embeddings for our images, which should be sensitive to those aspects of an image that relate to its inferred class. We consequently expect these embeddings to map images with similar class-relevant features near each other, leading to more informative near hits and misses. However, it should be noted that by using embeddings depending on the classifier model, our approach is model-specific. That is, it is required that the model to be explained is a neural network (or otherwise induces a suitable vector embedding). We therefore additionally use a generic state-of-the-art image classification model (VGG16 [16]) to obtain a second embedding unbiased by our data sets and unrelated to our classifier model, allowing us to remain model-agnostic. We refer to the embedding obtained via feature extraction on our classifier models as \(\mathtt {E}_C\), and the one using VGG16 as \(\mathtt {E}_{\mathtt {VGG}}\). We will occasionally use the raw image vectors for comparison, which we denote as the (trivial) embedding \(\mathtt {E}_0\).

2.3 Architecture Overview

Figure 2 shows our approach as envisioned in practice. A user selects an image, for instance, of an industrial manufacturing component, which is classified by a CNN (or other black-box model). The inferred label is used to obtain a set of prototypes with the same label from the training data set. Both the label and the input image – under some vector embedding – are used to select a number of comparable near hits and misses from the training set. All three combined are provided to the user, allowing to better comprehend both the returned classification by comparing it to prototypical samples and the most similar (ground-truth labelled) elements from the training data (near hits), as well as the decision boundary in a contrastive manner via the near misses.

Fig. 2.
figure 2

Overview of the implemented architecture.

3 Prototype Selection

3.1 Prototype Selection Using Maximum Mean Discrepancy

Kim et al. [8] propose an approach for prototype selection based on Maximum Mean Discrepancy (MMD), a similarity measure on distributions, rather than individual data points: Given (finite approximations for) distributions XY, then the expression

$$\begin{aligned} \begin{aligned} \mathtt {MMD}^2(X,Y) \; := \;&\frac{1}{|X|^2}\sum _{x_1,x_2\in X}k(x_1,x_2)+\frac{1}{|Y|^2}\sum _{y_1,y_2\in Y}k(y_1,y_2)\\&-\frac{2}{|X|\cdot |Y|}\sum _{x\in X,y\in Y}k(x,y) \end{aligned} \end{aligned}$$

approaches 0, as X and Y become more similar with respect to a Hilbert space of testing functions with reproducing kernel k. For our purposes, we use the radial basis kernel function \(k(x,y) := e^{-\gamma || x - y ||}\) for a real-valued scaling parameter \(\gamma \). We can use this for selecting prototypes as follows:

Given a set of embedded data points X with \(|X|=n\) and a kernel function \(k:X\times X\rightarrow \mathbb {R}\), our objective is to find a subset \(S\subseteq X\) with \(|S| = m\) such that \(\mathtt {MMD^2}(X,\varnothing )-\mathtt {MMD}^2(X,S)\) is maximized, which can be simplified to the following cost function:

$$\begin{aligned} J(S):= \frac{2}{nm} \sum _{x\in X,s\in S}k(x,s) - \frac{1}{m^2}\sum _{s_1,s_2\in S}k(s_1,s_2) \end{aligned}$$

The remaining aspects of the selection algorithm are straight-forward (see Algorithm 1).

figure a

Surprisingly, [8] suggests using raw image data as input for the algorithm. While this works well for some data sets, e.g., MNIST, we agree with [10] that feature embeddings should yield better results in general. Nevertheless, we compare both variants.

3.2 Parameter Selection

Algorithm 1 depends on two parameters: The number m of desired prototypes and the scaling factor \(\gamma \) of the kernel function. To determine the optimal value for the latter, [10] suggests training a 1-Nearest-Neighbour (1-NN) algorithm on the selected prototypes to classify an test set. Additionally, we used a k-fold cross-validation averaged for robustness. Notably, the best value for \(\gamma \) seems to depend on both the underlying data set and the embedding used (see Fig. 3).

Fig. 3.
figure 3

\(\gamma \)-Values plotted against recall on two data sets and two embeddings while using 1-NN.

Regarding m, the natural but expensive approach would be a survey on end-users. Instead, we opt for an analytical approach by considering two methods:

  1. 1.

    We perform an Elbow method based on a k-Means algorithm, plotting the distortion (i.e. the sum of square distances from each point to its assigned cluster center) wrt. the number of clusters/prototypes (see Fig. 4a).

  2. 2.

    We use the fact that \(\mathtt {MMD}^2(X,S)\) gives us a measure of how “representative” the elements of S are with respect to our full data set X. We therefore consider a Scree plot (see Fig. 4b) of the MMD\(^2\)-value against m.

Fig. 4.
figure 4

Representative plots using an Elbow method (a) and a Scree plot (b).

We applied both methods on all three embeddings \(E_0\), \(E_C\) and \(E_{\mathtt {VGG}}\), and on both data sets, and observe at which values the respective curves flatten. In all cases, this happens noticeably at a value of about \(m=3\), which strongly suggests that more than three prototypes do not convey significantly more information. However, note that the optimal number should vary depending on the specific data set under consideration.

3.3 Evaluation

Similarly to our strategy for selecting parameters, we evaluate the resulting prototypes using a 1-NN approach with respect to the training data set. We additionally use an off-the-shelf Partitioning around Medoids [15] k-Medoids clustering algorithm as a baseline approach. The results are shown in Table 1.

Table 1. Results of 1-NN algorithm trained selected prototypes, evaluated with respect of the embedding based on accuracy (best performance in bold).

Overall, the MMD-based approach performs mostly better than k-Medoids, although the difference is surprisingly small. As expected, the embedding \(\mathtt {E}_C\) obtained via feature extraction on the classifier model itself shows consistently better results than the alternative embeddings. Surprisingly, the unbiased embedding (\(\mathtt {E}_{\mathtt {VGG}}\)) is also superior to the raw data (\(\mathtt {E}_0\)) on the casting data set, while not on MNIST - probably due to the simple structure and uniform background.

Fig. 5.
figure 5

Selected prototypes of the casting data set.

Figure 5 shows the resulting prototypes of our primary data set. Notably, the defective prototypes using \(\mathtt {E}_C\) cover exactly the three primary kinds of defects occurring in the data set - a blowhole in the first, abrasions in the second, and a scratch in the third, whereas the \(\mathtt {E}_{\mathtt {VGG}}\) based prototypes are noticeably less diverse in that respect. Like the corresponding data samples themselves, the prototypes for the “ok” class are largely very similar, regardless of the embedding used.

4 Near Miss and Hit Selection

Regarding NHMs, our algorithm is conceptually straight-forward (see Algorithm 2). Given a data sample e, we choose as a subset X of our training data either those samples with the same inferred label as e (near hits) or those with a different label (near misses). Then we compare each element of X to e by some given metric \(m:\mathbb R^n\times \mathbb R^n\rightarrow \mathbb R\): (i) The euclidean metric \(\sqrt{\sum _i (x_i - y_i)^2}\), (ii) the manhattan metric \(\sum _i |x_i - y_i|\) and (iii) the cosine metric \(1-{x\cdot y}/{\Vert x\Vert \cdot \Vert y\Vert }\), again using all three of our embeddings.

figure b

4.1 Evaluation

Evaluating the accuracy of NHMs analytically is considerably difficult in that no objective measure of similarity – especially with respect to those features that are relevant for classification – exists, which could serve as a baseline comparison. While this applies equally to prototypes, this problem becomes a lot more prominent here, where comparisons between individual pairs of data samples need to be considered. Ideally, we would evaluate the possible vector embeddings and metrics in a large-scale user study. In lieu of that, we opted for manually inspecting random samples of NHMs on both data sets with varying parameters.

One clear and unsurprising result is the superiority of the classifier embedding \(\mathtt {E}_C\). This is particularly noticeable with near misses on the MNIST data set. Figure 6 shows some near misses for the class “7” for both embeddings. Notably, the near misses obtained using \(\mathtt {E}_C\) all have something resembling a corner at the upper end, which could indicate a number 7, whereas using \(\mathtt {E}_{\mathtt {VGG}}\) quickly yields plain lines, much more reminiscent of a 1.

Furthermore, the near misses using \(\mathtt {E}_C\) seem to differ much more rarely depending on the metric used, or even the input image used. This makes sense, assuming the data samples are distributed such that near misses reduce to those data samples which most closely resemble the opposite class. For example, Fig. 7 shows the first five near hits for an image of class “1”, which are notably similar for all three metrics and for several input images.

Fig. 6.
figure 6

Near misses for an input of class “7”.

Fig. 7.
figure 7

Regularly occurring near hits for the class “1”.

The advantage of \(\mathtt {E}_C\) over \(\mathtt {E}_{\mathtt {VGG}}\) is much less noticeable on the casting data set (see Fig. 8), however. We conjecture that the homogeneity of the data set allows for either embedding to primarily focus on the relevant differences, i.e. exactly the defects, since even generic embeddings should largely be able to abstract from rotation, angle and similar unimportant variations.

With respect to the metric used, different choices yield different, but very similar results (equal in only \({\approx }30\%\) of cases). In fact, we could not notice a clear advantage of either over the others, regardless of the choice of embedding, with possibly a slight advantage of euclidean and manhattan distances over cosine.

Fig. 8.
figure 8

\(\mathtt {E}_C\)-based NHMs exhibits more striking features regarding the example input image, to indicate faster this misclassified image of class “defective”

5 Demonstrator

We implemented a web-based interactive demonstrator (see Fig. 9). A user can choose a model-specific (\(\mathtt E_C\)) or model-agnostic (\(\mathtt E_{\mathtt {VGG}}\)) embedding, one of our two example data sets, a metric (cosine, euclidean, manhattan), the number of near hits and misses to show, and an input image from the test sets. The system then displays the input image itself, its classification according to the CNN, the corresponding probability, prototypes for the classes and near hits and misses with their corresponding distances according to the metric chosen. The demonstrator, and all code relating to our evaluations is available onlineFootnote 1.

Fig. 9.
figure 9

Screenshot of our XAI demonstrator.

6 Conclusion and Future Work

We presented an example-based XAI approach for image classification models providing prototypes as global explanation, as well as near misses and hits to explain the local decision boundary of a prediction. Our experiments showed that model-specific embeddings are more informative with respect to decision boundaries than model-agnostic ones. In a next step, more advanced prototype selection algorithms can be evaluated, e.g., re-implementing the ProtoDash algorithm from [4].

Although there already exists some empirical evidence showing that humans can profit from these types of example-based explanations [7, 11], we plan to conduct user studies to evaluate the helpfulness of our demonstrator for the visual quality control task. Performance accuracies for predictions of class decisions of the CNN models will be compared for (a) visual highlighting, (b) prototype explanations, (c) near hit and miss explanations, and both prototype and near hit/miss explanations. Another useful enhancement could be highlighting the dissimilarities or similarities between the test image and the near miss or hit by using saliency maps – e.g. similarity based saliency maps stemming from CBIR [3] – to enable much more precise and faster indication of the decision boundaries to the domain expert. Finally, the user interface of the demonstrator can be improved with respect to intuitive interaction, ease of information acquistion, and positive user experience.