1 Introduction

Behavioural cloning (BC) (Pomerleau, 1988, 1991) is one of the simplest imitation learning methods to obtain a decision-making policy from expert demonstrations. BC frames the imitation learning problem as a supervised learning one. Given expert trajectories - the expert’s paths through the state space - a policy network is trained to reproduce the expert behaviour: for a given observation, the action taken by the policy must closely approximate the one taken by the expert. Although a simple method, BC has shown to be very effective across many application domains (Kadous et al., 2005; Pearce & Zhu, 2022; Pomerleau, 1988; Sammut et al., 1992), and has been particularly successful in cases where the dataset is large and has wide coverage (Codevilla et al., 2019). An appealing aspect of BC is that it is applied in an offline setting, using only the historical data. Unlike reinforcement learning (RL) methods, BC does not require further interactions with the environment. Offline policy learning can be advantageous in many circumstances, especially when collecting new data through interactions is expensive, time-consuming or dangerous; or in cases where deploying a partially trained, sub-optimal policy in the real-world may be unethical, e.g. in autonomous driving and medical applications.

BC extracts the behaviour policy which created the dataset. Consequently, when applied to sub-optimal data (i.e. when some or all trajectories have been generated by non-expert demonstrators), the resulting behavioural policy is also expected to be sub-optimal. This is due to the fact that BC has no mechanism to infer the importance of each state-action pair. Other drawbacks of BC are its tendency to overfit when giving a small number of demonstrations and the state distributional shift between training and test distributions (Codevilla et al., 2019; Ross et al., 2011). In the area of imitation learning, significant efforts have been made to overcome such limitations, however the available methodologies generally rely on interacting with the environment (Finn et al., 2016; Ho & Ermon, 2016; Le et al., 2018; Ross et al., 2011). So, a question arises: can we help BC infer a superior policy only from available sub-optimal data without the need to collect additional expert demonstrations?

Fig. 1
figure 1

Simplified illustration of Model-Based Trajectory Stitching. Each original trajectory (a sequence of states and actions) in the dataset \(\mathcal {D}\) is indicated as \(\mathcal {T}_i\) with \(i=1,\ldots ,3\). A first stitching event is seen in trajectory \(\mathcal {T}_1\) whereby a transition to a state originally visited in \(\mathcal {T}_2\) takes place. A second stitching event involves a jump to a state originally visited in \(\mathcal {T}_3\). At each event, jumping to a new state increases the current trajectory’s future expected returns. The resulting trajectory (in bold) consists of a sequence of states, all originally visited in \(\mathcal {D}\), but connected by imagined actions; it replaces \(\mathcal {T}_1\) in the new dataset

Our investigation is related to the emerging body of work on offline RL, which is motivated by the aim of inferring expert policies with only a fixed set of sub-optimal data (Lange et al., 2012; Levine et al., 2020). A major obstacle towards this aim is posed by the notion of action distributional shift (Fujimoto et al., 2019; Kumar et al., 2019; Levine et al., 2020). This is introduced when the policy being optimised deviates from the behaviour policy, and is caused by the action-value function overestimating out-of-distribution (OOD) actions. A number of existing methods address the issue by constraining the actions that can be taken. In some cases, this is achieved by constraining the policy to actions close to those in the dataset (Fujimoto et al., 2019; Fujimoto & Gu, 2021; Jaques et al., 2019; Kumar et al., 2019; Wu et al., 2019; Zhou et al., 2020), or by manipulating the action-value function to penalise OOD actions (Kumar et al., 2020; An et al., 2021; Kostrikov et al., 2021; Yu et al., 2021). In situations where the data is sub-optimal, offline RL has been shown to recover a superior policy to BC (Fujimoto et al., 2019; Kumar et al., 2022). Improving BC will in turn improve many offline RL policies that rely on an explicit behaviour policy of the dataset (Argenson & Dulac-Arnold, 2020; Fujimoto & Gu, 2021; Zhan et al., 2021).

In contrast to existing offline learning approaches, we turn the problem on its head: rather than trying to regularise or constrain the policy somehow, we investigate whether the data quality itself can be improved using only the available demonstrations. To explore this avenue, we propose a model-based data improvement method called Model-Based Trajectory Stitching (MBTS) . Our ultimate aim is to develop a procedure that identifies sub-optimal trajectories and replaces them with better ones. New trajectories are obtained by stitching existing ones together, without the need to generate unseen states. The proposed strategy consists of replaying each existing trajectory in the dataset: for each state-action pair leading to a particular next state along a trajectory, we ask whether a different action could have been taken instead, which would have landed at a different seen state from a different trajectory. An actual jump to the new state only occurs when generating such an action is plausible and it is expected to improve the quality of the original trajectory - in which case we have a stitching event.

An illustrative representation of this procedure can be seen in Fig. 1, where we assume to have at our disposal only three historical trajectories. In this example, a trajectory has been improved through two stitching events. To determine the stitching points, MBTS uses a probabilistic view of state-reachability that depends on learned dynamics models of the environment. These models are evaluated only on in-distribution states enabling accurate prediction. In order to assess the expected future improvement introduced by a potential stitching event, we utilise a state-value function and a reward model. Thus, MBTS can be thought of as a data-driven, automated procedure yielding highly plausible and higher-quality demonstrations to facilitate supervised learning; at the same time, sub-optimal demonstrations are removed altogether whilst keeping the diverse set of seen states.

Our experimental results show that MBTS produces higher-quality data, with BC-derived policies always superior than those inferred on the original data. Remarkably, we demonstrate that MBTS -augmented data allow BC to compete with state-of-the-art offline RL algorithms on highly complex continuous control openAI gym tasks implemented in MuJoCo using the D4RL offline benchmarking suite (Fu et al., 2020). Furthermore, we show that integrating MBTS with existing offline learning methods that explicitly use BC such as model-based planning (Argenson & Dulac-Arnold, 2020) and TD3+BC (Fujimoto & Gu, 2021) can significantly boost their performance.

2 Related work

2.1 Imitation learning

Imitation learning aims to emulate a policy from expert demonstrations (Hussein et al., 2017). BC is the simplest of such category of methods and uses supervised learning to clone the actions in the dataset. BC is a powerful method and has been used successfully in many applications such as learning a quadroter to fly (Giusti et al., 2015), self-driving cars (Bojarski et al., 2016; Farag & Saleh, 2018) and games (Pearce & Zhu, 2022). These application are highly complex and shows accurate policy estimation from high-quality offline data.

One drawback from using BC is the state distributional shift between training and test distributions. Improved imitation learning methods have been introduced to reduce this distributional shift, however they usually require online exploration. For instance, DAgger (Ross et al., 2011) is an online learning approach that iteratively updates a deterministic policy; it addresses the state distributional shift problem of BC through an on-policy method for data collection; similarly to MBTS, the original dataset is augmented, but this involves online interactions. Another algorithm, GAIL (Ho & Ermon, 2016), iteratively updates a generative adversarial network (Goodfellow et al., 2014) to determine whether a state-action pair can be deemed as expert; a policy is then inferred using a trust region policy optimisation step (Schulman et al., 2015). MBTS also uses generative modelling, but this is to create data points likely to have come from the data that connect high-value regions.

While expert demonstrations are crucial for imitation learning, our MBTS approach generates higher quality datasets from existing, potentially sub-optimal data, thereby enhancing offline policy learning. Furthermore, MBTS leverages a reward function to learn an improved policy, which distinguishes it from the imitation learning setting where access to rewards may not always be available. This key difference enables MBTS to deliver better performance in certain scenarios compared to traditional imitation learning methods.

2.2 Offline reinforcement learning

Offline RL aims to learn an optimal policy from sub-optimal datasets without further interactions with the environment (Lange et al., 2012; Levine et al., 2020). Similarly to BC, offline RL suffers from distributional shift. However this shift comes from the policy selecting OOD actions leading to overestimation of the value function (Fujimoto et al., 2019; Kumar et al., 2019). In the online setting, this overestimation encourages the agent to explore, but offline this leads to a compounding of errors where the agent believes OOD actions lead to high returns. Many offline RL algorithms bias the learned policy towards the behaviour-cloned one (Argenson & Dulac-Arnold, 2020; Fujimoto & Gu, 2021; Zhan et al., 2021) to ensure the policy does not deviate too far from the behaviour policy. Many of these offline methods are therefore expected to directly benefit from enhanced datasets yielding higher-achieving behavioural policies.

2.2.1 Model-free methods

Many model-free offline RL methods typically deal with distributional shift either by regularising the policy to stay close to actions given in the dataset (Fujimoto & Gu, 2021; Fujimoto et al., 2019; Jaques et al., 2019; Kumar et al., 2019; Wu et al., 2019; Zhou et al., 2020) or by pessimistically evaluating the Q-value to penalise OOD actions (An et al., 2021; Kostrikov et al., 2021; Kumar et al., 2020). Both options involve explicitly or implicitly capturing information about the unknown underlying behaviour policy. This behaviour policy can be fully captured using BC. For instance, batch-constrained Q-learning (BCQ) (Fujimoto et al., 2019) is a policy constraint method which uses a variational autoencoder to generate likely actions in order to constrain the policy. The TD3+BC algorithm (Fujimoto & Gu, 2021) offers a simplified policy constraint approach; it adds a behavioural cloning regularisation term to the policy update biasing actions towards those in the dataset. Alternatively, conservative Q-learning (CQL) (Kumar et al., 2020) adjusts the value of the state-action pairs to “push down” on OOD actions and “push up” on in-distribution actions. CQL manipulates the value function so that OOD actions are discouraged and in-distribution actions are encouraged. Implicit Q-learning (IQL) (Kostrikov et al., 2021) avoids querying OOD actions altogether by manipulating the Q-value to have a state-value function in the SARSA-style update. All the above methods try to directly deal with OOD actions, either by avoiding them or safely handling them in either the policy improvement or evaluation step. In contrast, our method rethinks the problem of learning from sub-optimal data. Rather than using RL to learn a policy, instead we use RL-based approaches to enrich the data enabling BC to extract an improved policy. Our method generates unseen actions between in-distribution states; by doing so, we avoid distributional shift by evaluating a state-value function only on seen states.

2.2.2 Model-based methods

Model-based algorithms rely on an approximation of the environment’s dynamics (Janner et al., 2019; Sutton, 1991), that is probability distributions where the next state and reward are predicted from a current state and action. In the online setting, model-based methods tend to improve sample efficiency (Buckman et al., 2018; Chua et al., 2018; Feinberg et al., 2018; Janner et al., 2019; Kalweit & Boedecker, 2017). In an offline learning context, the learned dynamics have been exploited in various ways.

One approach consists of using the models to improve the policy learning. For instance, Model-based offline RL (MOReL) (Kidambi et al., 2020) is an algorithm which constructs a pessimistic Markov Decision Model (P-MDP), based off a learned forward dynamics model and a state-action detector. The P-MDP is given an additional absorbing state, which gives large negative reward for unknown state-actions. Model-based Offline policy Optimization (MOPO) (Yu et al., 2020) augments the dataset by performing rollouts using a learned, uncertainty-penalised, MDP. Unlike MOPO, MBTS does not introduce imagined states, but only actions between reachable unconnected states.

Another opportunity to exploit learnt models of the environment is in decision-time planning. Model-based offline planning (MBOP) (Argenson & Dulac-Arnold, 2020) uses the learnt environment dynamics and a BC policy to roll-out a trajectory from the current state, one transition at a time. The best trajectory from the current state is found where the trajectory horizon has been extended using a value function and the first action is selected. This process is repeated for each new state. Model-based offline planning with trajectory pruning (MOPP) (Zhan et al., 2021) extends the MBOP idea, but prunes the trajectory roll-outs based on an uncertainty measure, safely handling the problem of distributional shift. Diffuser (Janner et al., 2022) uses a diffusion probabilistic model to predict a whole trajectory in one step. Rather than using a model to predict a single next state at decision-time, diffuser can generate unseen trajectories that have high likelihood under the data and maximise the cumulative rewards of a trajectory ensuring long-horizon accuracy. However, diffuser’s individual plans are very slow which limits its use case for real-world applications. Our MBTS method can be used in direct conjunction with planning, especially with MBOP and MOPP, which both use a BC policy to guide the trajectory sampling.

2.3 State similarity metrics

A central aspect of the proposed MBTS approach consists of a stitching event, which uses a notion of state similarity to determine whether two states are “close” together. Relying on only geometric distances would often be inappropriate; e.g. two states may be close in Euclidean distance, yet reaching one from another may be impossible (e.g. in navigation task environments where walls or other obstacles preclude reaching a nearby state). Bisimulation metrics (Ferns et al., 2004) capture state similarity based on the dynamics of the environment. These have been used in RL mainly for system state aggregation (Ferns et al., 2012; Kemertas et al., 2021; Zhang et al., 2020); they are expensive to compute (Chen et al., 2012) and usually require full-state enumeration (Bacci et al., 2013a, b; Dadashi et al., 2021). A scalable approach for state-similarity has recently been introduced by using a pseudometric (Castro, 2020) which facilitates the calculation of state-similarity in offline RL. PLOFF (Dadashi et al., 2021) is an offline RL algorithm that uses a state-action pseudometric to bias the policy evaluation and improvement steps. Whereas PLOFF uses a pseudometric to stay close to the dataset, we bypass this notion altogether by only using states in the dataset and generating unseen actions connecting them. Our stitching event is based from the decomposition of the trajectory distribution which allows us to pick unseen actions, but with high likelihood, determined by the future state.

2.4 Data re-sampling and augmentation approaches

In offline RL, data re-sampling strategies aim to only learn from high-performing transitions. For instance, best-action imitation learning (BAIL) (Chen et al., 2020) imitates state-action pairs based from the upper-envelope of the dataset. Monotonic Advantage Re-Weighted Imitation Learning (MARWIL) (Wang et al., 2018) weights state-action pairs from an exponentially-weighted advantage function during policy learning by BC. Return-based data re-balance (ReD) (Yue et al., 2022) re-samples the data based from the trajectory returns and then applies offline reinforcement learning methods. The proposed MBTS differs from BAIL, MARWIL and ReD as we increase the dataset by adding impactful stitching transitions as well as removing the low-quality transitions. MBTS has the effect of re-sampling high-value transitions in the trajectory as well supplementing the dataset with stitched transitions, connecting high-value regions.

Best action trajectory stitching (BATS) (Char et al., 2022) is a related trajectory stitching method: it augments the dataset by adding transitions through model-based planning. MBTS differs from BATS in a number of fundamental ways. First, BATS takes a geometric approach to defining state similarity; state-actions are rolled-out using the dynamics model until a state is found that is within a short distance of a state in the dataset. Relying exclusively on geometric distances may result in poor results; as such, our stitching events are based on the dynamics of the environment and are only assessed between two in-distribution states. Second, BATS generates new states that are not in the dataset. Due to compounding model error, resulting in unlikely rollouts, the rewards are penalised for the generated transitions which favours state-action pairs within the dataset. In contrast, we only allow one-step stitching between in-distribution states and use the value function to extend the horizon rather than a learned model. Finally, BATS adds all stitched actions to the original dataset, then create a new dataset by running value iteration, which is eventually used to learn a policy through BC. In contrast, our MBTS method has been designed to be more directly suited to policy learning through BC: since the lower-value experiences have been removed through stitching events, the resulting dataset contains only high-quality trajectories to learn from.

3 Methods

3.1 Problem setup

We consider the offline RL problem setting, which consists of finding an optimal decision-making policy from a fixed dataset. The policy is a mapping from states to actions, \(\pi : \mathcal {S} \rightarrow \mathcal {A}\), whereby \(\mathcal {S}\) and \(\mathcal {A}\) are the state and action spaces, respectively. The dataset is made up of transitions \(\mathcal {D} = \{(s_t,a_t,r_t,s_{t+1})\}\) that include the current state, \(s_t\), the action performed in that state, \(a_t\), the next state after the action has been taken, \(s_{t+1}\), and the reward resulting for transitioning, \(r_t\). The actions are assumed to follow an unknown behavioural policy, \(\pi _{\beta }\), acting in a Markov decision process (MDP). The MDP is defined as \(\mathcal {M} = (\mathcal {S}, \mathcal {A}, \mathcal {P}, \mathcal {R}, \gamma )\), where \(\mathcal {P}: \mathcal {S} \times \mathcal {A} \times \mathcal {S} \rightarrow [0,1]\) is the transition probability function which defines the dynamics of the environment, \(\mathcal {R}:\mathcal {S}\times \mathcal {A} \times \mathcal {S} \rightarrow \mathbb {R}\) is the reward function and \(\gamma \in (0,1]\) is a scalar discount factor (Sutton & Barto, 1998).

In offline RL, the agent must learn a policy, \(\pi (a_t \mid s_t)\), that maximises the returns defined as the expected sum of discounted rewards, \(\mathbb {E}_{\pi }\left[ \sum _{t=0}^{\infty } r_t \gamma ^t\right]\), without ever having access to \(\pi _{\beta }\). Here we are interested in performing imitation learning through BC, which mimics \(\pi _{\beta }\) by performing supervised learning on the state-action pairs in \(\mathcal {D}\) (Pomerleau, 1988, 1991). More specifically, BC finds a deterministic policy,

$$\begin{aligned} \pi ^{\text {BC}}(s) = \mathop {\mathrm {arg\,min}}\limits _{\pi } \mathbb {E}_{s_t, a_t \sim \mathcal {D}}\left[ (\pi (s_t) - a_t)^2\right] . \end{aligned}$$

This solution is known to minimise the KL-divergence between \(\pi _{\beta }\) and the trajectory distributions of the learned policy Ke et al. (2020). Our objective is to enhance the dataset, such that it has the effect of being collected by an improved behaviour policy. Thus, training a policy by BC on the improved dataset will lead to higher returns than \(\pi _{\beta }\).

3.2 Model-based trajectory stitching

Under our modelling assumptions, the probability distribution of any given trajectory \(\mathcal {T} = (s_0, a_0, s_1, a_1, s_2, \dots , s_H)\) in \(\mathcal {D}\) can be expressed as

$$\begin{aligned} p(\mathcal {T}) = p(s_0)\prod _{t=1}^{H} p(a_t \mid s_t) p(s_{t+1} \mid s_t, a_t) . \end{aligned}$$
(1)

where \(p(a_t \mid s_t)\) is the policy and \(p(s_{t+1} \mid s_t, a_t)\) is the environment’s dynamics. First, we note that, in the offline case, Eq. (1) can be re-written in an alternative, but equivalent form as

$$\begin{aligned} p(\mathcal {T}) = p(s_0)\prod _{t=1}^{H} p(s_{t+1} \mid s_t) p(a_t \mid s_t, s_{t+1}), \end{aligned}$$
(2)

which now depends on two different conditional distributions: \(p(s_{t+1}\mid s_t)\), the environment’s forward dynamics, and \(p(a_t\mid s_t,s_{t+1})\), its inverse dynamics. Both distributions can be approximated using the available data, \(\mathcal {D}\) (see Section 3.3). We also pre-train a state-value function \(V_{\pi _{\beta }}\) to estimate the future expected sum of rewards for being in a state s following the behaviour policy \(\pi _{\beta }\) as well as a reward function (see Section 3.4), which will be used to predict \(r(s_t, \hat{a}_t, s_{t+1})\) for any action \(\hat{a}_t\) not in \(\mathcal {D}\).

Equation (2) informs our data-improvement strategy, as follows. For a given transition, \((s_t, a_t, s_{t+1}) \in \mathcal {D}\), our aim is to replace \(s_{t+1}\) with \(\hat{s}_{t+1} \in \mathcal {D}\) using a synthetic connecting action \(\hat{a}_t\). A necessary condition for such a state swap to occur is that \(\hat{s}_{t+1}\) should be plausible, conditional on \(s_t\), according to the learnt forward dynamic model, \(p(s_{t+1} \mid s_t)\). Furthermore, such a state swap should only happen when landing on \(\hat{s}_{t+1}\) leads to higher expected returns. Accordingly, two criteria need to be satisfied in order to allow swapping states: \(p(\hat{s}_{t+1}\mid s_t) \ge p(s_{t+1}\mid s_t)\) and \(V_{\pi _{\beta }}( \hat{s}_{t+1}) > V_{\pi _{\beta }}( s_{t+1})\). The first criterion ensures that the new next state must be at least as likely to have been observed as the candidate state under the learnt dynamic model. Furthermore, to be beneficial, the candidate next state must not only be likely to be reached from \(s_t\) under the environment dynamics, but must also lead to higher expected returns compared to the current \(s_{t+1}\). This requirement is captured by the second criterion using the pre-trained value function. In practice, finding a suitable candidate \(\hat{s}_{t+1}\) involves a search for candidate next states amongst all the states that has been visited by any trajectory in \(\mathcal {D}\) (see Section 3.3). Where the two criteria above are satisfied, a plausible action connecting \(s_t\) and the newly found \(\hat{s}_{t+1}\) is obtained by generating an action that maximises the learnt inverse dynamics model. In summary, we have:

Definition 1

A candidate stitching event consists of a transition \((s_t, \hat{a}_t, \hat{s}_{t+1}, r(s_t, \hat{a}_t, \hat{s}_{t+1}))\) that replaces \((s_t, a_t, s_{t+1}, r(s_t, a_t, s_{t+1}))\) and it is such that, starting from \(s_t\), the new state satisfies

$$\begin{aligned} \hat{s}_{t+1} = \mathop {\mathrm {arg\,max}}\limits _{s_{t+1} \in \mathcal {D}} V_{\pi _{\beta }} (s_{t+1}) \quad \text {s.t } p(\hat{s}_{t+1} \mid s_t)>p(s_{t+1}\mid s_t) \end{aligned}$$

and the new action is generated by

$$\begin{aligned} \hat{a}_t = \mathop {\mathrm {arg\,max}}\limits _{\hat{a}} p(\hat{a} \mid s_t,\hat{s}_{t+1}). \end{aligned}$$

For every trajectory in the dataset, starting from the initial state, we sequentially identify candidate stitching events. For instance, in Fig. 1, two such events have been identified along the \(\mathcal {T}_1\) trajectory and eventually they yield a new trajectory, \(\hat{\mathcal {T}}_1\). When the cumulative sum of rewards along the newly formed trajectory are higher than those observed in the original trajectory, the old trajectory is replaced by the new one in \(\mathcal {D}\). This is captured by the following definition.

Definition 2

A trajectory replacement event is such that, if a new trajectory \(\hat{\mathcal {T}}\) started at the initial state \(s_0\) of \(\mathcal {T}\) has been compiled after a sequence of candidate stitching events, then \(\hat{\mathcal {T}}\) replaces \(\mathcal {T}\) in \(\mathcal {D}\) when the following condition is satisfied:

$$\begin{aligned} (1+\tilde{p})\sum _{t \in \mathcal {T}} r_t < \sum _{u \in \hat{\mathcal {T}}} r_u. \end{aligned}$$

In this definition, \(\tilde{p}\) is a small positive constant and the \((1+\tilde{p})\) terms ensures that the cumulative sum of returns in the new trajectory improves upon the old one by a given margin. This conservative approach takes into account potential prediction errors incurred by using the learnt reward model when assessing the rewards for \(\hat{\mathcal {T}}\).

The procedure above is repeated for all the trajectories in the current dataset. When any of the original trajectories are replaced by new ones, a new and improved dataset is formed. The new dataset can then be thought of as being collected by a different, and improved, behaviour policy. Using the new data, the value function is trained again, and a search for trajectory replacement events is started again. This iterative procedure is summarised below.

Definition 3

Trajectory Stitching is an iterative process whereby every trajectory in a dataset \(\mathcal {D}\) may be entirely replaced by a new one formed through trajectory replacement events. When such replacements take place, resulting in a new dataset, an updated value function is inferred and the process is repeated again.

The trajectory stitching method enforces a greedy next state selection policy (Definition 1) and guarantees that the trajectories produced by this policy have higher returns than under the previous policy (Definition 2). Therefore, we obtain a new dataset (Definition 3) collected under a new behaviour policy for which a new value function can be learned and the trajectory stitching process can be repeated. This iterative data improvement process is terminated when no more trajectory replacements are possible, or earlier (see Section 4).

The MBTS approach is highly versatile and can be implemented in a variety of ways. In the remainder of this section, we describe our chosen methods for modeling the two probability distributions featured in Eq. (1), as well as our techniques for estimating the state-value function and predicting the environment’s rewards. Our discussion is primarily motivated by continuous control problems, focusing on continuous state and action space domains. For applications in discrete spaces, alternative models should be employed.

3.3 Candidate next state search

The search for a candidate next state requires a learned forward dynamics model, i.e. \(p(s_{t+1} \mid s_t)\). Model-based RL approaches typically use such dynamics’ models conditioned on the action as well as the state to make predictions (Janner et al., 2019; Yu et al., 2020; Kidambi et al., 2020; Argenson & Dulac-Arnold, 2020). Here, we use the model differently, only to guide the search process and identify a suitable next state to transition to. Specifically, conditional on \(s_t\), the dynamics model is used to assess the relative likelihood of observing any other \(s_{t+1}\) in the dataset compared to the observed one. The environment dynamics is assumed to follow a Gaussian distribution whose mean vector and covariance matrix are approximated by a neural network, i.e.

$$\begin{aligned} \hat{p}_{\xi }(s_{t+1}\mid s_t) = \mathcal {N}(\mu _{\xi _1}(s_t), \Sigma _{\xi _2}(s_t)) \end{aligned}$$

where \(\xi = (\xi _1, \xi _2)\) indicate the parameters of the neural network. This modelling assumption is fairly common in applications involving continuous state spaces (Janner et al., 2019; Kidambi et al., 2020; Yu et al., 2020, 2021).

In our implementation, we take an ensemble of N Gaussian models, \(\mathcal {E}\); each component of \(\mathcal {E}\) is characterised by its own parameter set, \((\mu _{\xi ^i_1}, \Sigma _{\xi ^i_2})\). This approach has been shown to take into account epistemic uncertainty, i.e. the uncertainty in the model parameters (Argenson & Dulac-Arnold, 2020; Buckman et al., 2018; Chua et al., 2018; Yu et al., 2021). Each individual model’s parameter vector is estimated via maximum likelihood by optimising

$$\begin{aligned} \begin{aligned} \mathcal {L}_{\hat{p}}(\xi ) = \mathbb {E}_{s_t,s_{t+1} \sim \mathcal {D}} [ (\mu _{\xi _1} (s_t) - s_{t+1})^T \Sigma ^{-1}_{\xi _2}(s_t)(\mu _{\xi _1} (s_t) - s_{t+1}) + \log \mid \Sigma _{\xi _2}(s_t)\mid ], \end{aligned} \end{aligned}$$

where \(\mid \cdot \mid\) is the determinant of a matrix, and each model’s parameter set is initialised differently prior to estimation. Upon fitting the models, a state \(s_{t+1}\) is replaced by \(\hat{s}_{t+1}\) only when

$$\begin{aligned} \min _{i \in \mathcal {E}} \hat{p}_{\xi ^i} (\hat{s}_{t+1}\mid s_t) > \mathop {\mathrm {\,mean}}\limits _{i \in \mathcal {E}} \hat{p}_{\xi ^i}(s_{t+1}\mid s_t). \end{aligned}$$

In this context, we adopt a conservative approach, as we have greater confidence in the likelihood prediction of observed state-next state pairs, \(\hat{p}_{\xi ^i}(s_{t+1}\mid s_t)\), compared to unseen state-next state pairs, \(\hat{p}_{\xi ^i} (\hat{s}_{t+1}\mid s_t)\).

Performing this search over all next states in the dataset can be computationally inefficient. To address this, we assume that states that are far apart in Euclidean distance are not reachable, and therefore, we only evaluate these models on nearby states. We employ a nearest neighbors search organized by a KD-tree, as shown in line 7 of Algorithm 1, with the complete procedure detailed in the Appendix.

3.4 Value and reward function estimation

Value functions are widely used in reinforcement learning to determine the quality of an agent’s current position (Sutton & Barto, 1998). In our context, we use a state-value function to assess whether a candidate next state offers a potential improvement over the original next state. To accurately estimate the future returns given the current state, we calculate a state-value function dependent on the behaviour policy of the dataset. The function \(V_{\theta }(s)\) is approximated by a MLP neural network parameterised by \(\theta\). The parameters are learned by minimising the squared Bellman error (Sutton & Barto, 1998),

$$\begin{aligned} \mathcal {L}_V(\theta ) = \mathbb {E}_{s_t,r_t,s_{t+1}\sim \mathcal {D}} [ (r_t +\gamma V_{\theta }(s_{t+1}) - V_{\theta }(s_t))^2]. \end{aligned}$$
(3)

In our context, \(V_\theta\) is only used to observe the value of in-distribution states, thus avoiding the OOD issue when evaluating value functions which occurs in offline RL. The value function will only be queried once to determine whether a candidate stitching event has been found (Definition 1).

Value functions require rewards for training, therefore a reward must be estimated for unseen tuples \((s_t, \hat{a}_t, \hat{s}_{t+1})\). There are many different modelling choices available; e.g., under a Gaussian model, the mean and variance of the reward can be estimated allowing uncertainty quantification. Other alternatives include a Wasserstein-GAN, a VAE, and a standard multilayer neural network. In practice, the impact of the specific reward model and its effects when used for MBTS appears negligible (e.g. see Section 4.4.1). In the remainder of this section, we provide further details for one such model, based on Wasserstein-GAN (Arjovsky et al., 2017; Goodfellow et al., 2014), which we have extensively used in all our experiments (Section 4) and in our early investigations (Hepburn & Montana, 2022).

Wasserstein-GANs consist of a generator, \(G_{\phi }\) and a discriminator \(D_{\psi }\), with parameters of the neural networks \(\phi\) and \(\psi\) respectively. The discriminator takes in the state, action, reward, next state and determines whether this transition is from the dataset. The generator loss function is:

$$\begin{aligned} \mathcal {L}_G (\phi )= \mathbb {E}_{\begin{array}{c} z \sim p(z) \\ s_t,a_t,s_{t+1}\sim \mathcal {D} \\ \tilde{r}_t \sim G_{\phi }(z,s_t,a_t,s_{t+1}) \end{array}}[D_{\psi }(s_t,a_t,s_{t+1},\tilde{r}_t)]. \end{aligned}$$

Here \(z \sim p(z)\) is a noise vector sampled independently from \(\mathcal {N} (0,1)\), the standard normal. The discriminator loss function is:

$$\begin{aligned} \begin{aligned} \mathcal {L}_D (\psi )= \mathbb {E}_{s_t,a_t, r_t, s_{t+1} \sim \mathcal {D}}[D_{\psi }(s_t,a_t,s_{t+1},r_t)] - \mathbb {E}_{\begin{array}{c} z \sim p(z) \\ s_t,a_t,s_{t+1}\sim \mathcal {D} \\ \tilde{r_t} \sim G_{\phi }(z,s_t,a_t,s_{t+1}) \end{array}}[D_{\psi }(s_t,a_t,s_{t+1},\tilde{r}_t)]. \end{aligned} \end{aligned}$$

Once trained, a reward will be predicted for the stitching event when a new action has been generated between two previously disconnected states.

3.5 Action generation

Sampling a suitable action that leads from \(s_t\) to the newly found state \(\hat{s}_{t+1}\) requires an inverse dynamics model. Specifically, we require that a synthetic action must maximise the estimated conditional density, \(p(a_t\mid s_t,\hat{s}_{t+1})\). Given our requirement of sampling synthetic actions, a conditional variational autoencoder (CVAE) (Kingma & Welling, 2013; Sohn et al., 2015) provides a suitable approximation for the inverse dynamics model. The CVAE consists of an encoder \(q_{\omega _1}\) and a decoder \(p_{\omega _2}\) where \(\omega _1\) and \(\omega _2\) are the respective parameters of the neural networks.

The encoder maps the input data onto a lower-dimensional latent representation z whereas the decoder generates data from the latent space. We train a CVAE to maximise the conditional marginal log-likelihood, \(\log p(a_t\mid s_t,\hat{s}_{t+1})\). While intractable in nature, the CVAE objective enables us to maximize the variational lower bound instead,

$$\begin{aligned} \begin{aligned} \max _{\omega _1, \omega _2} \log p(a_t\mid s_t,\hat{s}_{t+1},z)&\ge \max _{\omega _1,\omega _2} \mathbb {E}_{z \sim q_{\omega _1}}[\log p_{\omega _2}(a_t\mid s_t,\hat{s}_{t+1},z)] \\&\quad - D_{\text {KL}}[q_{\omega _1}(z\mid a_t,s_t,\hat{s}_{t+1})\mid \mid P(z\mid s_t,\hat{s}_{t+1})], \end{aligned} \end{aligned}$$

where \(z \sim \mathcal {N}(0,1)\) is the prior for the latent variable z, and \(D_{\text {KL}}\) represents the KL-divergence (Kullback & Leibler, 1951; Kullback, 1997). To generate an action between two unconnected states, \(s_t \text { and }\hat{s}_{t+1}\), we use the decoder \(p_{\omega }\) to sample from \(p(a_t\mid s_t,\hat{s}_{t+1})\). This process ensures that the most plausible action is generated conditional on \(s_t\) and \(\hat{s}_{t+1}\).

figure a

4 Experimental results

In this section we first investigate whether MBTS can improve the quality of existing datasets for the purpose of inferring decision-making policies through BC in an offline fashion, without collecting any more data from the environment. Furthermore, we show that MBTS can help existing methods that explicitly use a BC term for offline learning to achieve higher performance. Specifically, we explore the use of MBTS in combination with two algorithms: model-based offline planning (MBOP) (Argenson & Dulac-Arnold, 2020), which uses an explicit BC policy to select new actions, and TD3+BC (Fujimoto & Gu, 2021), which has an explicit BC policy constraint. Our experiments rely on the D4RL datasets, a collection of commonly used benchmarking tasks, and include comparisons with selected offline RL methods. These comparisons provide an insight into the potential gains that can be achieved when MBTS is combined with BC-based algorithms, which often reach or even improve upon current state-of-the-art performance levels in offline RL. In Section 4.2, we show empirically that even with a small amount of expert data, the MBTS+BC policies become closer to the expert policy, in KL divergence. In all experiments, we run MBTS for five iterations; these have been found to be sufficient to increase the quality of the data without being overly computationally expensive (Section 4.3). Finally we provide ablation studies into the choice of reward model, as well as alternative extraction policies to BC.

Fig. 2
figure 2

Comparative performance of BC and MBTS+BC as the fraction of expert trajectories increases up to \(40\%\). For two environments, Hopper (left) and Walked2D (right), we report the average return of 10 trajectory evaluations of the best checkpoint during BC training. BC has been trained over 5 random seeds and MBTS has produced 3 datasets over different random seeds

4.1 Performance assessment on D4RL data

We compare our MBTS method on the D4RL (Fu et al., 2020) benchmarking datasets of the openAI gym MuJoCo tasks. Three complex continuous environments are tested - Hopper, Halfcheetah and Walker2d - each with different levels of difficulty. The “medium” datasets were gathered by the original authors using a single policy produced from the early-stopping of an agent trained by soft actor-critic (SAC) (Haarnoja et al., 2018a, b). The “medium-replay” datasets are the replay buffers from the training of the “medium” policies. The “expert” datasets were obtained from a policy trained to an expert level, and the “medium-expert” datasets are the combination of both the “medium” and “expert” datasets. A BC-cloned policy that used a MBTS dataset is denoted by MBTS+BC. All results and comparisons are summarised in Table 1 and detailed explanations of our methods are in order. We run MBTS for 3 different seeds, giving 3 datasets, we then train BC over 5 seeds for each new dataset giving 15 MBTS +BC policies.

Table 1 Average normalised scores of state-of-the-art offline RL methods achieved on three locomotion tasks (Hopper, Halfcheetah and Walker2d) using the D4RL v2 data sets

4.1.1 Behaviour cloning: MBTS+BC

The first method we investigate using MBTS with on the D4RL datasets is BC. Enriching the dataset with more high-value transitions and removing low quality ones leaves the dataset with closer-to-expert trajectories making BC the most suitable policy extraction algorithm. From Table 1 we can see that MBTS+BC improves over BC in all cases, showing that MBTS creates a higher quality dataset as claimed.

4.1.2 Model-based offline planning: MBTS+MBOP

Given previously presented evidence that MBTS improves over BC, a natural next step is to investigate whether MBTS can also improve on other methods that are reliant on BC. Model-based offline planning (MBOP) (Argenson & Dulac-Arnold, 2020) is an offline model-based planning method that uses a BC policy to rollout multiple trajectories picking the action that leads to the trajectory with highest returns. For this study, we alter MBOP slightly to obtain MBTS+MBOP: in this version, actions are selected using our MBTS extracted policy and we use our trained value function.

As can be observed in Table 1, MBTS+MBOP improves over the MBOP baseline in all cases. We also compare MBTS+MBOP to state-of-the-art model-based algorithms such as a MOPO (Yu et al., 2020), MOReL (Kidambi et al., 2020) and Diffuser (Janner et al., 2022); in these comparisons, MBTS+MBOP achieves higher performance in 5 out of the 9 comparable tasks. Only in the hopper medium and medium-replay tasks does another model-based method outperform MBTS+MBOP.

4.1.3 Model-free offline RL: TD3+ MBTS+BC

We also investigate the benefits of using MBTS in conjunction with a model-free offline RL algorithm. TD3+BC (Fujimoto & Gu, 2021) explicitly using BC in the policy improvement step to regularise the policy to take actions close to the dataset. As MBTS removes low-quality data, the learned Q-values will be inaccurate when trained solely on the new MBTS data. To counter this, we warm start TD3+BC on the original dataset, then use the new MBTS data to fine-tune both the critic and actor after the Q-values have been sufficiently trained. To keep this a fair comparison, we train the policy over the same number of iterations as reported in Fujimoto and Gu (2021). We make one small amendment to the Walker2d medium-replay dataset where we train the critic only using the original data, and use the MBTS data only to fine-tune the policy. We run TD3+ MBTS+BC on the same 5 seeds as reported in the original dataset.

As reported in Table 1, we find that, in all cases, TD3+ MBTS+BC outperforms the baseline method thus solidifying the positive effect of MBTS in offline RL. For this comparison, we also consider two additional state-of-the-art model-free offline RL algorithms: IQL (Kostrikov et al., 2021) and CQL (Kumar et al., 2020). In 6 out of the 9 comparable tasks, TD3+ MBTS+BC significantly improves over the model-free baselines. In the hopper medium-replay task, we find that TD3+ MBTS+BC under-performs compared to other model-free methods (IQL and CQL).

Fig. 3
figure 3

Estimated KL-divergence and MSE of the BC and MBTS+BC policies on the Hopper and Walker2d environments as the fraction of expert trajectories increases. (Left) Relative difference between the KL-divergence of the BC policy and the expert and the KL-divergence of the MBTS+BC policy and the expert. Larger values represent the MBTS+BC policy being closer to the expert than the BC policy. MSE between actions evaluated from the expert policy and the learned policy on states from the Hopper (Middle) and Walker2d (Right) environments. The y-axes (Middle and Right) are on a log-scale. All policies were collected by training BC over 5 random seeds, with MBTS being evaluated over 3 different random seeds. All KL-divergences were scaled between 0 and 1, depending on the minimum and maximum values per task, before the difference was taken

4.2 Expected performance on sub-optimal data

It is well known that BC minimises the KL-divergence of trajectory distributions between the learned policy and \(\pi _{\beta }\) (Ke et al., 2020). As MBTS has the effect of improving \(\pi _{\beta }\), this suggests that the KL-divergence between the trajectory distributions of the learned policy and the expert policy would be smaller post MBTS. To investigate this hypothesis, we use two complex locomotion tasks, Hopper and Walker2D, in OpenAI’s gym (Brockman et al., 2016). Independently for each task, we first train an expert policy, \(\pi ^*\), with TD3 (Fujimoto et al., 2018), and use this policy to generate a baseline noisy dataset by sampling the expert policy in the environment and adding white noise to the actions, i.e. \(a = \pi ^*(s) + \epsilon\), which gives us a stochastic behaviour policy. A range of different, sub-optimal datasets are created by adding a certain amount of expert trajectories to the noisy dataset so that they make up \(x\%\) of the total trajectories. Using this procedure, we create eight different datasets by controlling x, which takes values in the set \(\{0, 0.1, 2.5, 5, 10, 20, 30, 40\}\). BC is run on each dataset for 5 random seeds. We run MBTS (for five iterations) on each dataset over three different random seeds and then create BC policies over the 5 random seeds, giving 15 MBTS+BC policies. Random seeds cause different MBTS trajectories as they affect the latent variables sampled for the reward function and inverse dynamics model. Also, the initialisation of weights is randomised for the value function and BC policies hence the robustness of the methods is tested over multiple seeds. The KL divergences are calculated following (Ke et al., 2020) as

$$\begin{aligned} D_{KL}(p_{\pi ^*}(\mathcal {T}), p_{\pi }(\mathcal {T})) = \mathbb {E}_{s \sim p_{\pi ^*}, a \sim \pi ^*(s)}[\log \pi ^*(a \mid s) - \log \pi (a\mid s)]. \end{aligned}$$

Figure 2 shows the scores as average returns from 10 trajectory evaluations of the learned policies. MBTS+BC consistently improves on BC across all levels of expertise for both the Hopper and Walker2d environments. As the percentage of expert data increases, MBTS is available to leverage more high-value transitions, consistently improving over the BC baseline. Fig. 3 (left) shows the average difference in KL-divergences of the BC and MBTS+BC policies against the expert policy. Precisely, the y-axis represents \(D_{KL}(p_{\pi ^*}(\mathcal {T}), p_{\pi ^{\text {BC}}}(\mathcal {T})) - D_{KL}(p_{\pi ^*}(\mathcal {T}), p_{\pi ^{\text { MBTS+BC}}}(\mathcal {T}))\), where \(p_{\pi }(\mathcal {T})\) is the trajectory distribution for policy \(\pi\), Eq. (1). A positive value represents the MBTS+BC policy being closer to the expert, and a negative value represents the BC policy being closer to the expert, with the absolute value representing the degree to which this is the case. We also scale the average KL-divergence between 0 and 1, where 0 is the smallest KL-divergence and 1 is the largest KL-divergence, per task. This makes the scale comparable between Hopper and Walker2d. The figure shows that BC can extract a behaviour policy closer to the expert after performing MBTS on the dataset, except in the \(0\%\) case for Walker2D, however the difference is not significant. MBTS seems to work particularly well with a minimum of \(2.5\%\) expert data for Hopper and \(0.1\%\) for Walker2d.

Furthermore, Fig. 3 (middle and right) shows the mean square error (MSE) between actions from the expert policy and the learned policy for the Hopper (middle) and Walker2d (right) tasks. Actions are selected by collecting 10 trajectory evaluations of an expert policy. As we expect, the MBTS+BC policies produce actions closer to the experts on most levels of dataset expertise. A surprising result is that for \(0\%\) expert data on the Walker2d environment the BC policy produces actions closer to the expert than the MBTS+BC policy. This is likely due to MBTS not having any expert data to leverage. However, even in this case, MBTS still produces a higher-quality dataset than previous as shown by the increased performance on the average returns. Overall, these results offer empirical confirmation that MBTS does have the effect of improving the underlying behaviour policy of the dataset.

4.3 On the number of MBTS iterations

We investigate empirically how the quality of the dataset improves after each iteration; see Definition 3. We repeat MBTS on each D4RL dataset, each time using a newly estimated value function to take into account the newly generated transitions. In all our experiments, we choose 5 iterations. Figure 4 shows the scores of the D4RL environments on the different iterations, with the standard deviation across seeds shown as the error bar. With iteration 0 we indicate the BC score as obtained on the original D4RL datasets. For all datasets, we observe that the average scores of BC increase initially over a few iterations, then remain stable with only some minor random fluctuations. We see less improvement in the expert datasets as there are fewer trajectory improvements to be made. Conversely, for the medium expert datasets more iterations are required to reach an improved performance. For Hopper and Walker2d medium-replay, there is a higher degree of standard deviation across the seeds, which gives a less stable average as the number of iterations increases.

Fig. 4
figure 4

Returns of BC extracted policies as the number of iterations of MBTS is increased. Iteration 0 are the BC scores on the original D4RL datasets. The errors bars represent the standard deviation of the average returns of 10 trajectory evaluations over 5 random seeds of BC and 3 random seeds of MBTS

4.4 Ablation studies

In this Section we perform ablation studies to assess the impact of the reward model on MBTS performance and the effect of value-weighted BC.

4.4.1 Choice of reward model

MBTS requires a predictive model for rewards associated to the stitched transitions enabling a value function to be learned on the new dataset. Unlike some online methods (Chua et al., 2018; Nagabandi et al., 2018) we do not have access to the true reward function during training time and so a model must be trained to predict rewards. There are many choices of models. For example, MBPO (Janner et al., 2019), MOPO (Yu et al., 2020) and MBOP (Argenson & Dulac-Arnold, 2020) use a neural network that outputs the parameters of a Gaussian distribution, to predict the next state and reward. These models are coupled with the next state as well as reward. We solely want to predict the reward and consider the following options: a Gaussian distribution whose parameters are modelled by a neural network, a Wasserstein-GAN, a VAE and multilayer neural network that minimizes the mean square error between true and predicted reward.

We evaluate the reward models on the D4RL hopper-medium dataset and perform a 95 : 5 training and test split. To make it a fair test all models are trained on the same training data and each model has two hidden layers with dimension size 512. Fig. 5 shows the mean-square error (MSE) between predicted and true rewards during training on the test and train set. From this clearly the VAE model and MLP model perform the best by attaining the smallest error, getting training and test error to \(10^{-5}\). The average reward for a transition in the hopper-medium dataset is 3.11, so in fact the GAN also performs very well by attaining a training and test error of order \(10^{-4}\).

Fig. 5
figure 5

Assessment of different types of models to predict reward on the hopper-medium D4RL dataset. The MSE between predicted and true rewards are assessed during training on a test set and training set of the same size

In MBTS we want to predict a reward for an unseen transition, where s and \(s'\) are in the dataset but have never been connected by an observed action. Therefore, we evaluate the trained reward models on unseen data to test their OOD performance. Table 2 shows the MSE between predicted and true rewards of the models on the rest of the D4RL hopper datasets: random, expert and medium replay. The GAN, VAE and MLP perform very similarly achieving accurate predictions on all three datasets. The VAE and MLP outperform the GAN in predicting rewards of the expert dataset. The Gaussian model performed very poorly on these datasets.

Table 2 MSE between true and predicted rewards from the reward functions evaluated on the other D4RL hopper datasets

Finally we compare MBTS(WGAN)+BC with MBTS(MLP)+BC on the D4RL datasets; here, either a WGAN or MLP is used to predict the reward. Table 3 shows that the decision between using a WGAN or MLP is insignificant as they are both accurate enough at predicting rewards.

Table 3 Comparison of BC, MBTS(WGAN)+BC and MBTS(MLP)+BC on the D4RL locomotion tasks

4.4.2 Value-weighted BC

MBTS uses a value function to estimate the future returns from any given state. Therefore MBTS+BC has a natural advantage over just BC which uses only the states and actions. To ensure that using a value function is only sufficient to improve the performance of BC, we investigate a weighted version of the BC loss function whereby the weights are given by the estimated value function, i.e.

$$\begin{aligned} \pi ^{\text {BC}}(s) = \mathop {\mathrm {arg\,min}}\limits _{\pi } \mathbb {E}_{s, a \sim \mathcal {D}}[V_{\theta }(s)(\pi (s) - a)^2]. \end{aligned}$$

This weighted-BC method gives larger weight to the high-value states and lower weight to the low-value states during training.

On the Hopper medium and medium-expert datasets, training this weighted-BC method only gives a slight improvement over the original BC-cloned policy. For Hopper-medium, weighted-BC achieves an average score of 59.21 (with standard deviation 3.4); this is an improvement over BC (55.3), but lower than MBTS+BC (64.3). Weighted-BC on hopper-medexp achieves an average score of 66.02 (with standard deviation 6.9); again, this is a slight improvement over BC (62.3), but significantly lower than MBTS+BC (94.8). The experiments indicate that using a value function to weight the relative importance of seen states when optimising the BC objective function is not sufficient to achieve the performance gains introduced by MBTS.

5 Discussion

The proposed method, MBTS, has been presented for learning an optimal policy in continuous state and action spaces. For other domains, alternative modeling techniques for dynamic models should be considered. As demonstrated in Section 4.2, MBTS is expected to be beneficial in settings with sub-optimal data, even with a small percentage of additional expert data. Notably, MBTS does not require expert transitions within the initial dataset, as evidenced by the nearly \(80\%\) improvement in Walker2d medium-replay, which contains no expert data.

Empirically, MBTS does not damage performance; however, this does not guarantee that performance will not decrease under all circumstances. We believe that the primary risk of MBTS harming performance lies in the use of imperfect models. To mitigate model error in the forward models, we employ an ensemble and take a conservative approach to determining the next state, as shown in Eq. . To further account for model error, we only replace existing trajectories with new ones if they significantly increase returns, as per Definition 2. Due to Definition 2, we believe that the MBTS procedure is consistently safe when replacing existing trajectories, as it guarantees performance improvement. However, this may result in the loss of potentially useful information that MBTS can no longer use in future iterations.

Our method might be perceived as computationally intensive due to the number of models and the need to iterate over the entire dataset. However, we use our models with their limitations in mind, avoiding extrapolation of unseen states and generating actions only between in-distribution reachable states. The primary computational burden comes from searching for reachable next states. We have reduced this burden by evaluating the forward model only on “reasonably close” states using nearest neighbors organized by a KD-tree. KD-trees are well-studied and have a worst-case time complexity of \(\mathcal {O}(k\cdot n^{1-\frac{1}{k}})\) for k-dimensional trees and n data points (Lee & Wong, 1977). Consequently, we assume that states far apart in Euclidean distance are not reachable and do not require evaluation. Our approach is detailed in the Appendix; however, alternative methods could be employed to reduce complexity, and the technique we used is not strictly integral to the MBTS framework. Moreover, our method converges in a remarkably few number of iterations (max 5), as shown in Fig. 4, significantly reducing the computational cost.

6 Conclusion

In this paper, we have proposed an iterative data improvement strategy, Model-Based Trajectory Stitching, which can be applied to historical datasets containing demonstrations of sequential decisions taken to solve a complex task. At each iteration, MBTS performs one-step stitching between reachable states within the dataset that lead to higher future expected returns. We have demonstrated that, without further interactions with the environment, MBTS improves the quality of the historical demonstrations, which in turn has the effect of boosting the performance of BC-extracted policies significantly. Extensive experimental results using the D4RL benchmarking data have demonstrated that MBTS always improves the underlying behaviour policy. We have also demonstrated that MBTS is beneficial beyond BC, when combined with existing offline reinforcement learning methods. In particular, MBTS can be used to extract an improved explicit BC-based regulariser for TD3+BC, as well as an improved BC prior for offline model-based planning (MBOP). MBTS-based methods achieve state-of-the-art results in 10 out of the 12 D4RL datasets considered.

We believe that this work opens up a number of directions for future investigation. For example, MBTS could be extended to multi-agent offline policy learning by reformulating Eq. 2 to actions taken by multiple agents. Besides the realm of offline RL, MBTS may also be useful for learning with sub-optimal demonstrations, e.g. by inferring a reward function through inverse RL. Historical demonstrations can also be used to guide RL and improve the data efficiency of online RL (Hester et al., 2018). In these cases, BC can be used to initialise or regularise the training policy (Rajeswaran et al., 2017; Nair et al., 2018).