Keywords

1 Introduction

Deformable image registration is an immanent part of medical imaging processing. With the advances in machine learning over the last years, deformable image registration based on deep learning has become more and more prevalent and image registration algorithms can generally be divided into approaches based on conventional optimisation, deep learning-based approaches and those that combine both. Intra-patient lung registration is one of multiple common medical registration tasks with clinical applications such as regional COPD analysis or nodule tracking. As the Learn2Reg 2021 challenge [10] has shown, methods that rely solely on deep learning do not yet achieve state-of-the-art performance in regards to intra-patient lung registration. The best four submissions to the challenge were either completely conventional optimisation-based or employed instance optimisation after application of a deep learning model. Deep learning based submissions to the Learn2Reg challenge aimed to predict the displacement based on fixed and moving image features directly through feedforward networks. We propose a learn to optimise (L2O) network with a different approach to deformable image registration by employing a recurrent network to emulate instance optimisation.

1.1 Related Work

Neural networks, especially recurrent ones, have been previously used to mimic and improve upon the procedure of optimisation.

Andrychowicz et al. [1] discretly parameterised an optimiser and based on that employed a recurrent network to learn an optimiser function. Previous works have emulated optimisation specifically for image registration: Teed and Deng [23] introduced the architecture RAFT, which uses a recurrent network to iteratively update a displacement field by sampling from a (hypothetical) dense correlation volume. They achieved state-of-the-art performance on KITTI and demonstrated that their network was highly efficient and able to generalise well. Liu et al. [14] used a recurrent multi-scale network to solve a differomorphic registration task with geometric constraints. While the spatial transformer loss [11] is widely employed for unsupervised (metric-based) learning of feedforward registration networks [6] its gradients may provide limited guidance for complex registration tasks and thus explain some of the limitations of DL-based methods in lung registration. [12] aims at improving the gradient estimation through linearised multi-sampling, but do not employ a trainable optimiser.

Recurrent networks have also been succesfully applied to further medical image registration tasks: Lu et al. [15] used recurrent networks to incorporate the temporal aspect of 4DCT data. Sandkühler et al. [20] calculated a sequence of local displacements with a recurrent network, which they then compose to obtain final global displacement. Sun et al. [22] used recurrent reinforcement learning to achieve robustness in multimodal brain image registration. These methods are different to our work, in that they either directly rely on multiple frames of an input sequence, are limited to 2D data or realise a cascade of multiple feedforward networks that do not resemble gradient based optimisation of a cost function.

1.2 Adam Optimisation

In image registration we aim to compute a displacement field \(\varphi \) that minimises an objective function \(f(\varphi )\) that usually consists of a distance metric of a warped moving image \(M\circ \varphi \) and a fixed image F along with a regularisation for \(\varphi \).

In iterative optimisation, the displacement field gets updated with every iteration. In conventional first-order gradient descent, the update is dependent on the gradient of f. For high-dimensional optimisation problems like image registration, it may be more beneficial to approximate the gradient \(\nabla f\) with a transformation that comprises a smaller number of control points, i.e. a free-form deformation. Stochastic gradient descent can be used to speed up convergence by adding a factor of randomness through sampling [21].

In Adam optimisation [13], the update rule is extended by including estimates of first and second momentum of the gradients, and determining each update from the gradient of the current sample as well as the previous update. Each update is thus dependent on all previous updates.

1.3 Our Contribution

We propose a novel recurrent framework using an iterative dynamic cost sampling step and a trainable optimiser that is specifically aimed at mimicking Adam optimisation but can substantially reduce the required number of iterations. The sampled features themselves are identical to the minimal information required to estimate gradients of the dissimilarity between an image pair used in one iteration of conventional displacements optimisation. Specifically, image features are only provided to the network indirectly in the form of displacement costs. We further provide the network with hidden state features to provide the network with necessary information to calculate gradient momentum like Adam optimisation does. Based on this information the optimiser network is enabled to learn larger gradient steps (leap frogging) based on deep (multi-step) supervision with keypoint correspondences that also promotes nearly equal-sized optimisation steps of the displacement field with each iteration.

2 Methods

2.1 Pre-registration

In order to evaluate our method as an alternative to instance optimisation we first employ a recent feedforward (non-recurrent) learning based deformable registration called VoxelMorph++ (VM++) [8]. For VM++, VoxelMorph [2] is adapted to the lung registration task by training it with automatically extracted keypoint correspondences from an accurate but long-running conventional registration algorithm (corrField [7]). The network is intended to be fine-tuned with instance optimisation.

2.2 Extraction of Optimisation Inputs

We extract a total of 45 optimisation inputs for each image voxel/control point. Three of those channels are made up of the current predicted displacements \(\varphi ^{t}\). For the first iteration they are initialised either with the displacement field obtained by VoxelMorph++ or an identity transform. For each voxel eight displacements coordinates are computed using a fixed grid of subpixel offsets, resulting in 24 feature channels. For each sampled coordinate a dissimilarity cost (i.e. 8 features in total) is calculated and also fed into the network. To determine the cost term, MIND features [9] of the moving image are sampled at the displacement coordinates and for each voxel and displacement the dissimilarity to the fixed image, i.e. the sum of squared distance to the fixed features, is calculated. Displacements and coordinates can be used to approximate the gradients for an iterative gradient descent, which follows e.g. [18]. The remaining 10 feature channels are hidden states that are propagated through all iterations. Through these, the model is able to save information about previous update steps and incorporate them similar to how Adam uses momenta. All features get updated with each recurrent application of the network. That means the coordinates and dissimilarity costs dynamically change across recurrent states and mimic the iterative fashion of conventional registration.

2.3 Optimiser Network

The feature map is fed into a U-Net [19] architecture with roughly 1 million trainable parameters. The network output comprises 13 feature channels, three forming the predicted displacements, while the remaining ten represent the hidden states of the network. During training, we apply eight recurrent iterations of the network, for which we employ deep keypoint supervision. We weight the loss linearly stronger with each iteration, to support the network in getting increasingly closer to the optimal value with each iteration, aiming at a behaviour similar to gradient descent. All weights are shared across iterations, which means the number of iterations does not have to be fixed and can be varied at inference time. A schematic illustration of our method is depicted in Fig. 1.

Fig. 1.
figure 1

Schematic depiction of our learn to optimise network. The feature map used in our recurrent network consists of four types of features. Displacement Field: Last predicted displacement. Coordinates: For every voxel, a fixed number of possible displacements is sampled. Dissimilarities: Displacement cost for each of the sampled displacements. Hidden states: Part of the U-Net output that is not supervised.

2.4 Comparison to Feed-Forward Nets and Adam Optimisation

Our method differs from networks that aim to directly predict a displacement field based on the input of two images. Foremost, the moving and fixed image features themselves are not directly given as input to the network. Instead, the network is only provided with similarity costs and displacement coordinates. Moreover, the employed U-Net architecture has a small capacity with few parameters, which forces the model to generalise more strongly.

The method, however, also differs from Adam optimisation in some important aspects. Since the hidden states are not supervised, the model does not necessarily make direct use of first and second momentum but could learn more useful nonlinear relations between previous updates, current coordinates, dissimilarity costs and the desired displacement update (during training). Additionally, while Adam optimisation solely has information about the current image it is optimising, our model can make use of further information over the image population, e.g. keypoint positions across multiple patients, through training.

3 Experiments and Results

Table 1. Mean target registration errors in mm of our and comparison methods on the 4DCT and COPD datasets.

We train our network on the public EMPIRE10 (selected cases), DIR-Lab COPD and DIR-Lab 4DCT datasets [3,4,5, 17], which contain a total of 28 pairs of ex- and inspiratory lung CT scans. Experiments were conducted in a 5-fold cross-validation. We evaluate our method on the public COPD and 4DCT datasets. Each registration pair is annotated with 300 manual corresponding landmarks, that are used to assess the registration accuracy. We compare our learning based L2O approach with continuous Adam based optimisation in two settings: 1) on the original inhale and exhale scans and 2) as instance optimisation step for image pairs, that are pre-registered using the widely used Voxelmorph framework (in the afore described VM++ variant). LapIRN [16] as winner of the recent Learn2Reg challenge [10] was chosen as a comparison method with the deep learning-based state-of-the-art. It was evaluated with and without subsequent instance optimisation using Adam. Our proposed optimisation framework itself is analysed in different ablation studies, investigating the effect of different number of optimisation iterations and the use of deep supervision.

Fig. 2.
figure 2

Left: Residual target registration error (TRE) for Adam and our proposed L2O approach (with and without pre-registration) after 1,2,4 and 8 iterations on case #5 of the COPD dataset. Warm and cold colors correspond to high and low errors, respectively. Right: Moving, warped and fixed image (masked around the lung) of the same case for L2O with pre-registration.

Implementation Details: All experiments were conducted on an Nvidia RTX A4000 using PyTorch 1.10. The 850 employed epochs of training for our L2O framework need 8.2 h per fold and 7 GB GPU memory. Hyperparameters were optimised on a single training fold and kept fixed for all further experiments.

General Results: Quantitative results of our and comparison methods can be found in Table 1. Our learned optimisation model reduces the TRE from 6.33 mm to 1.69 mm and from 11.99 mm to 2.24 mm for the 4DCT and COPD datasets, respectively, whereas VM++ alone, which is used as pre registration in our experiments, only achieves TREs of 4.40 mm and 5.30 mm. Our method clearly outperforms Adam in settings without pre-registration, while yielding lower accuracy on the 4DCT dataset with pre-registration, where Adam reaches a target registration error of 1.33 mm and achieving on par results for the COPD dataset with pre-registration (2.18 mm vs 2.24 mm). However, the log-Jacobian determinants of displacement fields yield mean standard deviations of 0.04 for Adam and 0.02 for L2O as well as significant lower number of foldings for L2O, indicating that our optimisation approach produces smoother and less complex deformations. In addition, results reported for L2O are already achieved after 8 optimisation steps, whereas Adam needed at least 50 iterations for convergence. The comparison with the state-of-the-art (LapIRN + Adam) is also favourable. While the accuracy is roughly comparable for the 4DCT dataset, the TRE is almost two times better for the COPD dataset (3.83 mm versus 2.24 mm). A visual comparison of individual optimisation steps of Adam and L2O in Fig. 2 provides further insights. The spatial residual error of the target registration is shown after 1, 2, 4 and 8 iterations. The iterative improvement is clearly visible in both methods. However, Adam gets stuck in local minima at the lung surfaces. The faster convergence of L2O is also evident. The biggest improvements can already been seen in the first 2 iterations. When starting from a pre-registered scan pair, the improvement of the registration errors is in particular evident in the boundary regions.

Ablation Studies: Figure 3 shows target registration errors in the course of the individual optimisation steps. L2O shows the fastest convergence in all settings and is only outperformed by Adam after 16 and 24 iterations on the 4DCT and COPD dataset with pre-registration, respectively. It should be noted that the number of iterations where only altered during inference and it is thus likely that the results for L2O can be further improved when more optimisation steps are considered during training (which then of course may come at the cost of longer training times). Finally, deep supervision seems to be of great importance for learning a meaningful optimisation at each iteration. Omitting deep supervision increases the TRE by 0.5 mm and 1 mm for 4DCT and COPD, respectively.

Fig. 3.
figure 3

Comparison of target registration error (TRE) between Adam and our L2O approach depending on the number of iterations. We present results with and without pre registration (using VM++).

4 Discussion

We proposed a framework that learns gradient-based optimisation steps for deformable image registration using a recurrent deep learning network. The structure of our model as well as the evaluation results indicate that our model emulates an optimisation rather than directly estimating a displacement field for given input images.

The network uses less iterations than Adam optimisation, indicating it does not only approximate first-order gradients and imitate gradient descent, but rather learns how to optimise using all available information and overcomes this limitation of gradient descent algorithms.

Since registration error slightly increases at inference time after using a large number of iterations, we hypothesise that our model learns typical properties of all given image pairs that are in relation to optimal alignment, i.e. the distance of landmarks in the training images, and optimises its internal parameters so that these properties are matched after executing the set number of iterations. This means in contrast to conventional optimisation it makes use of the population-wide information that is provided during training. Additionally, our model is less dependent on an accurate pre-registration, which demonstrates that it is also more robust with respect to local minima.

In conclusion, even though our model does not yet always surpass Adam instance optimisation, it combines advantages of deep learning and gradient descent algorithms and yields promising results when it comes to compensating disadvantages of conventional optimisation.