Keywords

1 Introduction

Learning algorithms often proceed by minimizing a loss function that measures the discrepancy between a data distribution and a model distribution. Given a parametric model and a metric in probability space, the loss can be minimized by the Riemannian gradient descent method, also known as the natural gradient method. An important metric in this context is the Fisher-Rao information metric [4, 18], which induces the Fisher-Rao natural gradient [1]. Another important metric is the Wasserstein metric [15, 20], which induces the Wasserstein natural gradient [8, 9, 12, 14]. Natural gradient methods have numerous applications in learning; see, e.g., [2, 3, 10, 13, 16, 17].

In spite of having numerous theoretical advantages, applying natural gradient methods is often challenging. In particular, machine learning models usually have many parameters, making the direct computation of the parameter updates too costly. Each update requires to compute the Jacobi matrix of the model and the inverse of the metric tensor in parameter space. An alternative, implicit, way to formulate the update is via a proximal operator. Recently [11] proposed proximal methods as an approach to natural gradients and demonstrated their viability in state of the art generative modeling. The idea is to compute the proximity penalty in closed form over an approximation space. This results in a tractable iterative regularization for the parameter updates.

We develop this idea to obtain a general natural proximal method, and provide explicit formulas for the Fisher-Rao and the Wasserstein metrics. These serve three purposes: (i) The proximal operator and its approximation can enable efficient and effective expressions for the time discretized parameter updates of the natural gradient flow. (ii) The proximal method, as an implicit method, naturally regularizes the objective function, and can be used to optimize non-smooth objective functions. (iii) The metric regularization is expressed in terms of statistics, such as mean and variance, and can be estimated from samples.

2 Natural Proximal Gradient

We review the natural gradient flow in a statistical manifold with Wasserstein and Fisher-Rao metrics, present the natural proximal operators, and introduce a systematic approximation which is suitable for estimation from samples.

2.1 Natural Gradients Flows

Learning problems are often formulated as the minimization of a loss function, as \(\min _{\theta \in \varTheta } F(\theta )\), where \(\varTheta \in \mathbb {R}^d\) is the parameter of the hypothesis class, and \(F:\varTheta \rightarrow \mathbb {R}\) is the loss function. As the hypothesis class, we consider a parametrized probability model \(\rho :\varTheta \rightarrow \mathcal {P}(\varOmega )\), where \(\varOmega \) is the sample space, which is a discrete or continuous set on which the distributions are supported. The loss is usually a divergence (sometimes distance) function between the empirical data distribution \(\hat{\rho }_{\text {data}}\) and the model distribution \(\rho _\theta \).

To find a minimizer, the gradient flow approach is often considered. This flow follows the steepest descent direction of the loss function with respect to a given Riemannian metric. In general, this is defined by

$$\begin{aligned} \dot{\theta }(t)=-G(\theta (t))^{-1}\nabla _\theta F(\theta (t)), \end{aligned}$$
(1)

where \(G(\theta )\in \mathbb {R}^{d\times d}\) is the matrix representation of the Riemannian metric tensor (for our choice of coordinates), and \(\nabla _\theta = (\frac{\partial }{\partial \theta _1},\ldots , \frac{\partial }{\partial \theta _d})^\top \) is the standard (Euclidean) gradient operator. In the context of probability distributions, the metric \(G(\theta )\) is pulled back from a natural metric structure on probability space. This implies that for any choice of the parametrization, (1) defines the same flow of probability distributions. Hence it is said to be parametrization invariant.

We will focus on two important statistical metrics on probability space: the Wasserstein metric and the Fisher-Rao metric. These metrics induce the following metric tensors in parameter space. We write \((\cdot ,\cdot )\) for the Euclidean or \(L^2\) inner product on the sample space \(\varOmega \) (which might be continuous or discrete).

Definition 1

(Statistical metric tensor on parameter space). Consider the probability space \((\mathcal {P}(\varOmega ), g)\) with metric tensor g, and a smoothly parametrized probability model \(\rho _\theta \) with parameter \(\theta \in \varTheta \). Then the pull-back G of g is given by

$$\begin{aligned} G(\theta )=\Big (\nabla _\theta \rho _\theta , g(\rho _\theta ) \nabla _\theta \rho _\theta \Big ). \end{aligned}$$
  1. (i)

    If \(g_\theta =-(\varDelta _{\rho _\theta })^{-1}\), with \(\varDelta _{\rho _\theta }=\nabla \cdot (\rho _\theta \nabla )\) being the weighted elliptic operator [6, 7, 15], then \(G(\theta )\) is the Wasserstein metric tensor, given by

    $$\begin{aligned} G_W(\theta )_{ij}=\Big (\nabla _{\theta _i}\rho _{\theta }, (-\varDelta _{\rho _\theta })^{-1}\nabla _{\theta _j}\rho _\theta \Big ), \end{aligned}$$
  2. (ii)

    If \(g_\theta =\frac{1}{\rho _\theta }\), then \(G(\theta )\) is the Fisher-Rao metric tensor, given by

    $$\begin{aligned} G_{FR}(\theta )_{ij}=\Big (\nabla _{\theta _i}\rho _{\theta }, \frac{1}{\rho _\theta }\nabla _{\theta _j}\rho _\theta \Big ). \end{aligned}$$

Given a metric tensor on parameter space, the standard approach for numerical computation of the gradient flow (1) is the forward Euler method, i.e.,

$$\begin{aligned} \theta ^{k+1}=\theta ^k-h G(\theta ^k)^{-1}\nabla _\theta F(\theta ^k), \end{aligned}$$

where \(h>0\) is a step-size. This is known as the natural gradient descent method [2]. In practice, we need to compute the matrix \(G(\theta )\) and its inverse at each parameter update, which is difficult in high dimensional parameter spaces.

2.2 Natural Proximal Operators

We next present another way to approximate the gradient flow, known as the backward Euler or proximal operator method. The proximal operator refers to

$$\begin{aligned} \theta ^{k+1}=\text {Prox}_{hF}(\theta ^k)=\arg \min _{\theta }~F(\theta )+\frac{D(\theta , \theta ^k)}{2h}, \end{aligned}$$
(2)

where D is a proximity term that penalizes the distance from the current point, and h adjusts the strength. When h is infinity, the proximal operator returns the global minimizer of F. The proximity term is given by the metric function:

$$\begin{aligned} \begin{aligned} D(\theta , \theta ^k)=&\inf _{\theta (t)}\Big \{\int _0^1 \dot{\theta }(t)^{\top }G(\theta (t))\dot{\theta }(t)dt:\theta _0=\theta ,~\theta _1=\theta ^k\Big \}\\ =&\inf _{\theta (t)}\Big \{\int _0^1 (\partial _t\rho _{\theta (t)},g(\rho _{\theta (t)})\partial _t\rho _{\theta (t)}) dt:\theta _0=\theta ,~\theta _1=\theta ^k\Big \}. \end{aligned} \end{aligned}$$
(3)

In rare cases, the proximal operator (2) can be written explicitly.

We shall approximate D in a way that allows for a more friendly computation of the proximal operator. Consider the iterative proximal update

$$\begin{aligned} \theta ^{k+1}=\arg \min _{\theta }~F(\theta )+\frac{1}{2h}\Big ( \rho _\theta -\rho _{\theta ^k} , g(\rho _{\tilde{\theta }})(\rho _\theta -\rho _{\theta ^k})\Big ), \end{aligned}$$
(4)

where \(\tilde{\theta }=\frac{\theta +\theta ^k}{2}\). Here the D term in (2) is replaced by a mid-point expression, which is exact up to the order \(o(\Vert \theta -\theta ^k\Vert ^2)\). This new proximal operator corresponds to a numerical method known as the semi-backward Euler method. Both (2) and (4) are time discretizations of (1) with first order accuracy. We shall focus on (4), and derive a tractable approximation of the regularization term.

3 Affine Space Approximation of the Metric

Consider the proximity term (similar to a squared Mahalanobis distance)

$$\begin{aligned} \tilde{D}(\theta , \theta ^k)=\Big ( \rho _\theta -\rho _{\theta ^k} , g(\rho _{\tilde{\theta }})(\rho _\theta -\rho _{\theta ^k})\Big ). \end{aligned}$$
(5)

In the following we derive an explicit and computer friendly approximation. To this end, we first consider the variational formulation

$$\begin{aligned} \frac{1}{2}\tilde{D}(\theta , \theta ^k)=\sup _{\varPhi :\varOmega \rightarrow \mathbb {R}}(\varPhi , \rho _\theta -\rho _{\theta ^k})-\frac{1}{2}\Big ( \varPhi , g(\rho _{\tilde{\theta }})^{\mathcal {\dagger }}\varPhi \Big ), \end{aligned}$$
(6)

where \({}^{\mathcal {\dagger }}\) is the pseudo-inverse operator and the maximizer \(\varPhi =g(\rho _{\tilde{\theta }})(\rho _\theta -\rho _{\theta ^k})\) recovers the previous formula. This corresponds to a expressing (5) in terms of its Legendre dual between tangent space and cotangent space in probability space; for a discussion see [7].

Now we restrict the optimization domain (i.e., the set of functions \(\varPhi :\varOmega \rightarrow \mathbb {R}\)) to an affine space of functions of the form

$$\begin{aligned} \mathcal {F}_\varPsi =\Big \{\varPhi (x) = \sum _{j=1}^n\xi _j \psi _j(x)=\xi ^{\top }\varPsi (x) :\xi \in \mathbb {R}^n\Big \}, \end{aligned}$$

where \(\xi =(\xi _j)_{j=1}^n\) is a parameter vector and \(\varPsi =(\psi _j)_{j=1}^n\) collects a choice of basis functions \(\psi _j:\varOmega \rightarrow \mathbb {R}\). This results in following optimization problems:

  1. (i)

    For the Wasserstein metric, we have

    $$\begin{aligned} \begin{aligned} \frac{1}{2}\tilde{D}^{W}_\varPsi (\theta , \theta ^k) =&\sup _{\varPhi = \xi ^\top \varPsi } \mathbb {E}_{\theta }[\varPhi ]-\mathbb {E}_{\theta ^k}[\varPhi ]- \frac{1}{2} \mathbb {E}_{\tilde{\theta }}[\Vert \nabla \varPhi \Vert ^2]; \end{aligned} \end{aligned}$$
  2. (ii)

    For the Fisher-Rao metric, we have

    $$\begin{aligned} \begin{aligned} \frac{1}{2}\tilde{D}^{FR}_\varPsi (\theta ,\theta ^k) =&\sup _{\varPhi = \xi ^\top \varPsi } \mathbb {E}_{\theta }[\varPhi ]-\mathbb {E}_{\theta ^k}[\varPhi ]- \frac{1}{2} \mathbb {E}_{\tilde{\theta }}\Big [(\varPhi -\mathbb {E}_{\tilde{\theta }}[\varPhi ])^2\Big ]. \end{aligned} \end{aligned}$$

These are quadratic semi-definite programs in \(\xi \). In practice, if using small sample estimates for the expectations, one can add a regularization \(-\lambda \Vert \xi \Vert ^2\), with a small \(\lambda >0\), to ensure strict definiteness and existence of a solution. We proceed to solve these problems. We write \(\mathbb {E}_\theta [\psi ] = \mathbb {E}_{x\sim \rho _\theta }[\psi (x)]\) and \(\partial _l = \frac{\partial }{\partial x_l}\) for the partial derivative w.r.t. the lth sample space variable.Footnote 1

Theorem 1

(Affine space approximation). Given a basis \(\varPsi \), the proximity term \(\tilde{D}\) within the affine function space \(\mathcal {F}_\varPsi =\{\xi ^\top \varPsi :\xi \in \mathbb {R}^n\}\) is given by

$$\begin{aligned} \tilde{D}_\varPsi (\theta , \theta ^k) = (\mathbb {E}_{\theta }[\varPsi ] -\mathbb {E}_{{\theta ^k}}[\varPsi ])^\top \Big (\varPsi , g(\rho _\theta )^{\mathcal {\dagger }}\varPsi \Big )^{\mathcal {\dagger }}(\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ]). \end{aligned}$$
  1. (i)

    For the Wasserstein metric, we have

    $$\begin{aligned} \tilde{D}^W_\varPsi (\theta , \theta ^k) = (\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ])^\top \Big (\mathfrak {C}^W(\tilde{\theta })\Big )^{-1}(\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ]), \end{aligned}$$

    where \(\mathfrak {C}^W(\tilde{\theta }) = \mathbb {E}_{{\tilde{\theta }}}[\sum _l \Big (\partial _{l} \varPsi \Big ) \Big (\partial _{l}\varPsi \Big )^\top ]\).

  2. (ii)

    For the Fisher-Rao metric, we have

    $$\begin{aligned} \tilde{D}^{FR}_\varPsi (\theta , \theta ^k) = (\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ])^\top \Big (\mathfrak {C}^{FR}(\tilde{\theta }) \Big )^{-1}(\mathbb {E}_{\theta }[\varPsi ] -\mathbb {E}_{{\theta ^k}}[\varPsi ]), \end{aligned}$$

    where \(\mathfrak {C}^{FR}(\tilde{\theta }) = \mathbb {E}_{{\tilde{\theta }}}[\Big (\varPsi (x)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ]\Big ) \Big (\varPsi (x)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ]\Big )^\top ]\).

Fig. 1.
figure 1

Illustration of the proximity term over an affine space. Intuitively, the metric between two distributions is measured along a chosen set of statistics.

Remark 1

The matrix \(\mathfrak {C}\) has size \(n\times n\), corresponding to the dimension of \(\varPsi \). For the Fisher-Rao metric, it is the covariance of the basis functions \(\varPsi \) w.r.t. \(\rho _{\tilde{\theta }}\). This corresponds to the Fisher-Rao matrix when the basis is a sufficient statistics of the model. See Fig. 1. The resulting metric bears a similarity with the Relative Fisher Information Metric approach proposed in [19]. Similar observations apply for the Wasserstein metric.

Remark 2

In the case of implicit generative models (used in GANs), where \(\rho _\theta \) is expressed as the push-forward measure of a latent variable z by a parametrized family of functions \(\mathfrak {g}_\theta \), we obtain

$$\begin{aligned} \tilde{D}(\theta , \theta ^k)=(\mathbb {E}_z[\varPsi (\mathfrak {g}_\theta (z))] -\mathbb {E}_{z} [\varPsi (\mathfrak {g}_{\theta ^k}(z))])^\top \mathbb {E}_{z}[C(\mathfrak {g}_{\tilde{\theta }}(z))]^{-1}(\mathbb {E}_z[\varPsi (\mathfrak {g}_\theta (z))]-\mathbb {E}_{z} [\varPsi (\mathfrak {g}_{\theta ^k}(z))]), \end{aligned}$$

where C is the corresponding term inside the expectation in Theorem 1.

Proof

(i) For the constrained Wasserstein metric, the gradient of \(\varPhi \) w.r.t. the sample space variable x is \(\nabla \varPhi (x) = (\sum _{i=1}^n \xi _i \partial _l\psi _i(x))_l\). The squared norm is then

$$\begin{aligned} \Vert \nabla \varPhi (x)\Vert ^2 = \sum _l (\sum _i \xi _i \partial _l \psi _i(x))^2 = \sum _l \sum _i \xi _i \partial _l \psi _i(x) \sum _j \xi _j \partial _l \psi _j(x) = \xi ^\top C^W(x) \xi , \end{aligned}$$

where \(C^W_{ij}(x) = \sum _l\partial _l\psi _i(x) \partial _l\psi _j(x)\). Now we consider the distance

$$\begin{aligned} \begin{aligned} \frac{1}{2}\tilde{D}^W_\varPsi (\theta , \theta ^k) =&\sup _{\varPhi = \xi ^\top \varPsi } \Big ( \varPhi , \rho _\theta -\rho _{\theta ^k}\Big ) - \frac{1}{2} \Big ( (\nabla \varPhi )^2 , \rho _{\tilde{\theta }}\Big ) \\ =&\sup _{\xi } \xi ^\top (\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ]) -\frac{1}{2} \xi ^\top \mathbb {E}_{\tilde{\theta }}[C^W] \xi . \end{aligned} \end{aligned}$$

In turn, by first order optimality conditions, at the maximizer we have

$$\begin{aligned} \xi ^*=(\mathbb {E}_{\tilde{\theta }}[C^W])^{-1} (\mathbb {E}_{\theta }[\varPsi ] - \mathbb {E}_{{\theta ^k}}[\varPsi ]). \end{aligned}$$

Thus \(\tilde{D}^W_\varPsi (\theta , \theta ^k) = (\mathbb {E}_{\theta }[\varPsi ] -\mathbb {E}_{{\theta ^k}}[\varPsi ])(\mathbb {E}_{\tilde{\theta }}[C^W])^{-1} (\mathbb {E}_{\theta }[\varPsi ] -\mathbb {E}_{{\theta ^k}}[\varPsi ])\).

(ii) For the Fisher-Rao metric, the term \(\Vert \varPhi (z)- \mathbb {E}_{{\tilde{\theta }}}[\varPhi ]\Vert ^2\) equals

$$\begin{aligned} \begin{aligned} \Vert \xi ^\top \varPsi (z)- \xi ^\top \mathbb {E}_{{\tilde{\theta }}}[\varPsi ]\Vert ^2 = \xi ^{\top }(\varPsi (z)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ])(\varPsi (z)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ])^{\top }\xi = \xi ^\top C^{FR}(z)\xi , \end{aligned} \end{aligned}$$

where \(C^{FR}(z) = (\varPsi (z)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ])(\varPsi (z)-\mathbb {E}_{{\tilde{\theta }}}[\varPsi ])^{\top }\).    \(\square \)

Example 1

(Order 1 approximation). For the metric approximation with the space of linear functions, \(\mathcal {F}_1=\Big \{\varPhi (x)=a^\top x+b :a\in \mathbb {R}^m,~b\in \mathbb {R} \Big \}\), we have:

  1. (i)
    $$\begin{aligned} \tilde{D}^W_1(\theta , \theta ^k)=(\mathbb {E}_\theta [x]-\mathbb {E}_{\theta ^k}[x])^{\top }(\mathbb {E}_\theta [x]-\mathbb {E}_{\theta ^k}[x]). \end{aligned}$$
  2. (ii)
    $$\begin{aligned} \tilde{D}^{FR}_1(\theta , \theta ^k)=(\mathbb {E}_\theta [x] - \mathbb {E}_{\theta ^k}[x])^{\top }\Big (\mathbb {E}_{\tilde{\theta }}\Big [(x-\mathbb {E}_{\tilde{\theta }}x)(x-\mathbb {E}_{\tilde{\theta }}x)^\top \Big ]\Big )^{-1}(\mathbb {E}_\theta [x] - \mathbb {E}_{\theta ^k}[x]). \end{aligned}$$

Example 2

(Order 2 approximation). For the space of quadratic functions, \(\mathcal {F}_2=\Big \{\varPhi (x)=\frac{1}{2}x^\top Q x+a^\top x+b:Q \in \mathbb {R}^{m\times m},~a\in \mathbb {R}^m,~b\in \mathbb {R}\Big \}\), we have:

  1. (i)
    $$\begin{aligned} \tilde{D}^W_2(\theta , \theta ^k)= \Big (\mathbb {E}_\theta \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\theta ^k} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big )^{\top } \mathbb {E}_{\tilde{\theta }}\left[ \!{\begin{matrix} I_m &{} x^\top \otimes I_m\\ x \otimes I_m &{} I_m \otimes xx^\top \end{matrix}}\!\right] ^{-1} \Big (\mathbb {E}_\theta \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\theta ^k} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big ) . \end{aligned}$$
  2. (ii)
    $$\begin{aligned} \tilde{D}^{FR}_2(\theta , \theta ^k)= \Big (\mathbb {E}_\theta \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\theta ^k} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big )^{\top } \Big (\mathfrak {C}^{FR}(\tilde{\theta }) \Big )^{-1} \Big (\mathbb {E}_\theta \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\theta ^k} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big ), \end{aligned}$$

    where \(\otimes \) is the Kronecker product (e.g., \(x\otimes x\) is an \(m^2\times 1\) vector), and

    $$\begin{aligned} \mathfrak {C}^{FR}= \mathbb {E}_{\tilde{\theta }}\begin{bmatrix} \Big (\left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\tilde{\theta }} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big ) \Big (\left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] -\mathbb {E}_{\tilde{\theta }} \left[ \!{\begin{matrix}x\\ \frac{x\otimes x}{2}\end{matrix}}\!\right] \Big )^\top \end{bmatrix}. \end{aligned}$$

4 Numerical Examples

The optimization loop can be implemented as shown in Algorithm 1. Here the proximal operator is computed by a short gradient iteration. In practice we can replace the expectations by sample averages, \(\mathbb {E}_\theta [f]\approx \frac{1}{N}\sum _{i=1}^Nf(x^{(i)})\), with \(x^{(i)}\) i.i.d. from \(\rho _\theta \). For the basis \(\varPsi \) we can choose low order polynomials, as in Examples 1 and 2, but even random functions worked well in our experiments. The optimal choice will balance low dimension and relevant statistics for the model under consideration. Orthogonality tends to be beneficial.

figure a

4.1 Maximum Likelihood Estimation for Hierarchical Models

We consider binary k-interaction models, which are exponential families \(\rho _\theta (x) = \exp (\theta ^\top A(x))/Z(\theta )\), \(x\in \{0,1\}^m\), with sufficient statistics \(A_\lambda (x) = \prod _{i\in \lambda }(-1)^{x_i}\), for \(\lambda \subseteq \{1,\ldots , m\}\), \(|\lambda | \leqslant k\). We use \(\varPsi _j(x)=(-1)^{x_j}\), \(j\in \{1,\ldots , m\}\), which are sufficient statistics for the 1-interaction model (independence model). We draw target distributions uniformly from the simplex and compute the MLEs. We compare Euclidean, Fisher-Rao, Wasserstein, and proximals. For each problem and method we run grid search over the step size \(\alpha \) and proximal strength h, which are kept fixed during optimization. The results are shown in Fig. 2.

Fig. 2.
figure 2

Left: MLE wall-clock computation times until the KL-divergence is within \(10^{-9}\) of optimal, for 4 binary variables and \(\varPsi \) the independence model, and typical optimization curves. Right: The learning curves for the image classification task on CIFAR-10. Each experiment was averaged over 5 runs. The bold lines represent the average, and the envelopes are the minimum and maximum achieved.

4.2 Classification on CIFAR-10

Here we present an image classification task on the CIFAR-10 dataset [5] using the Wasserstein proximal method. We use a simple CNN with two convolutional layers followed by two fully-connected layers, with ReLU activations. In this experiment F is the categorical cross-entropy loss and \(D = \tilde{D}_{\varPsi }^W\) is the Order 1 or Order 2 Wasserstein approximation. The specific details of our experiments can be found in Appendix A. Figure 2 provides the results, where we give curves for the validation error per epoch. As a baseline, we also give results when performing SGD many times per epoch, but without regularization. We see that the best result comes from the Order 2 Wasserstein distance approximation.

5 Discussion

We studied sampling–friendly implementations of the natural gradient based on the proximal operator. We approximate the proximity penalty by an affine space restriction in the Legendre dual formulation. This gives rise to a lower dimensional metric, expressed in expectation parameters, which can be estimated from samples. We cover both Fisher-Rao and Wasserstein metrics. Especially for the Wasserstein proximal, our method offers significant savings in computation time and provide improvement in validation error (in CIFAR-10 classification).