Keywords

1 Introduction

In recent years, Deep Neural Networks (DNNs) have gained prominence in various computer vision tasks and practical applications. This progress has been in part accelerated by multiple innovations in key parts of DNN pipelines, e.g., architecture design  [18, 30, 47, 49], optimization  [27], initialization  [12, 17], regularization  [22, 48], etc., along with a pool of effective heuristics identified by practitioners. Modern DNNs achieve now strong accuracy across tasks and domains, leading to their potential utilization as key blocks in real-world applications.

However, DNNs have also been shown to be making mostly over-confident predictions  [15], a side-effect of the heuristics used in modern DNNs. This means that for ambiguous instances bordering two classes (e.g., human wearing a cat costume), or on unrelated instances (e.g., plastic bag not “seen” during training and classified with high probability as rock), DNNs are likely to fail silently, which is a critical drawback for decision making systems. This has motivated several works to address the predictive uncertainty of DNNs  [6, 10, 31], usually taking inspiration from Bayesian approaches. Knowledge about the distribution of the network weights during training opens the way for studying the evolution of the underlying covariance matrix, and the uncertainty of the model parameters, referred to as the epistemic uncertainty  [26]. In this work we propose a method for estimating the distribution of the weights by tracking their trajectory during training. This enables us to sample an ensemble of networks and estimate more reliably the epistemic uncertainty and detect out-of-distribution samples.

Fig. 1.
figure 1

Our algorithm uses Kalman filtering for tracking the distribution \(\mathcal {W}\) of all DNN weights across training steps from a generic prior \(\mathcal {W}(0)\) to the final estimate \(\mathcal {W}(t^*)\). We also estimate the covariance matrix of all the trainable network parameters. Popular alternative approaches rely typically either on ensembles of models trained independently [31] with a significant computational cost, approximate ensembles  [10] or on averaging weights collected from different local minima  [23].

The common practice in training DNNs is to first initialize its weights using an appropriate random initialization strategy and then slowly adjust the weights through optimization according to the correctness of the network predictions on many mini-batches of training data. Once the stopping criterion is met, the final state of the weights is kept for evaluation. We argue that the trajectory of weights towards the (local) optimum reveals abundant information about the structure of the weight space that we could exploit, instead of discarding it and looking only at the final point values of the weights. Popular DNN weight initialization techniques  [12, 17] consist of an effective layer-wise scaling of random weight values sampled from a Normal distribution. Assuming that weights follow a Gaussian distribution at time \(t=0\), owing to the central limit theorem weights will also converge towards a Gaussian distribution. The final state is reached here through a noisy process, where the stochasticity is induced by the weight initialization, the order and configuration of the mini-batches, etc. We find it thus reasonable to see optimization as a random walk leading to a (local) minimum, in which case “tracking” the distribution makes sense (Fig. 1). To this end, Kalman filtering (KF)  [14] is an appropriate strategy for tractability reasons, as well as for the guaranteed optimality as long as the underlying assumptions are valid (linear dynamic system with Gaussian assumption in the predict and update steps)Footnote 1. To the best of our knowledge, our work is the first attempt to use such a technique to track the DNN weight distributions, and subsequently to estimate its epistemic uncertainty.

Contributions. The keypoints of our contribution are: (a) this is the first work which filters in a tractable manner the trajectory of the entire set of trainable parameters of a DNN during the training process; (b) we propose a tractable approximation for estimating the covariance matrix of the network parameters; (c) we achieve competitive or state of the art results on most regression datasets, and on out-of-distribution experiments our method is better calibrated on three segmentation datasets (CamVid  [7] , StreetHazards  [20], and BDD Anomaly   [20]); (d) our approach strikes an appealing trade-off in terms of performance and computational time (training + prediction).

2 TRAcking of the Weight DIstribution (TRADI)

In this section, we detail our approach to first estimate the distribution of the weights of a DNN at each training step, and then generate an ensemble of networks by sampling from the computed distributions at training conclusion.

2.1 Notations and Hypotheses

  • X and Y are two random variables, with \(X\sim \mathcal {P}_X\) and \(Y\sim \mathcal {P}_Y\). Without loss of generality we consider the observed samples \(\{\mathbf {x}_i\}_{i=1}^{n}\) as vectors and the corresponding labels \(\{y_i\}_{i=1}^{n}\) as scalars (class index for classification, real value for regression). From this set of observations, we derive a training set of \(n_l\) elements and a testing set of \(n_{\tau }\) elements: \(n=n_l+ {n_{\tau }}\).

  • Training/Testing sets are denoted respectively by \(\mathcal {D}_l=(\mathbf {x}_i,y_i)_{i=1}^{n_l}\), \(\mathcal {D}_{\tau }=(\mathbf {x}_i,y_i)_{i=1}^{n_{\tau }}\). Data in \(\mathcal {D}_l\) and \(\mathcal {D}_{\tau }\) are assumed to be i.i.d. distributed according to their respective unknown joint distribution \(\mathcal {P}_l\) and \(\mathcal {P}_{\tau }\).

  • The DNN is defined by a vector containing the K trainable weights \(\varvec{\mathbf {\omega }}=\{\omega _k\}_{k=1}^{K}\). During training, \(\varvec{\mathbf {\omega }}\) is iteratively updated for each mini-batch and we denote by \(\varvec{\mathbf {\omega }}(t)\) the state of the DNN at iteration t of the optimization algorithm, realization of the random variable W(t). Let g denote the architecture of the DNN associated with these weights and \(g_{\varvec{\mathbf {\omega }}(t)}(x_i)\) its output at t. The initial set of weights \(\varvec{\mathbf {\omega }}(0)=\{\omega _k(0)\}_{k=1}^{K}\) follows \(\mathcal {N}(0,\sigma ^2_k)\), where the values \(\sigma _k^2\) are fixed as in  [17].

  • \(\mathcal {L}(\varvec{\mathbf {\omega }}(t),y_i)\) is the loss function used to measure the dissimilarity between the output \(g_{\varvec{\mathbf {\omega }}(t)}(\mathbf {x}_i)\) of the DNN and the expected output \(y_i\). Different loss functions can be considered depending on the type of task.

  • Weights on different layers are assumed to be independent of each another at all times. This assumption is not necessary from a theoretical point of view, yet we need it to limit the complexity of the computation. Many works in the related literature rely on such assumptions  [13], and some take the assumptions even further, e.g.  [5], one of the most popular modern BNNs, supposes that all weights are independent (even from the same layer). Each weight \(\omega _k(t)\), \(k=1,\ldots ,K\), follows a non-stationary Normal distribution (i.e. \(W_{k}(t) \sim \mathcal {N}(\mu _k(t),\sigma _k^2(t))\)) whose two parameters are tracked.

2.2 TRAcking of the DIstribution (TRADI) of Weights of a DNN

Tracking the Mean and Variance of the Weights. DNN optimization typically starts from a set of randomly initialized weights \(\varvec{\mathbf {\omega }}(0)\). Then, at each training step t, several SGD updates are performed from randomly chosen mini-batches towards minimizing the loss. This makes the trajectory of the weights vary or oscillate, but not necessarily in the good direction each time  [33]. Since gradients are averaged over mini-batches, we can consider that weight trajectories are averaged over each mini-batch. After a certain number of epochs, the DNN converges, i.e. it reaches a local optimum with a specific configuration of weights that will then be used for testing. However, this general approach for training does not consider the evolution of the distribution of the weights, which may be estimated from the training trajectory and from the dynamics of the weights over time. In our work, we argue that the history of the weight evolution up to their final state is an effective tool for estimating the epistemic uncertainty.

More specifically, our goal is to estimate, for all weights \(\omega _k(t)\) of the DNN and at each training step t, \(\mu _k(t)\) and \(\sigma _k^2(t)\), the parameters of their normal distribution. Furthermore, for small networks we can also estimate the covariance \(\text {cov}(W_{k}(t),W_{k'}(t))\) for any pair of weights \((\omega _k(t),\omega _k'(t))\) at t in the DNN (see material for details). To this end, we leverage mini-batch SGD in order to optimize the loss between two weight realizations. The loss derivative with respect to a given weight \({\omega _{k}(t)}\) over a mini-batch B(t) is given by:

$$\begin{aligned} \nabla \mathcal {L}_{\omega _{k}(t)} = \frac{1}{|B(t)|} \sum _{ (\mathbf {x}_i,y_i) \in B(t)}\frac{\partial \mathcal {L}(\varvec{\mathbf {\omega }}(t-1),y_i)}{\partial \omega _{k}(t-1)} \end{aligned}$$
(1)

Weights \(\omega _{k}(t)\) are then updated as follows:

$$\begin{aligned} \omega _{k}(t)=\omega _{k}(t-1)-\eta \nabla \mathcal {L}_{\omega _{k}(t)} \end{aligned}$$
(2)

with \(\eta \) the learning rate.

The weights of DNNs are randomly initialized at \(t=0\) by sampling \(W_{k}(0)\sim \mathcal {N}(\mu _k(0),\sigma ^2_k(0))\), where the parameters of the distribution are set empirically on a per-layer basis  [17]. By computing the expectation of \(\omega _k(t)\) in Eq. (2), and using its linearity property, we get:

$$\begin{aligned} \mu _k(t)=\mu _k(t-1) -\mathbb {E}\left[ \eta \nabla \mathcal {L}_{\omega _{k}(t)}\right] \end{aligned}$$
(3)

We can see that \(\mu _k(t)\) depends on \(\mu _k(t-1)\) and on another function at time \((t-1)\): this shows that the means of the weights follow a Markov process.

As in  [2, 53] we assume that during back-propagation and forward pass weights to be independent. We then get:

$$\begin{aligned} \sigma _k^2(t)=\sigma _k^2(t-1)+\eta ^2 \mathbb {E}\left[ (\nabla \mathcal {L}_{\omega _{k}(t)} )^2\right] -\eta ^2 \mathbb {E}^2\left[ \nabla \mathcal {L}_{\omega _{k}(t)} \right] \end{aligned}$$
(4)

This leads to the following state and measurement equations for \(\mu _{k}(t)\):

$$\begin{aligned} \left\{ \begin{array}{l} \mu _{k}(t)=\mu _k(t-1)-\eta \nabla \mathcal {L}_{\omega _{k}(t)} + \varepsilon _{\mu } \\ \omega _{k}(t)=\mu _{k}(t)+ \tilde{\varepsilon }_{\mu } \end{array} \right. \end{aligned}$$
(5)

with \(\varepsilon _\mu \) being the state noise, and \(\tilde{\varepsilon }_{\mu }\) being the observation noise, as realizations of \(\mathcal {N}(0, \sigma ^2_{\mu })\) and \(\mathcal {N}(0, \tilde{\sigma }^2_{\mu })\) respectively. The state and measurement equations for the variance \(\sigma _{k}\) are given by:

$$\begin{aligned} \left\{ \begin{array}{l} \sigma _{k}^2(t)=\sigma _{k}^2(t-1)+ \left( \eta \nabla \mathcal {L}_{\omega _{k}(t)} \right) ^2+ \varepsilon _{\sigma }\\ z_k(t)=\sigma _{k}^2(t)-\mu _{k}(t)^2+\tilde{\varepsilon }_{\sigma } \\ \text {with } z_k(t) =\omega _{k}(t)^2 \end{array} \right. \end{aligned}$$
(6)

with \(\varepsilon _\sigma \) being the state noise, and \(\tilde{\varepsilon }_{\sigma }\) being the observation noise, as realizations of \( \mathcal {N}(0, \sigma ^2_{\sigma })\) and \(\mathcal {N}(0, \tilde{\sigma }^2_{\sigma })\), respectively. We ignore the square empirical mean of the gradient on the equation as in practice its value is below the state noise.

Approximating the Covariance. Using the measurement and state transition in Eq. (56), we can apply a Kalman filter to track the state of each trainable parameter. As the computational cost for tracking the covariance matrix is significant, we propose to track instead only the variance of the distribution. For that, we approximate the covariance by employing a model inspired from Gaussian Processes  [52]. We consider the Gaussian model due to its simplicity and good results. Let \(\varvec{\mathbf {\Sigma }}(t)\) denote the covariance of W(t), and let \(\mathbf {v}(t)=\begin{pmatrix} \sigma _{0}(t),\sigma _{1}(t),\sigma _{2}(t), \ldots , \sigma _{K}(t)\end{pmatrix}\) be a vector of size K composed of the standard deviations of all weights at time t. The covariance matrix is approximated by \(\hat{\varvec{\mathbf {\Sigma }}}(t) = (\mathbf {v}(t) \mathbf {v}(t)^{T})\odot \varvec{\mathcal {K}}(t)\), where \(\odot \) is the Hadamard product, and \(\varvec{\mathcal {K}}(t)\) is the kernel corresponding to the \(K\times K\) Gram matrix of the weights of the DNN, with the coefficient \((k,k')\) given by \(\varvec{\mathcal {K}}(\omega _k(t),\omega _{k'}(t))=\exp \left( -{\frac{\Vert \omega _k(t)-\omega _{k'}(t)\Vert ^{2}}{2\sigma _{\text {rbf}} ^{2}}}\right) \). The computational cost for storing and processing the kernel \(\varvec{\mathcal {K}}(t)\) is however prohibitive in practice as its complexity is quadratic in terms of the number of weights (e.g., \(K\approx 10^9\) in recent DNNs).

Rahimi and Recht  [45] alleviate this problem by approximating non-linear kernels, e.g. Gaussian RBF, in an unbiased way using random feature representations. Then, for any translation-invariant positive definite kernel \( \varvec{\mathcal {K}(t)}\), for all \((\omega _k(t),\omega _{k'}(t))\), \( \varvec{\mathcal {K}}(\omega _k(t),\omega _{k'}(t))\) depends only on \(\omega _k(t)-\omega _{k'}(t)\). We can then approximate the matrix by:

$$\varvec{\mathcal {K}}(\omega _k(t),\omega _{k'}(t)) {\equiv } \mathbb {E}_{}\left[ \cos (\varTheta \omega _k(t) + \varPhi )\cos (\varTheta \omega _{k'}(t)+ \varPhi )\right] $$

where \(\varTheta \sim \mathcal {N}(0,\sigma _{ \text {rbf}}^2)\) (this distribution is the Fourier transform of the kernel distribution) and \(\varPhi \sim \mathcal {U}_{[0,2\pi ]}\). In detail, we approximate the high-dimensional feature space by projecting over the following N-dimensional feature vector:

$$\begin{aligned} \mathbf {z}(\omega _k(t)) {\equiv } \sqrt{\frac{2}{N}}\begin{bmatrix}\cos (\theta _1\omega _k(t) + \phi _1), \ldots , \cos (\theta _N \omega _k(t)+ \phi _N))\end{bmatrix}^\top \end{aligned}$$
(7)

where the \(\theta _1,\ldots ,\theta _N\) are i.i.d. from \(\mathcal {N}(0,\sigma _{{rbf}}^2)\) and \(\phi _1,\ldots ,\phi _N\) are i.i.d. from \(\mathcal {U}_{[0,2\pi ]}\). In this new feature space we can approximate kernel \(\varvec{\mathcal {K}}(t)\) by \(\hat{\varvec{\mathcal {K}}}(t)\) defined by:

$$\begin{aligned} \hat{\varvec{\mathcal {K}}}(\omega _{k}(t),\omega _{k'}(t)) =\mathbf {z}(\omega _{k}(t))^\top \mathbf {z}(\omega _{k'}(t)) \end{aligned}$$
(8)

Furthermore, it was proved in  [45] that the probability of having an error of approximation greater than \(\epsilon \in \mathbb {R}^+\) depends on \(\exp (-N\epsilon ^2)/\epsilon ^2\). To avoid the Hadamard product of matrices of size \(K\times K\), we evaluate \(\textit{\textbf{r}}(\omega _k(t)) =\sigma _{k}(t)\textit{\textbf{z}}(\omega _{k}(t)) \), and the value at index \((k,k')\) of the approximate covariance matrix \(\hat{\varvec{\mathbf {\Sigma }}}(t)\) is given by:

$$\begin{aligned} \hat{\varvec{\mathbf {\Sigma }}}(t)(k,k') =\textit{\textbf{r}}(\omega _k(t))^\top \textit{\textbf{r}}(\omega _k(t)). \end{aligned}$$
(9)

2.3 Training the DNNs

In our approach, for classification we use the cross-entropy loss to get the log-likelihood similarly to  [31]. For regression tasks, we train over two losses sequentially and modify \(g_{\varvec{\mathbf {\omega }}(t)}(\textit{\textbf{x}}_i)\) to have two output heads: the classical regression output \(\mu _{pred}(\textit{\textbf{x}}_i)\) and the predicted variance of the output \(\sigma _{pred}^2\). This modification is inspired by [31]. The first loss is the MSE \(\mathcal {L}_1(\varvec{\mathbf {\omega }}(t),\textit{\textbf{y}}_i)=\Vert g_{\varvec{\mathbf {\omega }}(t)}(\textit{\textbf{x}}_i)-\textit{\textbf{y}}_i \Vert ^2_2\) as used in the traditional regression tasks. The second loss is the negative log-likelihood (NLL)  [31] which reads:

$$\begin{aligned} \mathcal {L}_2(\varvec{\mathbf {\omega }}(t),y_i) = \frac{1}{2\sigma _{\text {pred}}(\mathbf {x}_i)^2} \Vert \mu _{\text {pred}}(\mathbf {x}_i) - y_i \Vert ^2 + \frac{1}{2}\log \sigma _{\text {pred}}(\mathbf {x}_i)^2 \end{aligned}$$
(10)

We first train with loss \(\mathcal {L}_1(\varvec{\mathbf {\omega }}(t),y_i)\) until reaching a satisfying \(\varvec{\mathbf {\omega }}(t)\). In the second stage we add the variance prediction head and start fine-tuning from \(\varvec{\mathbf {\omega }}(t)\) with loss \(\mathcal {L}_2(\varvec{\mathbf {\omega }}(t),y_i)\). In our experiments we observed that this sequential training is more stable as it allows the network to first learn features for the target task and then to predict its own variance, rather than doing both in the same time (which is particularly unstable in the first steps).

2.4 TRADI Training Algorithm Overview

We detail the TRADI steps during training in Appendix, Sect. 1.3. For tracking purposes we must store \(\mu _k(t)\) and \(\sigma _k(t)\) for all the weights of the network. Hence, the method computationally lighter than Deep Ensembles, which has a training complexity scaling with the number of networks composing the ensemble. In addition, TRADI can be applied to any DNN without any modification of the architecture, in contrast to MC Dropout that requires adding dropout layers to the underlying DNN. For clarity we define \( \mathcal {L}(\varvec{\mathbf {\omega }}(t),B(t)) = \frac{1}{|B(t)|} \sum _{ (x_i,y_i) \in B(t)}\ \mathcal {L}(\varvec{\mathbf {\omega }}(t),y_i)\). Here \(\mathbf {P}_\mu \), \(\mathbf {P}_\sigma \) are the noise covariance matrices of the mean and variance respectively and \(\mathbf {Q}_\mu \), \(\mathbf {Q}_\sigma \) are the optimal gain matrices of the mean and variance respectively. These matrices are used during Kalman filtering [24].

2.5 TRADI Uncertainty During Testing

After having trained a DNN, we can evaluate its uncertainty by sampling new realizations of the weights from to the tracked distribution. We call \(\tilde{\varvec{\mathbf {\omega }}}(t)=\{\tilde{\omega }_k(t)\}_{k=1}^{K}\) the vector of size K containing these realizations. Note that this vector is different from \(\varvec{\mathbf {\omega }}(t)\) since it is sampled from the distribution computed with TRADI, that does not correspond exactly to the DNN weight distribution. In addition, we note \(\varvec{\mathbf {\mu }}(t)\) the vector of size K containing the mean of all weights at time t.

Then, two cases can occur. In the first case, we have access to the covariance matrix of the weights (by tracking or by an alternative approach) that we denote \(\varvec{\mathbf {\Sigma }}(t)\), and we simply sample new realizations of W(t) using the following formula:

$$\begin{aligned} \tilde{\varvec{\mathbf {\omega }}}(t)= \varvec{\mathbf {\mu }}(t) + \varvec{\mathbf {\Sigma }}^{1/2}(t) \times \textit{\textbf{m}}_1 \end{aligned}$$
(11)

in which \(\textit{\textbf{m}}_1\) is drawn from the multivariate Gaussian \(\mathcal {N}(\textit{\textbf{0}}_K,\textit{\textbf{I}}_K)\), where \(\textit{\textbf{0}}_K,\textit{\textbf{I}}_K\) are respectively the K-size zero vector and the \(K\times K\) size identity matrix.

When we deal with a DNN (the considered case in this paper), we are constrained for tractability reasons to approximate the covariance matrix following the random projection trick proposed in the previous section, and we generate new realizations of W(t) as follows:

$$\begin{aligned} \tilde{\varvec{\mathbf {\omega }}}(t)= \varvec{\mathbf {\mu }}(t) + \textit{\textbf{R}}(\varvec{\mathbf {\omega }}(t)) \times \textit{\textbf{m}}_2 \end{aligned}$$
(12)

where \( \textit{\textbf{R}}(\varvec{\mathbf {\omega }}(t))\) is a matrix of size \(K\times N\) whose rows \(k \in [1,K]\) contain the \(\textit{\textbf{r}}(\omega _k(t))^\top \) defined in Sect. 2.2. \(\textit{\textbf{R}}(\varvec{\mathbf {\omega }}(t))\) depends on \((\theta _1,\ldots ,\theta _N)\) and on \((\phi _1,\ldots ,\phi _N)\) defined in Eq. (7). \(\textit{\textbf{m}}_2\) is drawn from the multivariate Gaussian \(\mathcal {N}(\textit{\textbf{0}}_N,\textit{\textbf{I}}_N)\), where \(\textit{\textbf{0}}_N,\textit{\textbf{I}}_N\) are respectively the zero vector of size N and the identity matrix of size \(N\times N\). Note that since \(N \ll K\), computations are significantly accelerated.

Then similarly to works in [26, 37], given input data \((\mathbf {x}^*,y^*)\in \mathcal {D}_{\tau }\) from the testing set, we estimate the marginal likelihood as Monte Carlo integration. First, a sequence \(\{\tilde{\varvec{\mathbf {\omega }}}^j(t)\}_{j=1}^{N{_{\text {model}}}}\) of \(N{_{\text {model}}}\) realizations of W(t) is drawn (typically, \(N{_{\text {model}}}=20\)). Then, the marginal likelihood of \(y^*\) over W(t) is approximated by:

$$\begin{aligned} \mathcal {P}(y^*|x^*) = \frac{1}{N{_{\text {model}}}} \sum _{j=1}^{N{_{\text {model}}}} \mathcal {P}(y^*|\tilde{\varvec{\mathbf {\omega }}}^j(t),\mathbf {x}^*) \end{aligned}$$
(13)

For regression, we use the strategy from  [31] to compute the log-likelihood of the regression and consider that the outputs of the DNN applied on \(\mathbf {x}^*\) are the parameters \(\{\mu ^j_{\text {pred}}(\mathbf {x}^*),(\sigma ^j_{\text {pred}}(\mathbf {x}^*))^2\}_{j=1}^{N{_{\text {model}}}}\) of a Gaussian distribution (see Sect. 2.3). Hence, the final output is the result of a mixture of \(N{_{\text {model}}}\) Gaussian distributions \(\mathcal {N}(\mu ^j_{\text {pred}}(\mathbf {x}^*),(\sigma ^j_{\text {pred}}(\mathbf {x}^*))^2)\). During testing, if the DNN has BatchNorm layers, we first update BatchNorm statistics of each of the sampled \(\tilde{\varvec{\mathbf {\omega }}}^j(t)\) models, where \(j \in [1, N{_{\text {model}}}]\)  [23].

3 Related Work

Uncertainty estimation is an important aspect for any machine learning model and it has been thoroughly studied across years in statistical learning areas. In the context of DNNs a renewed interest has surged in dealing with uncertainty, In the following we briefly review methods related to our approach.

Bayesian Methods. Bayesian approaches deal with uncertainty by identifying a distribution of the parameters of the model. The posterior distribution is computed from a prior distribution assumed over the parameters and the likelihood of the model for the current data. The posterior distribution is iteratively updated across training samples. The predictive distribution is then computed through Bayesian model averaging by sampling models from the posterior distribution. This simple formalism is at the core of many machine learning models, including neural networks. Early approaches from Neal  [39] leveraged Markov chain Monte Carlo variants for inference on Bayesian Neural Networks. However for modern DNNs with millions of parameters, such methods are intractable for computing the posterior distribution, leaving the lead to gradient based methods.

Modern Bayesian Neural Networks (BNNs). Progress in variational inference  [28] has enabled a recent revival of BNNs. Blundell et al.  [6] learn distributions over neurons via a Gaussian mixture prior. While such models are easy to reason along, they are limited to rather medium-sized networks. Gal and Ghahramani  [10] suggest that Dropout  [48] can be used to mimic a BNN by sampling different subsets of neurons at each forward pass during test time and use them as ensembles. MC Dropout is currently the most popular instance of BNNs due to its speed and simplicity, with multiple recent extensions  [11, 32, 50]. However, the benefits of Dropout are more limited for convolutional layers, where specific architectural design choices must be made  [25, 38]. A potential drawback of MC Dropout concerns the fact that its uncertainty is not reducing with more training steps  [41, 42]. TRADI is compatible with both fully-connected and convolutional layers, while uncertainty estimates are expected to improve with training as it relies on the Kalman filter formalism.

Ensemble Methods. Ensemble methods are arguably the top performers for measuring epistemic uncertainty, and are largely applied to various areas, e.g. active learning [3]. Lakshminarayan et al.  [31] propose training an ensemble of DNNs with different initialization seeds. The major drawback of this method is its computational cost since one has to train multiple DNNs, a cost which is particularly high for computer vision architectures, e.g., semantic segmentation, object detection. Alternatives to ensembles use a network with multiple prediction heads  [35], collect weight checkpoints from local minima and average them  [23] or fit a distribution over them and sample networks  [37]. Although the latter approaches are faster to train than ensembles, their limitation is that the observations from these local minima are relatively sparse for such a high dimensional space and are less likely to capture the true distributions of the space around these weights. With TRADI we are mitigating these points as we collect weight statistics at each step of the SGD optimization. Furthermore, our algorithm has a lighter computational cost than  [31] during training.

Kalman Filtering (KF). The KF [24] is a recursive estimator that constructs an inference of unknown variables given measurements over time. With the advent of DNNs, researchers have tried integrating ideas from KF in DNN training: for SLAM using RNNs  [8, 16], optimization  [51], DNN fusion  [36]. In our approach, we employ KF for keeping track of the statistics of the network during training such that at “convergence” we have a better coverage of the distribution around each parameter of a multi-million parameter DNN. The KF provides a clean and relatively easy to deploy formalism to this effect.

Weight Initialization and Optimization. Most DNN initialization techniques [12, 17] start from weights sampled from a Normal distribution, and further scale them according to the number of units and the activation function. BatchNorm  [22] stabilizes training by enforcing a Normal distribution of intermediate activations at each layer. WeightNorm  [46] has a similar effect over the weights, making sure they are sticking to the initial distributions. From a Bayesian perspective the \(L_2\) regularization, known as weight decay, is equivalent to putting a Gaussian prior over the weights  [4]. We also consider a Gaussian prior over the weights, similar to previous works  [6, 23] for its numerous properties, ease of use and natural compatibility with KF. Note that we use it only in the filtering in order to reduce any major drift in the estimation of distributions of the weights across training, while mitigating potential instabilities in SGD steps.

4 Experiments

We evaluate TRADI on a range of tasks and datasets. For regression , in line with prior works  [10, 31], we consider a toy dataset and the regression benchmark  [21]. For classification we evaluate on MNIST  [34] and CIFAR-10  [29]. Finally, we address the Out-of-Distribution task for classification, on MNIST/notMNIST  [31], and for semantic segmentation, on CamVid-OOD, StreetHazards  [19], and BDD-Anomaly  [19]. Unless otherwise specified, we use mini-batches of size 128 and Adam optimizer with fixed learning rate of 0.1 in all our experiments.

Fig. 2.
figure 2

Results on a synthetic regression task comparing MC dropout, Deep Ensembles, and TRADI. x-axis: spatial coordinate of the Gaussian process. Black lines: ground truth curve. Blue points: training points. Orange areas: estimated variance. (Color figure online)

4.1 Toy Experiments

Experimental Setup. As evaluation metric we use mainly the NLL uncertainty. In addition for classification we consider the accuracy, while for regression we use the root mean squared error (RMSE). For the out- of-distribution experiments we use the AUC, AUPR, FPR-95%-TPR as in  [20], and the Expected Calibration Error (ECE) as in  [15]. For our implementations we use PyTorch  [44]. Unless otherwise specified, we use mini-batches of size 128 and Adam optimizer with fixed learning rate of 0.1 in all our experiments. We provide other implementation details on per-experiment basis.

First we perform a qualitative evaluation on a one-dimensional synthetic dataset generated with a Gaussian Process of zero mean vector and as covariance function an RBF kernel \(\varvec{\mathcal {K}}\) with \(\sigma ^2=1\), denoted \(GP(\mathbf {0},\varvec{\mathcal {K}})\). We add to this process a zero mean Gaussian noise of variance 0.3. We train a neural network composed of one hidden layer and 200 neurons. In Fig. 2 we plot the regression estimation provided by TRADI, MC Dropout  [10] and Deep Ensembles  [31]. Although \(GP(\mathbf {0},\varvec{\mathcal {K}})\) is one of the simplest stochastic processes, results show clearly that the compared approaches do not handle robustly the variance estimation, while TRADI neither overestimates nor underestimates the uncertainty.

4.2 Regression Experiments

For the regression task, we consider the experimental protocol and the data sets from [21], and also used in related works [10, 31]. Here, we consider a neural network with one hidden layer, composed of 50 hidden units trained for 40 epochs. For each dataset, we do 20-fold cross-validation. For all datasets, we set the dropout rate to 0.1 except for Yacht Hydrodynamics and Boston Housing for which it is set to 0.001 and 0.005, respectively. We compare against MC Dropout  [10] and Deep Ensembles  [31] and report results in Table 1. TRADI outperforms both methods, in terms of both RMSE and NLL. Aside from the proposed approach to tracking the weight distribution, we assume that an additional reason for which our technique outperforms the alternative methods resides in the sequential training (MSE and NLL) proposed in Sect. 2.3.

Table 1. Comparative results on regression benchmarks

4.3 Classification Experiments

For the classification task, we conduct experiments on two datasets. The first one is the MNIST dataset [34], which is composed of a training set containing 60k images and a testing set of 10k images, all of size \(28\times 28\). Here, we use a neural network with 3 hidden layers, each one containing 200 neurons, followed by ReLU non-linearities and BatchNorm, and fixed the learning rate \(\eta = 10^{-2}\). We share our results in Table 2. For the MNIST dataset, we generate \(N{_{\text {model}}} =20\) models, in order to ensure a fair comparison with Deep Ensembles. The evaluation underlines that in terms of performance TRADI is positioned between Deep Ensembles and MC Dropout. However, in contrast to Deep Ensembles our algorithm is significantly lighter because only a single model needs to be trained, while Deep Ensembles approximates the weight distribution by a very costly step of independent training procedures (in this case 20).

Table 2. Comparative results on image classification

We conduct the second experiment on CIFAR-10  [29], with WideResnet \(28\times 10\)   [55] as DNN. The chosen optimization algorithm is SGD, \(\eta =0.1\) and the dropout rate was fixed to 0.3. Due to the long time necessary for Deep Ensembles to train the DNNs we set \(N{_{\text {model}}} =15\). Comparative results on this dataset, presented in Table 2, allow us to make similar conclusions with experiments on the MNIST dataset.

4.4 Uncertainty Evaluation for Out-of-Distribution (OOD) Test Samples

In these experiments, we evaluate uncertainty on OOD classes. We consider four datasets, and the objective of these experiments is to evaluate to what extent the trained DNNs are overconfident on instances belonging to classes which are not present in the training set. We report results in Table 3.

Baselines. We compare against Deep Ensembles and MC Dropout, and propose two additional baselines. The first is the Maximum Classifier Prediction (MCP) which uses the maximum softmax value as prediction confidence and has shown competitive performance  [19, 20]. Second, we propose a baseline to emphasize the ability of TRADI to capture the distribution of the weights. We take a trained network and randomly perturb its weights with noise sampled from a Normal distribution. In this way we generate an ensemble of networks, each with different noise perturbations – we practically sample networks from the vicinity of the local minimum. We refer to it as Gaussian perturbation ensemble.

First we consider MNIST trained DNNs and use them on a test set composed of 10k MNIST images and 19k images from NotMNIST  [1], a dataset of instances of ten classes of letters. Standard DNNs will assign letter instances of NotMNIST to a class number with high confidence as shown in  [1]. For these OOD instances, our approach is able to decrease the confidence as illustrated in Fig. 3a, in which we represent the accuracy vs confidence curves as in  [31].

Table 3. Distinguishing in- and out-of-distribution data for semantic segmentation (CamVid, StreetHazards, BDD Anomaly) and image classification (MNIST/notMNIST)

The accuracy vs confidence curve is constructed by considering, for different confidence thresholds, all the test data for which the classifier reports a confidence above the threshold, and then by evaluating the accuracy on this data. The confidence of a DNN is defined as the maximum prediction score. We also evaluate the OOD uncertainty using AUC, AUPR and FPR-95%-TPR metrics, introduced in [20] and the ECE metrics introduced in [15]. These criteria characterize the quality of the prediction that a testing sample is OOD with respect to the training dataset. We also measured the computational training times of all algorithms implemented in PyTorch on a PC equipped with Intel Core i9-9820X and one GeForce RTX 2080 Ti and report them in Table 3. We note that TRADI DNN with 20 models provides incorrect predictions on such OOD samples with lower confidence than Deep Ensembles and MC Dropout.

Fig. 3.
figure 3

(a) and (b) Accuracy vs confidence plot on the MNIST \(\backslash \)notMNIST and CamVid experiments, respectively. (c) Calibration plot for the CamVid experiment.

In the second experiment, we train a Enet DNN [43] for semantic segmentation on CamVid dataset  [7]. During training, we delete three classes (pedestrian, bicycle, and car), by marking the corresponding pixels as unlabeled. Subsequently, we test with data containing the classes represented during training, as well as the deleted ones. The goal of this experiment is to evaluate the DNN behavior on the deleted classes which represent thus OOD classes. We refer to this setup as CamVid-OOD. In this experiment we use \(N{_{\text {model}}}=10\) models trained for 90 epochs with SGD and using a learning rate \(\eta =5\times 10^{-4}\). In Fig. 3b and 3c we illustrate the accuracy vs confidence curves and the calibration curves [15] for the CamVid experiment. The calibration curve as explained in [15] consists in dividing the test set into bins of equal size according to the confidence, and in computing the accuracy over each bin. Both the calibration and the accuracy vs confidence curves highlight whether the DNN predictions are good for different levels of confidence. However, the calibration provides a better understanding of what happens for different scores.

Finally, we conducted experiments on the recent OOD benchmarks for semantic segmentation StreetHazards  [19] and BDD Anomaly  [19]. The former consists of 5,125/1,031/1,500 (train/test-in-distribution/test-OOD) synthetic images  [9] with annotations for 12 classes for training and a 13th OOD class found only in the test-OOD set. The latter is a subset of BDD  [54] and is composed of 6,688/951/361 images, with the classes motorcycle and train as anomalous objects. We follow the experimental setup from  [19], i.e., PSPNet  [56] with ResNet50  [18] backbone. On StreetHazards, TRADI outperforms Deep Ensembles and on BDD Anomaly Deep Ensembles has best results close to the one of TRADI.

Results show that TRADI outperforms the alternative methods in terms of calibration, and that it may provide more reliable confidence scores. Regarding accuracy vs confidence, the most significant results for a high level of confidence, typically above 0.7, show how overconfident the network tends to behave; in this range, our results are similar to those of Deep Ensembles. Lastly, in all experiments TRADI obtains performances close to the best AUPR and AUC, while having a computational time /training time significantly smaller than Deep Ensembles.

Qualitative Discussion. In Fig. 4 we give as example a scene featuring the three OOD instances of interest (bike, car, pedestrian). Overall, MC Dropout outputs a noisy uncertainty map, but fails to highlight the OOD samples. By contrast, Deep Ensembles is overconfident, with higher uncertainty values mostly around the borders of the objects. TRADI uncertainty is higher on borders and also on pixels belonging to the actual OOD instances, as shown in the zoomed-in crop of the pedestrian in Fig. 4 (row 3).

Fig. 4.
figure 4

Qualitative results on CamVid-OOD. Columns: (a) input image and ground truth; (b)-(d) predictions and confidence scores by MC Dropout, Deep Ensembles, and TRADI. Rows: (1) input and confidence maps; (2) class predictions; (3) zoomed-in area on input and confidence maps

5 Conclusion

In this work we propose a novel technique for computing the epistemic uncertainty of a DNN. TRADI is conceptually simple and easy to plug to the optimization of any DNN architecture. We show the effectiveness of TRADI over extensive studies and compare against the popular MC Dropoutand the state of the art Deep Ensembles. Our method exhibits an excellent performance on evaluation metrics for uncertainty quantification, and in contrast to Deep Ensembles, for which the training time depends on the number of models, our algorithm does not add any significant cost over conventional training times.

Future works involve extending this strategy to new tasks, e.g., object detection, or new settings, e.g., active learning. Another line of future research concerns transfer learning. So far TRADI is starting from randomly initialized weights sampled from a given Normal distribution. In transfer learning, we start from a pre-trained network where weights are expected to follow a different distribution. If we have access to the distribution of the DNN weights we can improve the effectiveness of transfer learning with TRADI.