1 Introduction

The statistical distance between histograms plays a fundamental role in statistics and machine learning. It provides the geometric structure on statistical manifolds [3]. Learning problems usually correspond to minimizing a loss function over these manifolds. An important example is the Fisher–Rao metric on the probability simplex, which has been studied especially within the field of information geometry [3, 6]. A classic result due to Chentsov [11] characterizes this Riemannian metric as the only one, up to scaling, that is invariant with respect to natural statistical embeddings by Markov morphisms (see also [9, 21, 32]). Using the Fisher–Rao metric, a natural Riemannian gradient descent method is introduced [2]. This natural gradient has found numerous successful applications in machine learning (see, e.g., [1, 27, 35, 36, 40]).

Optimal transport provides another statistical distance, named Wasserstein or Earth Mover’s distance. In recent years, this metric has attracted increasing attention within the machine learning community [5, 16, 31]. One distinct feature of optimal transport is that it provides a distance among histograms that incorporates a ground metric on sample space. The \(L^2\)-Wasserstein distance has a dynamical formulation, which exhibits a metric tensor structure. The set of probability densities with this metric forms an infinite-dimensional Riemannian manifold, named density manifold [20]. The gradient descent method in the density manifold, called Wasserstein gradient flow, has been widely studied in the literature; see [34, 38] and references.

A question intersecting optimal transport and information geometry arises: What is the natural Wasserstein gradient descent method on the parameter space of a statistical model? In optimal transport, the Wasserstein gradient flow is studied on the full space of probability densities, and shown to have deep connections with the ground metrics on sample space deriving from physics [33], fluid mechanics [10] and differential geometry [25]. We expect that these relations also exist on parametrized probability models, and that the Wasserstein gradient flow can be useful in the optimization of objective functions that arise in machine learning problems. By incorporating a ground metric on sample space, this method can serve to implement useful priors in the learning algorithms.

We are interested in developing synergies between the information geometry and optimal transport communities. In this paper, we take a natural first step in this direction. We introduce the Wasserstein natural gradient flow on the parameter space of probability models with discrete sample spaces. The \(L^2\)-Wasserstein metric on discrete states was introduced in [12, 26, 29]. Following the settings from [13, 14, 17, 22], the probability simplex forms the Riemannian manifold called Wasserstein probability manifold. The Wasserstein metric on the probability simplex can be pulled back to the parameter space of a probability model. This metric allows us to define a natural Wasserstein gradient method on parameter space.

We note that one finds several formulations of optimal transport for continuous sample spaces. On the one hand, there is the static formulation, known as Kantorovich’s linear programming [38]. Here, the linear program is to find the minimal value of a functional over the set of joint measures with given marginal histograms. The objective functional is given as the expectation value of the ground metric with respect to a joint probability density measure. On the other hand, there is the dynamical formulation, known as the Benamou-Brenier formula [8]. This dynamic formulation gives the metric tensor for measures by lifting the ground metric tensor of sample spaces. Both static and dynamic formulations are equivalent in the case of continuous state spaces. However, the two formulations lead to different metrics in the simplex of discrete probability distributions. The major reason for this difference is that the discrete sample space is not a length space.Footnote 1 Thus the equivalence result in classical optimal transport is no longer true in the setting of discrete sample spaces. We note that for the static formulation, there is no Riemannian metric tensor for the discrete probability simplex. See [14, 26] for a detailed discussion.

In the literature, the exploration of connections between optimal transport and information geometry was initiated in [4, 18, 39]. These works focus on the distance function induced by linear programming on discrete sample spaces. As we pointed out above, this approach can not cover the Riemannian and differential structures induced by optimal transport. In this paper, we use the dynamical formulation of optimal transport to define a Riemannian metric structure for general statistical manifolds. With this, we obtain a natural gradient operator, which can be applied to any optimization problem over a parameterized statistical model. In particular, it is applicable to maximum likelihood estimation. Other works have studied the Gaussian family of distributions with \(L^2\)-Wasserstein metric [30, 37]. In that particular case, the constrained optimal transport metric tensor can be written explicitly and the corresponding density submanifold is a totally geodesic submanifold. In contrast to those works, our discussion is applicable to arbitrary parametric models.

This paper is organized as follows. In Sect. 2 we briefly review the Riemannian manifold structure in probability space introduced by optimal transport in the cases of continuous and discrete sample spaces. In Sect. 3 we introduce Wasserstein statistical manifolds by isometric embedding into the probability manifold, and in Sect. 4 we derive the corresponding gradient flows. In Sect. 5 we discuss a few examples.

2 Optimal transport on continuous and discrete sample spaces

In this section, we briefly review the results of optimal transport. We introduce the corresponding Riemannian structure for simplices of probability distributions with discrete support.

2.1 Optimal transport on continuous sample space

We start with a review of the optimal transport problem on continuous spaces. This will guide our discussion of the discrete state case. For related studies, we refer the reader to [20, 38] and the many references therein.

Denote the sample space by \((\Omega , g^\Omega )\). Here \(\Omega \) is a finite dimensional smooth Riemannian manifold, for example, \(\mathbb {R}^d\) or the open unit ball therein. Its inner product is denoted by \(g^\Omega \) and its volume form by dx. Denote the geodesic distance of \(\Omega \) by \(d_\Omega :\Omega \times \Omega \rightarrow \mathbb {R}_+\).

Consider the set \(\mathcal {P}_2(\Omega )\) of Borel measurable probability density functions on \(\Omega \) with finite second moment. Given \(\rho ^0, \rho ^1\in \mathcal {P}_2(\Omega )\), the \(L^2\)-Wasserstein distance between \(\rho ^0\) and \(\rho ^1\) is denoted by \(W:\mathcal {P}(\Omega )\times \mathcal {P}(\Omega )\rightarrow \mathbb {R}_+\). There are two equivalent ways of defining this distance. On one hand, there is the static formulation. This refers to the following linear programming problem:

$$\begin{aligned} W(\rho ^0, \rho ^1)^2=\inf _{\pi \in \Pi (\rho ^0, \rho ^1)}\int _{\Omega \times \Omega }d_\Omega (x,y)^2\pi (dx,dy), \end{aligned}$$
(1)

where the infimum is taken over the set \(\Pi (\rho ^0, \rho ^1)\) of joint probability measures on \(\Omega \times \Omega \) that have marginals \(\rho ^0\), \(\rho ^1\).

On the other hand, the Wasserstein distance W can be written in a dynamic formulation, where a probability path \(\rho :[0,1]\rightarrow \mathcal {P}_2(\Omega )\) connecting \(\rho ^0\), \(\rho ^1\) is considered. This refers to a variational problem known as the Benamou-Brenier formula:

$$\begin{aligned} W(\rho ^0, \rho ^1)^2=\inf _{\Phi }~\int _0^1\int _{\Omega } g^\Omega _x(\nabla \Phi (t,x), \nabla \Phi (t,x))\rho (t,x) dx dt, \end{aligned}$$
(2a)

where the infimum is taken over the set of Borel potential functions \([0,1]\times \Omega \rightarrow \mathbb {R}\). Each potential function \(\Phi \) determines a corresponding density path \(\rho \) as the solution of the continuity equation

$$\begin{aligned} \frac{\partial \rho (t,x)}{\partial t}+\text {div} (\rho (t,x)\nabla \Phi (t,x))=0,\quad \rho (0,x)=\rho ^0(x),\quad \rho (1,x)=\rho ^1(x). \end{aligned}$$
(2b)

Here \(\text {div}\) and \(\nabla \) are the divergence and gradient operators in \(\Omega \). The continuity equation is well known in physics.

The equivalence of the static (1) and dynamic (2) formulations is well known (for continuous \(\Omega \)). For the reader’s convenience we give a sketch of proof in the appendix. In this paper we focus on the variational formulation (2). In fact, this formulation entails the definition of a Riemannian structure as we now discuss. For simplicity, we only consider the set of smooth and strictly positive probability densities

$$\begin{aligned} \mathcal {P}_+(\Omega )=\Big \{\rho \in C^{\infty }(\Omega ):\rho (x)>0,~\int _{\Omega }\rho (x)dx=1\Big \} \subset \mathcal {P}_2(\Omega ). \end{aligned}$$

Denote \( \mathcal {F}(\Omega ):=C^{\infty }(\Omega )\) the set of smooth real valued functions on \(\Omega \). The tangent space of \(\mathcal {P}_+(\Omega )\) is given by

$$\begin{aligned} T_\rho \mathcal {P}_+(\Omega ) = \Big \{\sigma \in \mathcal {F}(\Omega ):\int _{\Omega }\sigma (x) dx=0 \Big \}. \end{aligned}$$

Given \(\Phi \in \mathcal {F}(\Omega )\) and \(\rho \in \mathcal {P}_+(\Omega )\), define

$$\begin{aligned} V_{\Phi }(x):=-\text {div} (\rho (x) \nabla \Phi (x)). \end{aligned}$$

We assume the zero flux condition

$$\begin{aligned} \int _{\Omega }V_\Phi (x)dx =0. \end{aligned}$$

In view of the continuity equation, the zero flux condition is equivalent to requiring that \(\int _\Omega \frac{\partial \rho }{\partial t}dx = 0\), which means that the space integral of \(\rho \) is always 1. When \(\Omega \) is compact without boundary, this is automatically satisfied. This is also true when \(\Omega = \mathbb {R}^d \) and \(\rho \) has finite second moment. Thus \(V_\Phi \in T_{\rho }\mathcal {P}_+(\Omega )\). The elliptic operator \(\nabla \cdot (\rho \nabla )\) identifies the function \(\Phi \) on \(\Omega \) modulo additive constants with a tangent vector \(V_{\Phi }\) of the space of densities (for more details see [20, 25]). This gives an isomorphism

$$\begin{aligned} \mathcal {F}(\Omega )/\mathbb {R}\rightarrow T_{\rho }\mathcal {P}_+(\Omega ); \quad \Phi \mapsto V_\Phi . \end{aligned}$$

Define the Riemannian metric (inner product) on the tangent space of positive densities \(g^W:{T_\rho }\mathcal {P}_+(\Omega )\times {T_\rho }\mathcal {P}_+(\Omega )\rightarrow \mathbb {R}\) by

$$\begin{aligned} g^W_\rho (V_{\Phi }, V_{\tilde{\Phi }})=\int _{\Omega }g^\Omega _x(\nabla \Phi (x), \nabla \tilde{\Phi }(x))\rho (x) dx, \end{aligned}$$

where \(\Phi (x)\), \(\tilde{\Phi }(x)\in \mathcal {F}(\Omega )/\mathbb {R}\). This inner product endows \(\mathcal {P}_+(\Omega )\) with an infinite dimensional Riemannian metric tensor. In other words, the variational problem (2) is a geometric action energy in \((\mathcal {P}_+(\Omega ), g^W)\) in the sense of [8, 25]. In literature [20], \((\mathcal {P}_+(\Omega ), g^W)\) is called density manifold.

2.2 Dynamical optimal transport on discrete sample spaces

We translate the dynamical perspective from the previous section to discrete state spaces, i.e., we replace the continuous space \(\Omega \) by a discrete space \(I=\{1,\ldots , n\}\).

To encode the metric tensor of discrete states, we first need to introduce a ground metric notion on sample space. We do this in terms of a graph with weighted edges, \(G=(V, E, \omega )\), where \(V=I\) is the vertex set, E is the edge set, and \(\omega =(\omega _{ij})_{i,j\in I}\in \mathbb {R}^{n\times n}\) are the edge weights. These weights satisfy

$$\begin{aligned} \omega _{ij}= {\left\{ \begin{array}{ll} \omega _{ji}>0, &{} \text {if }(i,j)\in E\\ 0, &{} \text {otherwise} \end{array}\right. }. \end{aligned}$$

As mentioned above, the weights encode the ground metric on the discrete state space. More precisely, we write

$$\begin{aligned} \omega _{ij}=\frac{1}{(d^G_{ij})^2},\quad \text {if }(i,j)\in E, \end{aligned}$$
(3)

where \(d^G_{ij}\) represents the distance or ground metric between states i and j. The set of neighbors or adjacent vertices of i is denoted by \(N(i)=\{j\in V:(i,j)\in E\}\).

The probability simplex supported on the vertices of G is defined by

$$\begin{aligned} \mathcal {P}(I) = \Big \{(p_1,\ldots , p_n)\in \mathbb {R}^n :\sum _{i=i}^n p_i=1,\quad p_i\ge 0\Big \}. \end{aligned}$$

Here \(p=(p_1,\ldots , p_n)\) is a probability vector with coordinates \(p_i\) corresponding to the probabilities assigned to each node \(i\in I\). We denote the relative interior of the probability simplex by \(\mathcal {P}_+(I)\). This consists of the strictly positive probability distributions, \(p\in \mathcal {P}(I)\) with \(p_i>0\), \(i\in I\).

Next we introduce the variational problem (2) on discrete states. First we need to define the “metric tensor” on graphs. A vector field \(v=(v_{ij})_{i,j\in V}\in \mathbb {R}^{n\times n}\) on G is a skew-symmetric matrix:

$$\begin{aligned} v_{ij}={\left\{ \begin{array}{ll}-v_{ji}, &{} \text {if }(i,j)\in E\\ 0, &{} \text {otherwise} \end{array}\right. }. \end{aligned}$$

A potential function \(\Phi =(\Phi _i)_{i=1}^n\in \mathbb {R}^{n}\) defines a gradient vector field \(\nabla _G\Phi =(\nabla _G\Phi _{ij})_{i,j\in V}\in \mathbb {R}^{n\times n}\) on the graph G by the finite differences

$$\begin{aligned} \nabla _G\Phi _{ij}={\left\{ \begin{array}{ll}\sqrt{\omega _{ij}}(\Phi _i-\Phi _j) &{} \text {if }(i,j)\in E\\ 0 &{} \text {otherwise} \end{array}\right. }. \end{aligned}$$

Here we use \(\sqrt{\omega }\) rather than \(1/d^G\) for simplicity of notations. In this way, we can represent the gradient, divergence, and Laplacian matrix in a multiplicity of weight, instead of dividing the ground metric.

We define an inner product of vector fields \(v_{ij}\), \(\tilde{v}_{ij}\) at each state \(i \in I\) by

$$\begin{aligned} g^I_i(v, \tilde{v}) := \frac{1}{2} \sum _{j\in N(i)}v_{ij}\tilde{v}_{ij}. \end{aligned}$$

In particular, the gradient vector field \(\nabla _G\Phi \) defines a kinetic energy at each state \(i \in I\) by

$$\begin{aligned} g^I_i(\nabla _G\Phi , \nabla _G\Phi ) := \frac{1}{2} \sum _{j\in N(i)} (\Phi _i -\Phi _j)^2 \omega _{ij}. \end{aligned}$$

We next define the expectation value of kinetic energy with respect to a probability distribution p:

$$\begin{aligned} \begin{aligned} (\nabla _G\Phi ,\nabla _G \Phi )_p&:= \sum _{i\in I} p_i\; g^I_i(\nabla _G\Phi , \nabla _G\Phi )=\frac{1}{2}\sum _{(i,j)\in E}\omega _{ij}(\Phi _i-\Phi _j)^2\frac{p_i+p_j}{2}. \end{aligned} \end{aligned}$$

This can also be written as

$$\begin{aligned} (\nabla _G\Phi ,\nabla _G \Phi )_p=\sum _{i=1}^n\Phi _i\sum _{j\in N(i)}{\omega _{ij}}(\Phi _i-\Phi _j)\frac{p_i+p_j}{2}=\Phi ^{\mathsf {T}}\big (-\text {div}_G(p\nabla _G\Phi )\big ), \end{aligned}$$

where

$$\begin{aligned} -\text {div}_G(p \nabla _G\Phi ):= \Bigl (\sum _{j\in N(i)}\omega _{ij}(\Phi _i-\Phi _j)\frac{p_i+p_j}{2}\Bigr )_{i\in I}. \end{aligned}$$
(4)

There are two definitions hidden in (4). First, \({\text {div}}_G:\mathbb {R}^{n\times n}\rightarrow \mathbb {R}^n\) maps any given vector field m on the graph G to a potential function

$$\begin{aligned} \text {div}_G(m)=\bigg (\sum _{j\in N(i)}\sqrt{\omega _{ij}}m_{ji}\bigg )_{i\in I}. \end{aligned}$$

Second, the probability weighted gradient vector field \(m=p\nabla _G\Phi \) defined by

$$\begin{aligned} m_{ij}= {\left\{ \begin{array}{ll} \frac{p_i+p_j}{2}(\Phi _i-\Phi _j)\sqrt{\omega _{ij}}, &{} \text {if }(i,j)\in E\\ 0, &{} \text {otherwise} \end{array}\right. }, \end{aligned}$$

where \(\frac{p_i+p_j}{2}\) represents the probability weight on the edge \((i,j)\in E\).

We are now ready to introduce the \(L^2\)-Wasserstein metric on \(\mathcal {P}_+(I)\).

Definition 1

For any \(p^0\), \(p^1\in \mathcal {P}_+(I)\), define the Wasserstein distance \(W:\mathcal {P}_+(I)\times \mathcal {P}_+(I)\rightarrow \mathbb {R}\) by

$$\begin{aligned} W(p^0,p^1)^2:= \inf _{p(t), \Phi (t)}~\Bigg \{\int _0^1(\nabla _G \Phi (t), \nabla _G \Phi (t))_{p(t)} dt\Bigg \}. \end{aligned}$$

Here the infimum is taken over pairs \((p(t), \Phi (t))\) with \(p\in H^1((0,1), \mathbb {R}^{n})\) and \(\Phi :[0, 1]\rightarrow \mathbb {R}^n\) measurable, satisfying

$$\begin{aligned} \dot{p}(t)+{\text {div}}_G(p(t) \nabla _G\Phi (t))=0,~ p(0)=p^0,~p(1)=p^1. \end{aligned}$$

Remark 1

It is worth mentioning that the metric given in Definition 1 is different from the metric defined by linear programming. In other words, denote the distance \(d^G(i,j)\) between two vertices i and j as the length of a shortest (ij)-path. If \((i,j)\in E\), then \(d^G(i,j)\) is same as the ground metric defined in (3). Then

$$\begin{aligned} \big (W(p^0, p^1)\big )^2 \not \equiv \min _{\pi } \Big \{ \sum _{1\le i,j\le n} d_G(i,j)^2 \pi _{ij} ~:~\sum _{i=1}^n\pi _{ij}=p^0_j ,\quad \sum _{j=1}^n\pi _{ij}=p^1_i,\quad \pi _{ij}\ge 0 \Big \}. \end{aligned}$$
(5)

The reason for this in-equivalence is that the discrete sample space I is not a length space. In other words, there is no continuous path in I connecting two nodes in I. For more details see discussions in the appendix.

2.3 Wasserstein geometry and discrete probability simplex

In this section we introduce the primal coordinates of the discrete probability simplex with \(L^2\)-Wasserstein Riemannian metric. Our discussion follows the recent work [22]. The probability simplex \(\mathcal {P}(I)\) is a manifold with boundary. To simplify the discussion, we focus on the interior \(\mathcal {P}_+(I)\). The geodesic properties on the boundary \(\partial \mathcal {P}(I)\) have been studied in [17].

Let us focus on the Riemannian structure. In the following we introduce an inner product on the tangent space

$$\begin{aligned} T_p\mathcal {P}_+(I) = \Big \{(\sigma _i)_{i=1}^n\in \mathbb {R}^n:\sum _{i=1}^n\sigma _i=0 \Big \}. \end{aligned}$$

Denote the space of potential functions on I by \( \mathcal {F}(I)=\mathbb {R}^{n}\). Consider the quotient space

$$\begin{aligned} \mathcal {F}(I)/ \mathbb {R}=\{[\Phi ] :(\Phi _i)_{i=1}^n\in \mathbb {R}^n\}, \end{aligned}$$

where \([\Phi ]=\{(\Phi _1+c,\ldots , \Phi _n+c) :c\in \mathbb {R}\}\) are functions defined up to addition of constants.

We introduce an identification map via (4)

$$\begin{aligned} V:\mathcal {F}(I)/\mathbb {R} \rightarrow T_p\mathcal {P}_+(I),\quad \quad V_\Phi =-\text {div}_G(p\nabla _G\Phi ). \end{aligned}$$

In [12] it is shown that \(V_\Phi :\mathcal {F}(I)/\mathbb {R}\rightarrow T_p\mathcal {P}_+(I)\) is a well defined map which is linear and one-to-one. I.e., \( \mathcal {F}(I)/\mathbb {R}\cong T_p^*\mathcal {P}_+(I)\), where \(T_p^*\mathcal {P}_+(I)\) is the cotangent space of \(\mathcal {P}_+(I)\). This identification induces the following inner product on \(T_p\mathcal {P}_+(I)\).

We first present this in a dual formulation, which is known in the literature [25].

Definition 2

(Inner product in dual coordinates) The inner product \(g_p^W :T_p\mathcal {P}_+(I)\times T_p\mathcal {P}_+(I) \rightarrow \mathbb {R}\) takes any two tangent vectors \(V_{\Phi }\) and \(V_{\tilde{\Phi }}\in T_p\mathcal {P}_+(I)\) to

$$\begin{aligned} g_p^W(V_{\Phi }, V_{\tilde{\Phi }})=(\nabla _G\Phi , \nabla _G\tilde{\Phi })_p. \end{aligned}$$
(6)

We shall now give the inner product in primal coordinates. The following matrix operator will be the key to the Riemannian metric tensor of \((\mathcal {P}_+(I), g^W)\).

Definition 3

(Linear weighted Laplacian matrix) Given \(I=\{1,\ldots , n\}\) and a weighted graph \(G=(I,E,\omega )\), the matrix function \(L(\cdot ):\mathbb {R}^n\rightarrow \mathbb {R}^{n\times n}\) is defined by

$$\begin{aligned} L(a)=D^{\mathsf {T}}\Lambda (a)D,\quad a=(a_i)_{i=1}^n\in \mathbb {R}^n, \end{aligned}$$

where

  • \(D \in \mathbb {R}^{|E|\times n}\) is the discrete gradient operator

    $$\begin{aligned} D_{(i,j)\in {E}, k\in V}={\left\{ \begin{array}{ll} \sqrt{\omega _{ij}}, &{} \text {if }i=k,\hbox { }i>j\\ -\sqrt{\omega _{ij}}, &{} \text {if }j=k,\hbox { }i>j\\ 0, &{} \text {otherwise} \end{array}\right. }, \end{aligned}$$
  • \(-D^{\mathsf {T}}\in \mathbb {R}^{n\times |E|}\) is the discrete divergence operator, also called oriented incidence matrix [15], and

  • \(\Lambda (a)\in \mathbb {R}^{|E|\times |E|}\) is a weight matrix depending on a,

    $$\begin{aligned} \Lambda (a)_{(i,j)\in E, (k,l)\in E}={\left\{ \begin{array}{ll} \frac{a_i+a_j}{2} &{} \text {if }(i,j)=(k,l)\in E\\ 0 &{} \text {otherwise} \end{array}\right. }. \end{aligned}$$

Consider some \(p\in \mathcal {P}_+(I)\). From spectral graph theory [15], we know that L(p) can be decomposed as

$$\begin{aligned} L(p)=U(p)\begin{pmatrix} 0 &{} &{} &{}\\ &{} {\lambda _{1}(p)}&{} &{}\\ &{} &{} \ddots &{} \\ &{} &{} &{} {\lambda _{n-1}(p)} \end{pmatrix}U(p)^{\mathsf {T}} . \end{aligned}$$

Here \(0<\lambda _1(p)\le \cdots \le \lambda _{n-1}(p)\) are the eigenvalues of L(p) in ascending order, and \(U(p)=(u_0(p),u_1(p),\cdots , u_{n-1}(p))\) is the corresponding orthogonal matrix of eigenvectors with

$$\begin{aligned} u_0=\frac{1}{\sqrt{n}}(1,\ldots , 1)^{\mathsf {T}}. \end{aligned}$$

We write \(L(p)^{\dagger }\) for the pseudo-inverse of L(p), i.e.,

$$\begin{aligned} L(p)^{\dagger }=U(p)\begin{pmatrix} 0 &{} &{} &{}\\ &{} \frac{1}{\lambda _{1}(p)}&{} &{}\\ &{} &{} \ddots &{} \\ &{} &{} &{} \frac{1}{\lambda _{n-1}(p)} \end{pmatrix}U(p)^{\mathsf {T}} . \end{aligned}$$

With \(\sigma =L(p)\Phi \), \(\tilde{\sigma }=L(p)\tilde{\Phi }\), we see that

$$\begin{aligned} {\sigma }^{\mathsf {T}}L(p)^{\dagger }\tilde{\sigma }=\Phi ^{\mathsf {T}}L(p)L(p)^{\mathcal {\dagger }}L(p)\tilde{\Phi }=\Phi ^{\mathsf {T}}L(p)\tilde{\Phi }=(\nabla _G\Phi , \nabla _G\tilde{\Phi })_p. \end{aligned}$$

Now we are ready to give the inner product in primal coordinates.

Definition 4

(Inner product in primal coordinates) The inner product \(g^{W}_p:T_p\mathcal {P}_+(I)\times T_p\mathcal {P}_+(I)\rightarrow \mathbb {R}\) is defined by

$$\begin{aligned} g^{W}_p(\sigma ,\tilde{\sigma }):={\sigma }^{\mathsf {T}}L(p)^{\dagger }\tilde{\sigma },\quad \text {for any }\sigma ,\tilde{\sigma }\in T_p\mathcal {P}_+(I). \end{aligned}$$

In other words, the variational problem from Definition 1 is a minimization of geometry energy functional in \(\mathcal {P}_+(I)\), i.e.,

$$\begin{aligned} W( p^0, p^1)^2=\inf _{p(t)\in \mathcal {P}_+(I),t\in [0,1]}\Big \{\int _0^1\dot{p}(t)^{\mathsf {T}}L(p(t))^{\mathcal {\dagger }}\dot{p}(t)dt~:~ p(0)= p^0,~ p(1)= p^1\Big \}. \end{aligned}$$

This defines a Wasserstein Riemannian structure on the probability simplex. For more details of Riemannian formulas see [22]. Following [20] we could call \((\mathcal {P}_+(I), g^W)\) discrete density manifold. However, this could be easily confused with other notions from information geometry, and hence we will use the more explicit terminology Wasserstein statistical manifold, or Wasserstein manifold for short.

3 Wasserstein statistical manifold

In this section we study parametric probability models endowed with the \(L^2\)-Wasserstein Riemannian metric. We define this in the natural way, by pulling back the Riemannian structure from the Wasserstein manifold that we discussed in the previous section. This allows us to introduce a natural gradient flow on the parameter space of a statistical model.

3.1 Wasserstein statistical manifold

Consider a statistical model defined by a triplet \((\Theta , I, p)\). Here, \(I=\{1,\ldots , n\}\) is the sample space, \(\Theta \) is the parameter space, which is an open subset of \(\mathbb {R}^d\), \(d\le n-1\), and \(p:\Theta \rightarrow \mathcal {P}_+(I)\) is the parametrization function,

$$\begin{aligned} p(\theta )=(p_i(\theta ))_{i=1}^n,\quad \theta \in \Theta . \end{aligned}$$

In the sequel we will assume that \(\text {rank}(J_{\theta } p(\theta ))=d\), so that the parametrization is locally injective.

We define a Riemannian metric g on \(\Theta \) as the pull-back of metric \(g^W\) on \(\mathcal {P}_+(I)\). In other words, we require that \(p:(\Theta ,g)\rightarrow (\mathcal {P}_+(I), g^W)\) is an isometric embedding:

$$\begin{aligned} \begin{aligned} g_\theta (a,b)&:=g^W_{ p(\theta )}(d p(\theta )(a), d p(\theta )(b))\\&=\big (d p(\theta )(a)\big )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\big (d p(\theta )(b)\big ). \end{aligned} \end{aligned}$$

Here \(d p(\theta )(a)=\big (\sum _{j=1}^n\frac{\partial p_i(\theta )}{\partial \theta _j}a_j\big )_{i=1}^n=J_\theta p(\theta )a\), where \(J_\theta p(\theta )\) is the Jacobi matrix of \( p(\theta )\) with respect to \(\theta \). We arrive at the following definition.

Definition 5

For any pair of tangent vectors \(a,b\in T_\theta \Theta =\mathbb {R}^d\), define

$$\begin{aligned} g_\theta (a,b):=a^{\mathsf {T}}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )b, \end{aligned}$$

where \(J_\theta p(\theta )=(\frac{\partial p_i(\theta )}{\partial \theta _j})_{1\le i\le n, 1\le j\le d}\in \mathbb {R}^{n\times d}\) is the Jacobi matrix of the parametrization p, and \(L( p(\theta ))^{\mathcal {\dagger }}\in \mathbb {R}^{n\times n}\) is the pseudo-inverse of the linear weighted Laplacian matrix.

This inner product is consistent with the restriction of the Wasserstein metric \(g^W\) to \( p(\Theta )\). For this reason, we call \( p(\Theta )\), or \((\Theta , I, p)\), together with the induced Riemannian metric g, Wasserstein statistical manifold.

We need to make sure that the embedding procedure is valid, because the metric tensor \(L(p)^{\mathcal {\dagger }}\) is only of rank \(n-1\). The next lemma shows that \((\Theta , g)\) is a well defined d-dimensional Riemannian manifold.

Lemma 6

For any \(\theta \in \Theta \), we have

$$\begin{aligned} \lambda _{\min }(\theta )=\inf _{a\in \mathbb {R}^d, \Vert a\Vert _2=1}g_\theta (a, a)>0. \end{aligned}$$

In addition, \(g_{\theta }\) is smooth as a function of \(\theta \), so that \((\Theta , g)\) is a smooth Riemannian manifold.

Proof

We only need to show that \(J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\in \mathbb {R}^{d\times d}\) is a positive definite matrix. Consider

$$\begin{aligned} a^{\mathsf {T}}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )a=0, \end{aligned}$$

where \(0\in \mathbb {R}^{n-1}\). Since L(p) only has one simple eigenvalue 0 with eigenvector \(u_0\), then

$$\begin{aligned} J_\theta p(\theta )a=cu_0,\quad \text {for some constant }c\in \mathbb {R}^1. \end{aligned}$$
(7)

Since \(u_0^{\mathsf {T}} p(\theta )=\frac{1}{\sqrt{n}}\sum _{i=1}^np_i(\theta )=0\), we have that \(u_0^{\mathsf {T}}\frac{\partial p(\theta )}{\partial \theta _j}=\frac{1}{\sqrt{n}}\sum _{i=1}^n\frac{\partial p_i(\theta )}{\partial \theta _j}=0\), i.e.,

$$\begin{aligned} u_0^{\mathsf {T}}J_\theta p(\theta )=0. \end{aligned}$$

Left multiply \(u_0\) into (7), we obtain

$$\begin{aligned} 0=u_0^{\mathsf {T}}J_\theta p(\theta )a=cu_0^{\mathsf {T}}u_0=c. \end{aligned}$$

Thus \(c=0\), and (7) forms

$$\begin{aligned} J_\theta p(\theta )a=0. \end{aligned}$$

Since \(\text {rank}(J_{\theta } p(\theta ))=d<n\), we have \(a=0\), which finishes the proof. \(\square \)

We illustrate some geometric calculations on parameter space \((\Theta , g)\). For simplicity of illustration, we assume \(\Theta \subset \mathbb {R}^d\), and denote a matrix function \(G(\theta )\in \mathbb {R}^{d\times d}\) with \(g_\theta (\dot{\theta }, \dot{\theta })=\dot{\theta }^{\mathsf {T}}G(\theta )\dot{\theta }\), i.e.,

$$\begin{aligned} G(\theta )=(J_\theta p(\theta ))^{\mathsf {T}}L(p(\theta ))^{\mathcal {\dagger }}(J_\theta p(\theta )). \end{aligned}$$
(8)

Under this notation, given \(\theta _0\), \(\theta _1\in \Theta \), the Riemannian distance on \((\Theta , g)\) is defined by the geometric action functional:

$$\begin{aligned} \text {Dist}(\theta _0,\theta _1)^2=\inf _{\theta (\cdot )\in {C^1([0,1];\Theta )}}\Big \{\int _0^1{\dot{\theta }(t)^{\mathsf {T}}G(\theta (t))\dot{\theta }(t)}dt~:~\theta (0)=\theta _0,~\theta (1)=\theta _1\Big \}. \end{aligned}$$
(9)

Denote \(\theta (t)=\theta _t\), and \(S_t\) is the Legendre transformation of \(\dot{\theta }_t\) in \((\Theta , g)\), then the cotangent geodesic flow satisfies

$$\begin{aligned} {\left\{ \begin{array}{ll} \dot{\theta }_t-G(\theta _t)^{-1}S_t=0\\ \dot{S}_t+\frac{1}{2}\frac{\partial }{\partial \theta } S^{\mathsf {T}}_tG(\theta _t)^{-1}S_t=0. \end{array}\right. } \end{aligned}$$
(10)

It is worth recalling the following facts. If p is an identity map, then (10) translates to

$$\begin{aligned} {\left\{ \begin{array}{ll} \dot{p}+\text {div}_G(p\nabla _G S)=0\\ \dot{S}+\frac{1}{4}\sum \nolimits _{j\in N(i)}(\nabla _GS)^2=0. \end{array}\right. } \end{aligned}$$

In addition, if \(I=\Omega \) and we replace i by x and \(p_i(t)\) by \(\rho (t,x)\), the above becomes

$$\begin{aligned} {\left\{ \begin{array}{ll} \frac{\partial \rho (t,x)}{\partial t}+\text {div}(\rho (t,x)\nabla S(t,x))=0\\ \frac{\partial S(t,x)}{\partial t}+\frac{1}{2}(\nabla S(t,x))^2=0, \end{array}\right. } \end{aligned}$$

which are the standard continuity and Hamilton-Jacobi equations on \(\Omega \). For these reasons, we call the two equations in (10) the continuity equation and the Hamilton-Jacobi equation on parameter space.

3.2 Geometry calculations in statistical manifold

We next present the geometric formulas in a probability model. This approach connects the geometry formulas in the full probability set to the ones in a submanifold \((p(\Theta ), g)\), and in the parameter space \((\Theta , g)\).

We first study the orthogonal projection operator from \((\mathcal {P}_+(I), g^W)\) to \((p(\Theta ), g)\).

Theorem 7

Given \(\theta \in \Theta \), for any tangent vector \(\sigma \in T_{p(\theta )}\mathcal {P}_+(I)\), there exists a unique orthogonal decomposition

$$\begin{aligned} \sigma = \sigma ^{\parallel }+\sigma ^{\perp }, \end{aligned}$$
(11)

with \(\sigma ^{\parallel }\in T_{ p(\theta )}p(\Theta )\) and \(\sigma ^{\perp }\in N_{ p(\theta )} p(\Theta )\), i.e., \(g^W_{p(\theta )}(\sigma ^{\parallel }, \sigma ^{\perp })=0\). At each point \(p(\theta )\), the projection matrix

$$\begin{aligned} H( p(\theta ))=J_\theta p(\theta )\big (J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\big )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\in \mathbb {R}^{n\times n}, \end{aligned}$$

gives the decomposition by

$$\begin{aligned} \sigma ^{\parallel }=H( p(\theta ))\sigma ,\quad \sigma ^{\perp }=(\mathbb {I}-H( p(\theta )))\sigma , \end{aligned}$$

where \(\mathbb {I}\) is an identity matrix in \(\mathbb {R}^{n\times n}\).

Proof

We first prove that (11) is a decomposition. It is to check that \(g^W_{p(\theta )}(\sigma ^{\parallel }, \sigma ^{\perp })=0\), i.e.

$$\begin{aligned}&\sigma ^{\mathsf {T}}H( p(\theta ))^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}(\mathbb {I}-H( p(\theta )))\sigma \\&\quad =\sigma ^{\mathsf {T}}\Big (H( p(\theta ))^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}H( p(\theta ))-H( p(\theta ))^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }} \Big )\sigma =0. \end{aligned}$$

Recall \(G(\theta )=J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\). We check that

$$\begin{aligned} \begin{aligned}&H( p(\theta ))^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}H( p(\theta ))\\&\quad =L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta ) G(\theta )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )G(\theta )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\\&\quad =L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta ) G(\theta )^{\mathcal {\dagger }}G(\theta )G(\theta )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\\&\quad =L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta ) G(\theta )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\\&\quad =H( p(\theta ))^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}, \end{aligned} \end{aligned}$$

which shows the claim. We next prove the uniqueness of decomposition (11). Suppose there are two decomposition \(\sigma =\sigma ^{||}+\sigma ^{\perp }\), \(\tilde{\sigma }=\tilde{\sigma }^{||}+\tilde{\sigma }^{\perp }\), where \({\sigma }^{||}=J_\theta p(\theta )\dot{\theta }\) and \(\tilde{\sigma }^{||}=J_\theta p(\theta )\dot{\tilde{\theta }}\). From the definition, then

$$\begin{aligned} \begin{aligned} 0&=g_p^W(\sigma ^{||}-{\tilde{\sigma }}^{||}, {\tilde{\sigma }}^{\perp }-\sigma ^{\perp })=g_p^{W}(\sigma ^{||}-{\tilde{\sigma }}^{||}, \sigma ^{||}-{\tilde{\sigma }}^{||})\\&= (J_\theta p(\theta )\dot{\theta }-J_\theta p(\theta )\dot{\tilde{\theta }})^{\mathsf {T}}L(p(\theta ))^{\mathcal {\dagger }} (J_\theta p(\theta )\dot{\theta }-J_\theta p(\theta )\dot{\tilde{\theta }})\\&=(\dot{\theta }-\dot{\tilde{\theta }})^{\mathsf {T}}J_\theta p(\theta )^{\mathsf {T}}L(p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )(\dot{\theta }-\dot{\tilde{\theta }})\\&= (\dot{\theta }-\dot{\tilde{\theta }})^{\mathsf {T}}G(\theta )(\dot{\theta }-\dot{\tilde{\theta }}). \end{aligned} \end{aligned}$$

Since \(G(\theta )\) is positive definite, we have \(\dot{\theta }=\dot{\tilde{\theta }}\) and \(\sigma ^{||}={\tilde{\sigma }}^{||}\), which finishes the proof. \(\square \)

We next present the second fundamental form for submanifold \((p(\Theta ), g)\). Given any \(\sigma \), \(\tilde{\sigma }\in T_{ p(\theta )}p(\Theta )\), consider the orthogonal decomposition of Levi–Civita connection in \((\mathcal {P}_+(I), g^W)\):

$$\begin{aligned} \nabla _{\sigma }^W\tilde{\sigma }=(\nabla ^W_{\sigma }\tilde{\sigma })^{\parallel }+(\nabla ^W_{\sigma }\tilde{\sigma })^{\perp }. \end{aligned}$$

The second fundamental form is the orthogonal part of this decomposition, i.e., \(B_{p(\theta )}(\sigma , \tilde{\sigma }):=(\nabla ^W_{\sigma }\tilde{\sigma })^{\perp }\).

Proposition 8

(Second fundamental form) Let \(\nabla _G\cdot \circ \nabla _G\cdot :\mathbb {R}^n\times \mathbb {R}^n\rightarrow \mathbb {R}^n\) so that, for any \(\Phi \), \(\tilde{\Phi }\in \mathbb {R}^n\),

$$\begin{aligned} (\nabla _G\Phi \circ \nabla _G\tilde{\Phi }):=\Big (g^I_i(\nabla _G\Phi , \nabla _G\tilde{\Phi })\Big )_{i=1}^n= \frac{1}{2}\Big (\sum _{j\in N(i)} \omega _{ij}(\Phi _i-\Phi _j)(\tilde{\Phi }_i-\tilde{\Phi }_j)\Big )_{i=1}^n. \end{aligned}$$

Then

$$\begin{aligned} \begin{aligned} B_{p(\theta )}(\sigma , \tilde{\sigma })&=-\frac{1}{2}\Big (\mathbb {I}-H( p(\theta ))\Big )\Big \{L(\sigma )L( p(\theta ))^{\mathcal {\dagger }}\tilde{\sigma }+L(\tilde{\sigma })L( p(\theta ))^{\mathcal {\dagger }}\sigma \\&\quad -L( p(\theta ))(\nabla _GL( p(\theta ))^{\mathcal {\dagger }}\sigma \circ \nabla _GL( p(\theta ))^{\mathcal {\dagger }}\tilde{\sigma }) \Big \}. \end{aligned} \end{aligned}$$

Proof

As shown in [22, Proposition 11], the Christoffel formula in \((\mathcal {P}_+(I), g^W)\) satisfies

$$\begin{aligned} \nabla ^W_{\sigma }\tilde{\sigma }= & {} \frac{1}{2}\Big \{L(\sigma )L( p(\theta ))^{\mathcal {\dagger }}\tilde{\sigma }+L(\tilde{\sigma })L( p(\theta ))^{\mathcal {\dagger }}\sigma \nonumber \\&-L( p(\theta ))(\nabla _GL( p(\theta ))^{\mathcal {\dagger }}\sigma \circ \nabla _GL( p(\theta ))^{\mathcal {\dagger }}\tilde{\sigma })\Big \}. \end{aligned}$$
(12)

Following the projection operator \(H(p(\theta ))\), we finish the proof. \(\square \)

We next establish the parallel transport and geodesic equation in \((p(\Theta ), g)\).

Proposition 9

(Parallel transport) Let \(p(\theta _t)\in p(\Theta )\), \(t\in (0,1)\) be a smooth curve. Consider a vector field \(\sigma _t\in T_{p(\theta _t)} p(\Theta )\) along curve \( p(\theta _t)\). Then the equation for \(\sigma _t\) to be parallel along \( p(\theta _t)\) satisfies

$$\begin{aligned} \dot{\sigma }_t= & {} \frac{1}{2}H( p(\theta _t))\Big \{ L(\sigma )L( p(\theta _t))^{\mathcal {\dagger }}\dot{p}(\theta _t) +L(\dot{p}(\theta _t))L( p(\theta _t))^{\mathcal {\dagger }}\sigma _t\\&\quad - L( p(\theta _t))(\nabla _GL( p(\theta _t))^{\mathcal {\dagger }} p(\theta _t)\circ \nabla _GL( p(\theta _t))^{\mathcal {\dagger }}\dot{p}(\theta _t) )\Big \}. \end{aligned}$$

If \(\sigma _t=\dot{p}(\theta _t)\), then the geodesic equation satisfies

$$\begin{aligned} \begin{aligned} \ddot{p}(\theta _t)&=H( p(\theta _t))\Big \{L(\dot{p}(\theta _t))L( p(\theta _t))^{\mathcal {\dagger }}\dot{p}(\theta _t)\\&\quad -\frac{1}{2}L( p(\theta _t))(\nabla _GL( p(\theta _t))^{\mathcal {\dagger }} p(\theta _t)\circ \nabla _GL( p(\theta _t))^{\mathcal {\dagger }}\dot{p}(\theta _t) ) \Big \}. \end{aligned} \end{aligned}$$

Proof

The parallel equation in a submanifold is given by

$$\begin{aligned} \dot{\sigma }_t+\Big (\nabla ^W_{\dot{p}(\theta _t)}\sigma _t\Big )^{\parallel }=0. \end{aligned}$$

In other words, we have

$$\begin{aligned} \dot{\sigma }_t=-H(p(\theta _t))\nabla ^W_{\dot{p}(\theta _t)}\sigma _t, \end{aligned}$$

where \(\nabla ^W\) is defined in (12). Let \(\sigma _t=\dot{p}(\theta _t)\), then

$$\begin{aligned} \ddot{p}(\theta _t)+\Big (\nabla ^W_{\dot{p}(\theta _t)}\dot{p}(\theta _t)\Big )^{\parallel }=0. \end{aligned}$$

This means that

$$\begin{aligned} \ddot{p}(\theta _t)=-H(\theta _t)\nabla ^W_{\dot{p}(\theta _t)}\dot{p}(\theta _t). \end{aligned}$$

Following the projection operator and (12), we finish the proof. \(\square \)

We last present the curvature tensor in \((p(\Theta ), g)\), denoted by \(R(\cdot , \cdot )\cdot :T_{p(\Theta )} p(\Theta )\times T_{p(\theta )} p(\Theta )\times T_{p(\theta )} p(\Theta )\rightarrow T_{p(\theta )} p(\Theta )\).

Proposition 10

(Curvature tensor) Given \(\sigma _1\), \(\sigma _2\), \(\sigma _3\), \(\sigma _4\in T_{ p(\theta )} p(\Theta )\), then

$$\begin{aligned} \begin{aligned} g_{p(\theta )}(R(\sigma _1,\sigma _2)\sigma _3, \sigma _4)&= m(\sigma _1,\sigma _4)^{\mathsf {T}}\Big (\mathbb {I}-H( p(\theta ))\Big )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\Big (\mathbb {I}-H( p(\theta ))\Big )m(\sigma _2,\sigma _3)\\&\qquad -m(\sigma _1,\sigma _3)^{\mathsf {T}}\Big (\mathbb {I}-H( p(\theta ))\Big )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}\Big (\mathbb {I}-H( p(\theta ))\Big )m(\sigma _2,\sigma _4)\\&\qquad +\frac{1}{2}\Big \{\sigma _2^{\mathsf {T}}L( p(\theta ))^{\dagger } L(m(\sigma _1, \sigma _3))L( p(\theta ))^{\dagger }\sigma _4\\&\qquad + \sigma _1^{\mathsf {T}}L( p(\theta ))^{\dagger }L(m(\sigma _2, \sigma _4))L( p(\theta ))^{\dagger }\sigma _3\\&\qquad -\sigma _2^{\mathsf {T}}L( p(\theta ))^{\dagger } L(m(\sigma _1, \sigma _4))L( p(\theta ))^{\dagger }\sigma _3\\&\qquad - \sigma _1^{\mathsf {T}}L( p(\theta ))^{\dagger } L(m(\sigma _2, \sigma _3))L( p(\theta ))^{\dagger }\sigma _4\Big \}\\&\qquad +\frac{1}{4}\Big \{2 n(\sigma _1, \sigma _2)^{\mathsf {T}}L( p(\theta ))^{\dagger }n(\sigma _3,\sigma _4)\\&\qquad +n(\sigma _1,\sigma _3)^{\mathsf {T}}L( p(\theta ))^{\dagger }n(\sigma _2,\sigma _4)\\&\qquad -n(\sigma _2,\sigma _3)^{\mathsf {T}}L( p(\theta ))^{\dagger }n(\sigma _1,\sigma _4)\Big \}, \end{aligned} \end{aligned}$$

where m, \(n:T_{ p(\theta )} p(\Theta )\times T_{ p(\theta )} p(\Theta )\rightarrow T_{ p(\theta )} p(\Theta )\) are symmetric, antisymmetric operators respectively, which are defined by

$$\begin{aligned} m(\sigma _a, \sigma _b)&:=\nabla _{\sigma _a}^W\sigma _b=\frac{1}{2}\Big \{L(\sigma _a)L( p(\theta ))^{\mathcal {\dagger }}\sigma _b+L(\sigma _b)L( p(\theta ))^{\mathcal {\dagger }}\sigma _a\\&\quad -L( p(\theta ))(\nabla _GL( p(\theta ))^{\mathcal {\dagger }}\sigma _a\circ \nabla _GL( p(\theta ))^{\mathcal {\dagger }}\sigma _b )\Big \}, \end{aligned}$$

and

$$\begin{aligned} n(\sigma _a, \sigma _b):=L(\sigma _a)L( p(\theta ))^{\dagger }\sigma _b-L(\sigma _b)L( p(\theta ))^{\dagger }\sigma _a. \end{aligned}$$

Proof

The curvature tensor in submanifold relates to the one in full manifold as follows:

$$\begin{aligned} \begin{aligned} g_{p(\theta )}(R(\sigma _1,\sigma _2)\sigma _3, \sigma _4)&=B_{p(\theta )}(\sigma _1,\sigma _4)^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }} B_{p(\theta )}(\sigma _2,\sigma _3)\\&\quad -B_{p(\theta )}(\sigma _1,\sigma _3)^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }} B_{p(\theta )}(\sigma _2,\sigma _4)\\&\quad +g_{p(\theta )}(R_W(\sigma _1,\sigma _2)\sigma _3, \sigma _4), \end{aligned} \end{aligned}$$

where \(R_W\) is the curvature tensor of \((\mathcal {P}_+(I), g^W)\) derived in [22, Proposition 6]. Combining \(R_W\) and the second fundamental form in Proposition 8, we derive the result. \(\square \)

4 Gradient flow on Wasserstein statistical manifold

In this section we introduce the natural Riemannian gradient flow on Wasserstein statistical manifold \((\Theta , g)\).

4.1 Gradient flow on parameter space

Consider a smooth loss function \(F:\mathcal {P}_+(I)\rightarrow \mathbb {R}\). Thus we focus on the composition \(F\circ p:\Theta \rightarrow \mathbb {R}\). The Riemannian gradient of \(F( p(\theta ))\) is defined as follows. Given \(\nabla _g F( p(\theta ))\in T_\theta \Theta \), we have

$$\begin{aligned} g_\theta (\nabla _g F( p(\theta )), a)= \nabla _\theta { F}( p(\theta ))\cdot a,\quad \text {for any}~ a\in T_\theta \Theta , \end{aligned}$$
(13)

where \(\nabla _\theta F( p(\theta ))\cdot a= \sum _{i=1}^d\frac{\partial }{\partial \theta _i}F( p(\theta ))a_i\). The gradient flow satisfies

$$\begin{aligned} \dot{\theta }_t=-\nabla _g F( p(\theta _t)). \end{aligned}$$

The next theorem establishes an explicit formulation of the gradient flow.

Theorem 11

(Wasserstein gradient flow) The gradient flow of a functional \(F: \mathcal {P}_+(I)\rightarrow \mathbb {R}\) is given by

$$\begin{aligned} \dot{\theta }_t=-G(\theta _t)^{-1}\nabla _\theta F( p(\theta _t)), \end{aligned}$$

where \(\nabla _{\theta }\) is the Euclidean gradient of \(F( p(\theta ))\) with respect to \(\theta \). More explicitly,

$$\begin{aligned} \dot{\theta }_t=-\Big (J_\theta p(\theta _t)^{\mathsf {T}}L( p(\theta _t))^{\mathcal {\dagger }}J_\theta p(\theta _t)\Big )^{\mathcal {\dagger }}J_\theta p(\theta _t)^{\mathsf {T}}\nabla _ pF( p(\theta _t), \end{aligned}$$
(14)

where \(\nabla _p\) is the Euclidean gradient of F(p) with respect to p.

Proof

The proof follows directly from (13). Notice that

$$\begin{aligned} g_\theta (\nabla _gF( p(\theta )), a)=\nabla _g F( p(\theta ))^{\mathsf {T}}J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta ) a=\nabla _\theta F( p(\theta ))^{\mathsf {T}}a, \end{aligned}$$

and \(J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\) is an invertible matrix. Hence

$$\begin{aligned} \nabla _gF( p(\theta ))=\bigg (J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\bigg )^{\mathcal {\dagger }}\nabla _\theta F( p(\theta )). \end{aligned}$$

We compute \(\nabla _\theta F( p(\theta ))\) as

$$\begin{aligned} \nabla _\theta F(p(\theta ))=\bigg (\frac{\partial }{\partial \theta _i} F( p(\theta )) \bigg )_{i=1}^n= \bigg (\sum _{j=1}^n\frac{\partial }{\partial p_j} F( p(\theta ))\cdot \frac{\partial p_j(\theta )}{\partial \theta _i} \bigg )_{i=1}^n =J_\theta p(\theta )^{\mathsf {T}}\nabla _ pF( p(\theta )). \end{aligned}$$

This concludes the proof of (14). \(\square \)

Equation (14) is the generalization of Wasserstein gradient flow in probability simplex to the one on parameter space. If p is an identity map with the parameter space \(\Theta \) equal to the entire probability simplex, then (14) is

$$\begin{aligned} \dot{p}_t=-\nabla _g F(p_t)=\text {div}_G(p_t\nabla _G\nabla _p F(p_t)), \end{aligned}$$

which is the Wasserstein gradient flow on the discrete probability simplex. In particular, if \(I=\Omega \), then it represents

$$\begin{aligned} \partial _t\rho _t=-\nabla _W F(\rho _t)={\text {div}}(\rho _t \nabla \delta _\rho F(\rho _t)), \end{aligned}$$

which is the Wasserstein gradient flow on \(\Omega \). From now on, we call (14) the Wasserstein gradient flow on parameter space.

The definition of Wasserstein gradient flow shares many similarities with the steepest gradient descent defined as follows. Consider

$$\begin{aligned} \arg \min _{ h\in T_\theta \Theta } F( p(\theta + h)) \quad \text {s.t.}\quad \frac{1}{2}W( p(\theta ), p(\theta + h))^2=\epsilon , \end{aligned}$$
(15)

where \(\epsilon \in \mathbb {R}_+\) is a given small constant. By taking the second-order Taylor approximation of the Wasserstein distance at \(\theta \), we get

$$\begin{aligned} W( p(\theta ), p(\theta + h))^2= h^{\mathsf {T}}G(\theta ) h+o( h^2), \end{aligned}$$

where \(G(\theta )\) is the metric tensor of \((\Theta , g)\) defined in (8), inherited from Wasserstein manifold. We take the first-order approximation of \(F( p(\theta + h))\) in (15) by

$$\begin{aligned} \arg \min _{ h\in T_\theta \Theta } F( p(\theta )) + h^{\mathsf {T}}\nabla _\theta F( p(\theta )) \quad \text {s.t.}\quad \frac{1}{2} h^{\mathsf {T}}G(\theta ) h=\epsilon . \end{aligned}$$

By the Lagrangian method with Lagrange multiplier \(\lambda \), we have

$$\begin{aligned} h=\lambda G(\theta )^{-1}\nabla _{\theta }F( p(\theta )). \end{aligned}$$

The above derivations lead to the Wasserstein natural gradient direction

$$\begin{aligned} \nabla _g F( p(\theta ))=G(\theta )^{-1}\nabla _\theta F( p(\theta )). \end{aligned}$$

Remark 2

In the standard Fisher–Rao natural gradient [2], we replace (15) by

$$\begin{aligned} \arg \min _{ h} F( p(\theta + h)) \quad \text {s.t.}\quad \text {KL}( p(\theta ) \Vert p(\theta + h))=\epsilon , \end{aligned}$$

where \({\text {KL}}\) stands for the Kullback-Leibler divergence (relative entropy) from \( p(\theta )\) to \( p(\theta + h)\). Our definition changes the KL-divergence by the Wasserstein distance.

4.2 Displacement convexity on parameter space

The Wasserstein structure on the statistical manifold not only provides us the gradient operator, but also the Hessian operator on \((\Theta , g)\). The latter allows us to introduce the displacement convexity on parameter space.

We first review some facts. One remarkable property of Wasserstein geometry is that it yields a correspondence between differential operators on sample space and differential operators on probability space. E.g., the Hessian operator on Wasserstein manifold is equal to the expectation of Hessian operator on sample space.

An important example is stochastic relaxation. Given \(f(x)\in C^{\infty }(\Omega )\), consider

$$\begin{aligned} F(\rho )=\mathbb {E}_{X\sim \rho }[f(X)]=\int _\Omega f(x)\rho (x)dx. \end{aligned}$$

It is known that the Hessian operator of \( F(\rho )\) on Wasserstein manifold satisfies

$$\begin{aligned} {\text {Hess}}_W F(\rho )(V_{\Phi }, V_{\tilde{\Phi }})=\mathbb {E}_{X\sim \rho }({\text {Hess}}f(X) \nabla \Phi (X), \nabla \tilde{\Phi }(X)). \end{aligned}$$

One can show that \({\text {Hess}}f\succeq \lambda \mathbb {I}\) if and only if \({\text {Hess}}_W F(\rho )(V_\Phi , V_\Phi )\succeq \lambda g^W_\rho (V_\Phi , V_\Phi )\). This means that f is \(\lambda \)-geodesic convex in \((\Omega , g^\Omega )\) if and only if \( F(\rho )\) is \(\lambda \)-geodesic convex in \((\mathcal {P}(\Omega ), g^W)\). In literature [38], the geodesic convexity on Wasserstein manifold is known as the displacement convexity.

In this section we would like to extend the displacement convexity to parameter space \(\Theta \). In other words, we relate the parameter to the differential structures of sample space via constrained Wasserstein geometry \((\Theta , g)\). If \(\Theta \) is the full probability manifold, our definition coincides with the classical Hessian operator in sample space.

Definition 12

(Displacement convexity on parameter space) Given \( F\circ p:\Theta \rightarrow \mathbb {R}\), we say that \(F(p(\theta ))\) is \(\lambda \)-displacement convex if for any constant speed geodesic \(\theta _t\), \(t\in [0,1]\) connecting \(\theta _0,\theta _1\in (\Theta , g)\), it holds that

$$\begin{aligned} F(p(\theta _t))\ge (1-t) F(p(\theta _0))+t F(p(\theta _1))-\frac{\lambda }{2}t(1-t){\text {Dist}}(\theta _0,\theta _1)^2, \end{aligned}$$

where \({\text {Dist}}\) is defined in (9). If \(F(p(\theta ))=\sum _{i=1}^n f_ip_i(\theta )\) is \(\lambda \)-displacement convex, then we call \(f\in \mathbb {R}^n\) \(\lambda \)-convex in \((\Theta , I, p)\).

Remark 3

In particular, the displacement convexity of KL divergence relates to the Ricci curvature lower bound on sample space. We elaborate this notion in [23].

We next derive the displacement convexity condition for stochastic relaxation.

Theorem 13

Assume \(\Theta \subset \mathbb {R}^d\) is a compact set and \(f=(f_i)_{i=1}^n\in \mathbb {R}^n\). Then f is \(\lambda \)-convex if and only if

$$\begin{aligned}&\sum _{i=1}^n p_i(\theta ) \Big (\Gamma (\Gamma (f, \Phi ),\Phi )-\frac{1}{2}\Gamma (\Gamma (\Phi , \Phi ), f)\Big )_i+\sum _{i=1}^nf_i B_{p(\theta )}(V_\Phi , V_\Phi )_i\nonumber \\&\quad \ge \lambda \sum _{i=1}^n \Gamma (\Phi , \Phi )_i p_i(\theta ), \end{aligned}$$
(16)

for any \(\Phi \in \mathcal {F}(I)/\mathbb {R}\) and \(\theta \in \Theta \). Here \(\Gamma :\mathbb {R}^n\times \mathbb {R}^n \rightarrow \mathbb {R}^n\) is given by

$$\begin{aligned} \Gamma (\Phi ,\tilde{\Phi })_i:=g_i^I(\nabla _G\Phi , \nabla _G\tilde{\Phi })=\frac{1}{2}\sum _{j\in N(i)}\omega _{ij}(\Phi _{i}-\Phi _{j})(\tilde{\Phi }_{i}-\tilde{\Phi }_{j}), \end{aligned}$$

and B is the second fundamental form given in Proposition 8.

Proof

If \(\Theta \) is a compact set, then the \(\lambda \)-displacement convexity of \(F(p(\theta ))\) is equivalent to

$$\begin{aligned} {\text {Hess}}_g F(p(\theta ))\succeq \lambda G(\theta ), \end{aligned}$$

where \({\text {Hess}}_g F\) is the Hessian operator in \((\Theta , g)\). We next calculate this Hessian operator explicitly. Notice that

$$\begin{aligned} {\text {Hess}}_gF(\sigma , \tilde{\sigma })={\text {Hess}}_WF(\sigma ,\tilde{\sigma })+B_{p(\theta )}(\sigma , \tilde{\sigma })^{\mathsf {T}}\nabla _p F( p(\theta )), \end{aligned}$$

where \(\text {Hess}_W\) is the Hessian operator in \((\mathcal {P}_+(I), g^W)\). Denote the above in dual coordinates, i.e. \(\sigma =\tilde{\sigma }=V_\Phi =V_{\tilde{\Phi }}=L(p(\theta ))\Phi \), and follow the geometric computations in [22, Proposition 18], we finish the proof. \(\square \)

Here \(\Gamma \) is the discrete Bakry-Emery Gamma one operator [7]. The geometry of Wasserstein manifold is directly related to the expectation of Bakry-Emery Gamma one operators [22]. In particular, if p is the identity mapping and \(I=\Omega \), then our definition (16) represents

$$\begin{aligned} \int _\Omega \Big (\Gamma (\Gamma (f, \Phi ),\Phi )-\frac{1}{2}\Gamma (\Gamma (\Phi , \Phi ), f)\Big )\rho (x)dx\ge \lambda \Gamma (\Phi ,\Phi )\rho dx, \end{aligned}$$

i.e.

$$\begin{aligned} \int _\Omega {\text {Hess}}f(x) (\nabla \Phi (x),\nabla \Phi (x))\rho (x)dx\ge \lambda \int _\Omega g^\Omega _x(\nabla \Phi , \nabla \Phi )\rho (x)dx \end{aligned}$$

for any \(\rho \), and vector fields \(\nabla \Phi \). The above inequality is same as requiring \(\text {Hess}f\succeq \lambda I\). Our definition extends this concept to parameter space.

4.3 Numerical methods

Here we discuss the numerical computation of the Wasserstein metric and the Wasserstein gradient flow.

Let us give a simple reformulation of the gradient that can be useful in practice, where typically \(n\gg d\). Note that

$$\begin{aligned} \Big (J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\Big )^{\mathcal {\dagger }}=J_\theta p(\theta )^{\mathcal {\dagger }}L( p(\theta ))(J_\theta p(\theta )^{\mathsf {T}})^{\mathcal {\dagger }}. \end{aligned}$$

Hence (7) can be written as

$$\begin{aligned} \frac{d\theta }{dt}=-J_\theta p(\theta )^{\mathcal {\dagger }}L( p(\theta ))(J_\theta p(\theta )^{\mathsf {T}})^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}}\nabla _ pF( p(\theta )\ . \end{aligned}$$

In this formulation, the computation of the pseudo inverse of \(L( p(\theta ))\in \mathbb {R}^{n\times n}\) is not needed, and the computation complexity reduces to that of obtaining the pseudo inverse of \(J_\theta p(\theta )\in \mathbb {R}^{n\times d}\).

Given the gradient flow (7), there are two standard choices of time discretization, namely the forward Euler scheme and the backward Euler scheme. Denote the step size by \(\lambda >0\). The forward Euler method computes a discretized trajectory by

$$\begin{aligned} \theta ^{k+1}=\theta ^k-\lambda \nabla _g F( p(\theta ^k)), \end{aligned}$$

while the backward Euler method computes

$$\begin{aligned} \theta ^{k+1}=\arg \min _{\theta \in \Theta } F( p(\theta ))+\frac{\text {Dist}(\theta , \theta ^k)^2}{2\lambda }, \end{aligned}$$

where \(\text {Dist}\) is the geodesic distance in parameter space \((\Theta , g)\).

In the information geometry literature, the forward Euler method is often referred to as natural gradient method. In Wasserstein geometry, the backward Euler method is often called the Jordan–Kinderlehrer–Otto (JKO) scheme. In the following we give pseudo code for both numerical methods.

figure a
figure b

In practice, the forward Euler method is usually easier to implement than the backward Euler method. We would also suggest to implement the natural Wasserstein gradient using this method for minimization problems. As known in optimization, the JKO scheme can also be useful for non-smooth objective functions. Moreover, the backward Euler method is usually unconditionally stable, which means that one can choose a large step size h for computations.

5 Examples

Example 1

( Wasserstein geodesics) Consider the sample space \(I=\{1,2,3\}\) with an unweighted graph \(1-2-3\). The probability simplex for this sample space is a triangle in \(\mathbb {R}^3\):

$$\begin{aligned} \mathcal {P}(I)=\Big \{(p_i)_{i=1}^3\in \mathbb {R}^3~:~\sum _{i=1}^3p_i=1, \quad p_i\ge 0\Big \}. \end{aligned}$$

Following Definition 1, the \(L^2\)-Wasserstein distance is given by

$$\begin{aligned} W(p^0, p^1)^2&:=\inf _{\Phi (t)}\int _0^1 \Big \{(\Phi _1(t)-\Phi _2(t))^2\frac{p_1(t)+p_2(t)}{2}\nonumber \\&\quad +(\Phi _2(t)-\Phi _3(t))^2\frac{p_2(t)+p_3(t)}{2}\Big \} dt, \end{aligned}$$
(17)

where the infimum is taken over paths \(\Phi :[0,1]\rightarrow \mathbb {R}^3\). Each \(\Phi \) defines \(p:[0,1]\rightarrow \mathbb {R}^3\) as the solution of the differential equation

$$\begin{aligned} {\left\{ \begin{array}{ll} \dot{p}_1&{}=(\Phi _1-\Phi _2)\frac{p_1+p_2}{2}\\ \dot{p}_2&{}=(\Phi _2-\Phi _1)\frac{p_1+p_2}{2}+(\Phi _2-\Phi _3)\frac{p_2+p_3}{2}\\ \dot{p}_3&{}=(\Phi _3-\Phi _2)\frac{p_2+p_3}{2} \end{array}\right. } \end{aligned}$$

with boundary condition \(p(0)=p^0\), \(p(1)=p^1\).

Consider local coordinates in (17). We parametrize a probability vector as \(p=(p_1, 1-p_1-p_3, p_3)\), with parameters \((p_1, p_3)\). Then (17) can be written as

$$\begin{aligned} W(p^0, p^1)^2:=\inf _{p(t):p(0)=p^0,~p(1)=p^1}\int _0^1\Big \{\frac{\dot{p}_1(t)^2}{1-p_3(t)}+\frac{\dot{p}_3(t)^2}{1-p_1(t)}\Big \}dt. \end{aligned}$$
(18)

where the infimum is taken over paths \(p:[0,1]\rightarrow \mathcal {P}_+(I)\). We also compare the Wasserstein metric (18) with the Fisher–Rao metric. In this case, the Fisher–Rao metric function is given by

$$\begin{aligned} \text {FR}(p^0, p^1)^2:=\inf _{p(t):p(0)=p^0,~p(1)=p^1}\int _0^1\Big \{\frac{\dot{p}_1(t)^2}{p_1(t)}+\frac{(\dot{p}_1(t)+\dot{p}_3(t))^2}{p_2(t)}+\frac{\dot{p}_3(t)^2}{p_3(t)}\Big \}dt. \end{aligned}$$

This clearly demonstrates the difference between Wasserstein Riemannian metric and Fisher–Rao metric. We would also compare the dynamical optimal transport with the statistical one. In particular, if the ground metric is given by \(c_{12}=1\), \(c_{13}=2\), \(c_{23}=1\), which is of homogenous degree one type. Then the statistical optimal transport defined by

$$\begin{aligned} d(p^0, p^1)=\inf _{\pi \ge 0}\Big \{c_{12}\pi _{12}+c_{13}\pi _{13}+c_{12}\pi _{23}:\sum _{i=1}^3\pi _{ij}=p^0_j,~\sum _{j=1}^3\pi _{ij}=p^1_i\Big \}, \end{aligned}$$

can be reformulated by

$$\begin{aligned} {d}(p^0, p^1)=\inf _{p(t):p(0)=p^0,~p(1)=p^1}\int _0^1\Big \{|\dot{p}_1(t)|+|\dot{p}_3(t)|\Big \}dt. \end{aligned}$$

Here the statistical formulation does not provide a Riemannian metric, but gives a Finslerian metric.

We next compute (18) numericallyFootnote 2 for different choices of the boundary conditions \(p^0\), \(p^1\). We fix three distributions

$$\begin{aligned} q^1=\frac{1}{8}(6, 1, 1), \quad q^2=\frac{1}{8}(1, 6, 1),\quad q^3=\frac{1}{8}(1, 1, 6) \end{aligned}$$
(19)

and solve (18) for three choices of the boundary conditions:

$$\begin{aligned} p^0=q^1,\; p^1=q^2 ; \quad p^0=q^1,\; p^1=q^3 ; \quad p^0=q^2,\; p^1=q^3 . \end{aligned}$$
(20)

This gives us a geodesic triangle between \(q^1, q^2, q^3\), which is illustrated in Fig. 1. It can be seen that \((\mathcal {P}_+(I), W)\) has a non Euclidean geometry. Moreover, we see that the geodesics depend on the graph structure on sample space, where state 2 is qualitatively different from states 1 and 3.

Fig. 1
figure 1

The Wasserstein geodesic triangle from Example 1 plotted in the probability simplex (left) and in exponential parameter space (right). The path connecting \(q^1\) and \(q^3\) bends towards \(q^2\); something that does not happen for the other two paths. This illustrates how, as a result of the ground metric on sample space, state 2 is treated differently from 1 and 3

We can make the same derivations in terms of an exponential parametrization. Consider the parameter space \(\Theta =\{\theta =(\theta _1,\theta _2)\in \mathbb {R}^2\}\) and the parametrization \(p:\Theta \rightarrow \mathcal {P}_+(I)\) with

$$\begin{aligned} p_1(\theta )= & {} \frac{e^{\theta _1}}{e^{\theta _1}+e^{\theta _2}+1}, \quad p_3(\theta )=\frac{e^{\theta _2}}{e^{\theta _1}+e^{\theta _2}+1},\\ p_2(\theta )= & {} 1-p_1(\theta )-p_3(\theta )=\frac{1}{e^{\theta _1}+e^{\theta _2}+1}. \end{aligned}$$

We rewrite the Wasserstein metric (18) in terms of \(\theta \). Denote \( p(\theta ^k)=p^k\), \(k=0, 1\). Then the Wasserstein metric in the coordinate system \(\theta \) is

$$\begin{aligned}&{\text {Dist}}(\theta ^0,\theta ^1)^2\\&\quad =\inf _{\theta (t):\theta (0)=\theta ^0,~\theta (1)=\theta ^1}\Big \{\int _0^1\dot{\theta }^{\mathsf {T}}J_\theta (p_1, p_3)^{\mathsf {T}}\begin{pmatrix} \frac{1}{1-p_3(\theta )} &{} 0\\ 0 &{} \frac{1}{1-p_1(\theta )} \end{pmatrix} J_\theta (p_1, p_3) \dot{\theta }dt\Big \}. \end{aligned}$$

The resulting geodesic triangle in \(\Theta \) is plotted in the right panel of Fig. 1.

For comparison, we compute the exponential geodesic triangle between the same distributions \(q^1,q^2,q^3\). This is shown in Fig. 2. In this case, there is no distinction between the states 1, 2, 3 and the three paths are symmetric. The exponential geodesic between two distributions \(p^0\) and \(p^1\) is given by \((p^0)^{1-t}(p^1)^{t}/ \sum _x(p^0)^{1-t}(p^1)^{t}\), \(t\in [0,1]\).

Fig. 2
figure 2

Exponential geodesic triangle plotted in the probability simplex (left) and in exponential parameter space (right). Exponential geodesics correspond to straight lines in exponential parameter space

Example 2

(Wasserstein gradient flow on an independence model) We next illustrate the Wasserstein gradient flow over the independence model of two binary variables. The sample space is \(I = \{-1,+1\}^2\). For simplicity, we denote the states by \(a=(-1,-1)\), \(b=(-1, +1)\), \(c=(+1,-1)\), \(d=(+1,+1)\). We consider the square graph

$$\begin{aligned} \begin{matrix} b -d\\ | \phantom {--}|\\ a- c \end{matrix} \end{aligned}$$

with vertices I, edges \(E=\{\{a, b\},\{b,d\}, \{a,c\},\{c,d\} \}\), and weights \(\omega =( \omega _{ab}, \omega _{bd}, \omega _{ac}, \omega _{cd})\in \mathbb {R}^E\) attached to the edges. The edge weights correspond to the inverse squared ground metric that we assign to the sample space I. The probability simplex for this sample space is the tetrahedron

$$\begin{aligned} \mathcal {P}(I) = \Big \{(p(x))_{x\in I}\in \mathbb {R}^4~:~\sum _{x\in I}p(x)=1, \quad p(x)\ge 0\Big \}. \end{aligned}$$

Following Definition 4, the Wasserstein metric tensor is given by \(g_p^W=L(p)^{\mathcal {\dagger }}\), which is the inverse of the linear weighted Laplacian metric L from Definition 3. In this example the latter is

$$\begin{aligned} L(p)= { \begin{pmatrix} \omega _{ab}\frac{p_a+p_b}{2}+\omega _{ac}\frac{p_a+p_c}{2}&{} -\omega _{ab}\frac{p_a+p_b}{2}&{}-\omega _{ac}\frac{p_a+p_c}{2}&{} 0 \\ -\omega _{ab}\frac{p_a+p_b}{2}&{} \omega _{ab}\frac{p_a+p_b}{2}+\omega _{bd}\frac{p_b+p_d}{2} &{} 0 &{} -\omega _{bd}\frac{p_b+p_d}{2}\\ -\omega _{ac}\frac{p_a+p_c}{2} &{} 0 &{}\omega _{ac}\frac{p_a+p_c}{2}+\omega _{cd} \frac{p_c+p_d}{2}&{} -\omega _{cd}\frac{p_c+p_d}{2} \\ 0 &{}\omega _{bd}\frac{p_b+p_d}{2} &{} -\omega _{cd}\frac{p_c+p_d}{2} &{} \omega _{bd}\frac{p_b+p_d}{2}+\omega _{cd}\frac{p_c+p_d}{2} \\ \end{pmatrix} } . \end{aligned}$$

The independence model consist of the joint distributions that satisfy \(p(x_1, x_2)=p(x_1)p(x_2)\). This can be parametrized in terms of \(\Theta =\{\xi =(\xi _1,\xi _2)\in [0,1]^2\}\), where \(\xi _1=p_1(x_1=+1)\), \(\xi _2=p_2(x_2=+1)\) describe the marginal probability distributions. The parametrization \(p:\Theta \rightarrow \mathcal {P}(I)\) is then

$$\begin{aligned} p(\xi )(x_1,x_2)={\left\{ \begin{array}{ll} (1-\xi _1)(1-\xi _2) &{}\text {if }(x_1,x_2)=(-1, -1) \\ (1-\xi _1)\xi _2 &{} \text {if }(x_1,x_2)=(-1, +1) \\ \xi _1(1-\xi _2)&{} \text {if }(x_1,x_2)=(+1, -1) \\ \xi _1\xi _2&{} \text {if }(x_1,x_2)=(+1, +1) \end{array}\right. }. \end{aligned}$$

The model \( p(\Theta ) \subset \mathcal {P}(I)\) is a two dimensional manifold. The parameter space \(\Theta \) inherits the Riemannian structure \(g^W\) from \(\mathcal {P}(I)\), which is computed as follows. Denote the Jacobi matrix of the parametrization by

$$\begin{aligned} J_\xi p(\xi )=\begin{pmatrix} -(1-\xi _2)&{}-(1-\xi _1) \\ -\xi _2&{}1-\xi _1 \\ 1-\xi _2&{} -\xi _1 \\ \xi _2&{}\xi _1 \end{pmatrix}\in \mathbb {R}^{4\times 2}. \end{aligned}$$

Then \(g^W\) induces a metric tensor on \(\Theta \) given by

$$\begin{aligned} G(\xi )=J_\xi p(\xi )^{\mathsf {T}} L( p(\xi ))^{\mathcal {\dagger }} J_\xi ( p(\xi ))\in \mathbb {R}^{2\times 2}. \end{aligned}$$

We now consider a discrete optimization problem via stochastic relaxation and illustrate the gradient flow. Consider following potential function on I, taken from [28]:

$$\begin{aligned} f(x_1,x_2)=x_1+2x_2+3x_1x_2={\left\{ \begin{array}{ll} 0 &{}\text {if }(x_1,x_2)=(-1, -1) \\ -2&{} \text {if }(x_1,x_2)=(-1, +1) \\ -4 &{} \text {if }(x_1,x_2)=(+1, -1) \\ 6 &{} \text {if }(x_1,x_2)=(+1, +1) \end{array}\right. }. \end{aligned}$$

We are to minimize \(F(\mathbf{p})=\mathbb {E}_\mathbf{p}[ f]\), i.e.,

$$\begin{aligned} F( p(\xi ))=\sum _{(x_1,x_2)\in I}f(x_1,x_2)p_1(x_1)p_2(x_2)=-4\xi _1-2\xi _2+12\xi _1\xi _2. \end{aligned}$$

By Theorem 11, the Wasserstein gradient flow is

$$\begin{aligned} \dot{\xi }=-G(\xi )^{-1}\nabla _\xi F( p(\xi )). \end{aligned}$$

For our function, the standard Euclidean gradient is \(\nabla _\xi F( p(\xi ))=(-4+12\xi _2, -2+12\xi _1)^{\mathsf {T}}\). The matrix G is computed numerically from J and L.

Fig. 3
figure 3

Negative Wasserstein gradient on the parameter space \([0,1]^2\) of the two-bit independence model from Example 2. We fix the state graph shown on the top left, and a function f with values shown in gray next to the state nodes. We evaluate the gradient flow for three different choices of the graph weight \(\omega _{bd}\). When the weight \(\omega _{bd}\) is small, the flow from d towards b (a local minimum) is suppressed. A large weight has the opposite effect. The contours are for the objective function \(F( p(\xi )) = \mathbb {E}_{ p(\xi )}[f]\)

In Fig. 3 we plot the negative Wasserstein gradient vector field in the parameter space \(\Theta =[0,1]^2\). As can be seen, the Wasserstein gradient direction depends on the ground metric on sample space (encoded in the edge weights). If b and d are far away, there is higher tendency to go c, rather than b. This reflects the intuition that, the more ground distance between b and d, the harder for the probability distribution to move from its concentration place b to d. We observe that the the attraction region of the two local minimizers changes dramatically as the ground metric between b and d changes, i.e., as \(\omega _{bd}\) varies from 0.1, 1, 10. This is different in the Fisher–Rao gradient flow, plotted in Fig. 4, which is independent of the ground metric on sample space.

The above result illustrates the displacement convexity shown in Theorem 16. Different ground metric exhibits different displacement convexity of f on parameter space \((\Theta , g)\). These properties lead to different convergence regions of Wasserstein gradient flows.

Fig. 4
figure 4

Fisher–Rao gradient vector field for the same objective function of Fig. 3

Example 3

(Wasserstein gradient for maximum likelihood estimation) In maximum likelihood estimation, we seek to minimize the Kullback-Leibler divergence

$$\begin{aligned} {\text {KL}}(q \Vert p(\theta ))=\sum _{x\in I}q_x\log \frac{q_x}{p_x(\theta )}, \end{aligned}$$

where q is the empirical distribution of some given data. The Wasserstein gradient flow of \({\text {KL}}(q \Vert p(\theta ))\) satisfies

$$\begin{aligned} \frac{d\theta }{dt}=\Big (J_\theta p(\theta )^{\mathsf {T}}L( p(\theta ))^{\mathcal {\dagger }}J_\theta p(\theta )\Big )^{\mathcal {\dagger }}J_\theta p(\theta )^{\mathsf {T}} \left( \frac{q}{ p(\theta )}\right) . \end{aligned}$$

In this example we consider hierarchical log-linear models as our parametrized probability models, which are an important type of exponential families describing interactions among groups of random variables. Concretely, for an inclusion closed set S of subsets of \(\{1,\ldots , n\}\), the hierarchical model \(\mathcal {E}_S\) for n binary variables is the set of distributions of the form

$$\begin{aligned} p_x(\theta ) = \frac{1}{Z(\theta )}\exp \Big (\sum _{\lambda \in S} \theta _\lambda \phi _\lambda (x)\Big ), \quad x\in \{0,1\}^n, \end{aligned}$$

for all possible choices of parameters \(\theta _\lambda \in \mathbb {R}\), \(\lambda \in S\). Here the \(\phi _\lambda \) are real valued functions with \(\phi _\lambda (x)=\phi _\lambda (y)\) whenever \(x_i=y_i\) for all \(i\in \lambda \). We consider two different choices of \(\phi _\lambda \), \(\lambda \in S\), corresponding to two different parametrizations of the model.

  • Our first choice are the orthogonal characters

    $$\begin{aligned} \sigma _\lambda (x) = \prod _{i\in \lambda } (-1)^{x_i} = e^{ i \pi \langle 1_\lambda , x\rangle }, \quad x\in \{0,1\}^n, \end{aligned}$$

    which can be interpreted as a Fourier basis for the space of real valued functions over binary vectors.

  • As an alternative choice we consider the basis of monomials

    $$\begin{aligned} \pi _\lambda (x) = \prod _{i\in \lambda } x_i, \quad x\in \{0,1\}^n, \end{aligned}$$

    which is not orthogonal, but is frequently used in practice.

When \(S=\{ \lambda \subseteq \{1,\ldots , n\}:|\lambda |\le k\}\), the model is called k-interaction model. We consider k-interaction models with \(k=1,\ldots , n\) (independence model, pair interaction model, three way interaction model, etc.), with the two parametrizations, \(\sigma \) (orthogonal sufficient statistics) and \(\pi \) (non-orthogonal sufficient statistics).

We compare the Euclidean, Fisher, and Wasserstein gradients. For binary variables, the Hamming distance is a natural ground metric notion. Accordingly, we define the Wasserstein metric with the uniformly weighted graph of the binary cube. We sampled a few target distributions on \(\{0,1\}^n\) uniformly at random (uniform Dirichlet). For each target distribution, we initialize the model at the uniform distribution, \(\theta _0=0\). The gradient descent parameter iteration is

$$\begin{aligned} \theta _{t+1} = \theta _t - \gamma _t G(\theta _t)^{-1} \nabla {\text {KL}}(q\Vert p_{\theta _t}), \end{aligned}$$

where G is the corresponding metric (Euclidean, Fisher, or Wasserstein), \(\nabla \) is the standard gradient operator with respect to the model parameter \(\theta \), and \(\gamma _t\in \mathbb {R}_+\) is the learning rate (step size). The choice of the learning rate \(\gamma _t\) is important and the optimal value may vary for different methods and problems. We implemented an adaptive method to handle this as follows. We set an initial learning rate \(\gamma _0=0.001\), and at each iteration t, if the divergence does not decrease, we scale down the learning rate by a factor of 3 / 4. We also tried a few other methods, including backtracking line search and Adam [19], which is a method based on adaptive estimates of lower-order moments of the gradient. The stopping criterion was that the infinity norm of the expectation parameter matched the data expectation parameter to within 1 percent.

The results are shown in Fig. 5. The convergence to the final value can be monitored in terms of the normalized area under the optimization curve, \(\sum _{t=1}^T (D_t-D_T)/(D_0-D_T)\), where \(D_t\) is the divergence value at iteration t, and T is the final time. All methods achieved similar values of the divergence, except for the Euclidean gradient with non-orthogonal parametrization, which did not always reach the minimum. For the Fisher and Wasserstein gradients, the learning paths were virtually identical under the two different model parametrizations, as we already expected from the fact that these are covariant gradients. On the other hand, for the Euclidean gradient, the paths (and the number of iterations) were heavily dependent on the model parametrization, with the orthogonal basis usually being a much better choice than the non-orthogonal basis. In terms of the number of iterations until the convergence criterion was satisfied, the comparison is difficult because different methods work best with different step sizes. With the simple adaptive method and a suitable initial step size, the Wasserstein gradient was faster than the Euclidean and Fisher gradients. On the other hand, using Adam to adapt the step size, orthogonal Euclidean, Fisher, and Wasserstein were comparable.

Fig. 5
figure 5

Divergence minimization for random target distributions on \(\{0,1\}^n\), \(n=7\), over k-interaction models with \(k=1,\ldots ,n\). Shown is the average value of the divergence after optimization by Euclidean, Fisher, and Wasserstein gradient descent, and the corresponding number of gradient iterations. Orthogonal and non-orthogonal parametrization are indicated by \(\sigma \) and \(\pi \). The right hand side shows histograms of the normalized area under the optimization curves. The top figures are using a simple adaptive method for selecting the step size, and the bottom figures are using Adam

6 Discussion

We introduced the Wasserstein statistical manifolds, which are submanifolds of the probability simplex with the \(L^2\)-Wasserstein Riemannian metric tensor. With this, we defined an optimal transport natural gradient flow on parameter space.

The Wasserstein distance has already been discussed with divergences in information geometry and also shown to be useful in machine learning, for instance in training restricted Boltzmann machines and generative adversarial networks. In this work, we used the Wasserstein distance to define a geometry on the parameter space of a statistical model. Following this geometry, we establish a corresponding natural gradient and displacement convexity on parameter space.

We presented an application of the Wasserstein natural gradient method to maximum likelihood estimation in hierarchical probability models. The experiments show that, in combination with a suitable step size, the Wasserstein gradient can be a competitive optimization method and even reduce the required number of parameter iterations compared both to Euclidean and Fisher gradient methods. It will be essential to conduct further experimental studies to better understand the effects of the learning rate, as well as the interplay of ground metric, model, and optimization problem. In our current implementation, the Wasserstein gradient involved heavier computational costs compared to the Euclidean and Fisher gradients. For applications, it will be important to explore efficient computation and approximation approaches.

Regarding the theory, we suggest that many studies from information geometry will have a natural analog or extension in the Wasserstein statistical manifold. Some questions to consider include the following. Is it possible to characterize the Wasserstein metric on probability manifolds through an invariance requirement of Chentsov type? For instance, the work [32] formulates extensions of Markov embeddings for polytopes and weighted point configurations. Is there a weighted graph structure for which the corresponding Wasserstein metric recovers the Fisher metric?

The critical innovation coming from the Wasserstein gradient in comparison to the Fisher gradient is that it incorporates a ground metric in sample space. We suggest that this could have a positive effect not only concerning optimization, as discussed above, but also regarding generalization performance, in interplay with the optimization. The reason is that the ground metric on sample space provides means to introduce preferences in the hypothesis space. The specific form of such a regularization still needs to be developed and investigated. In this regard, a natural question is how to define natural ground metric notions. These could be fixed in advance or trained.

We hope that this paper contributes to strengthening the emerging interactions between information geometry and optimal transport, in particular, to machine learning problems, and to develop better natural gradient methods.