1 Introduction

Traffic forecasting is important for location-based applications such as intelligent transportation systems and urban planning [11]. Real-time accurate traffic flow prediction can help improve traffic efficiency and reducing traffic congestion in traffic control. Especially, in road peak periods and traffic accident-prone areas, accurate short-term traffic flow prediction can not only provide a judgment basis for travelers to choose the optimal path, but also provide strong data support for managers to formulate effective control measures and thus reduce traffic congestion.

In general, traffic states or events of a spatial unit (e.g., region and street) are not isolated, but influenced by its neighbors. This is a typical phenomenon of spatial dependency that has been extensively considered in current traffic prediction studies [4, 21, 22, 51]. The spatial dependency can be expressed by the first law of geography: “Everything is related to everything else, but near things are more related than distant things” [42]. For instance, traffic congestion states may propagate from one road to another due to rush hour, unexpected accidents, or unreasonable traffic management etc. Therefore, it is vital to consider spatial dependency among road segments. Besides, historical traffic flow data is also interdependent in temporal. As illustrated by Fig. 1, we can observe apparent similarity between two adjacent time periods (weeks) and even different days in a period. However, dynamic changes in spatio-temporal characteristics are random, can occur at any time, and thus difficult to capture. Thus, accurate traffic forecasting is challenging.

Fig. 1
figure 1

Temporal distribution of Shenzhen taxi movements in two continuous weeks: from 02/18/2019(Monday) to 02/24/2019(Sunday) and from 02/25/2019(Monday) to 03/03/2019(Sunday)

Numerous traffic prediction algorithms have been developed within the last few decades. The auto regressive integrated moving average (ARIMA) [28, 45] methods took advantage of repeating occurrences in temporal historical data. However the data is required to be smooth and continuous. And the prediction accuracy is usually limited for the complex spatial-temporal attributes of urban traffic. Machine learning models, such as k-nearest neighbor [56], support vector regression [49], only need enough historical data to automatically establish the nonlinear feature mapping relationships between input and output. But the prediction accuracy of these methods is also not so satisfactory due to the difficulty in fully capturing of the complex nonlinear relationship. With the rise of deep neural networks (DNNs), many researchers also investigated the usage of DNNs in traffic prediction, either using CNNs [15, 57] to capture spatial dependencies or using LSTMs [8] to learn temporal depencence. However, the spatial topology of the traffic network may lose when the data is represented by matrix or multidimensional tensors. Therefore, even if CNN can extract spatial dependence, its effect on traffic flow prediction is still limited [23]. While LSTMs are computationally slow as they can not be trained in parallel, even though some acceleration methods [46] were proposed.

To tackle these challenges, we propose a new short-time traffic flow prediction framework based on deep learning. Attention mechanism is introduced in the framework to handle phenomenons that are hardly considered in previous methods such as spatial nonstationarity. To be specific, STAtt-Net incorporates a spatial-tempral attention model with the purpose to capture the global spatio-temporal by considering the interactions of region-to-region. In addition, we fuse three components: i) trend for weekly trend, ii) period for daily periodicity, and iii) closeness for recent time dependence together so as to capture temporal similarity more effectively.

The main contributions of our work are summarized as follows:

  • A novel model, Spatio-Temporal Attention mechanism Network (STAtt-Net) for short-term traffic flow prediction, which can effectively exploits dynamic both temporal and spatial dependency in traffic.

  • An attention based module (STBlock) for traffic prediction, capable of dynamic modeling the association between any two locations in a city.

  • Extensive experiments with the proposed framework on three real-world datasets TaxiBJ, BikeNYC and TaxiSZ, with the experimental results revealing better performance of the proposed model over several state-of-the-art approaches.

2 Related work

2.1 Traffic flow prediction

Traffic flow prediction can be seen as a spatial-temporal forecast problem. Traditional methods targeting on this problem usually establish a time series model and exploit the relevant information hidden in the historical data for prediction. These methods can be categorized into parametric and nonparametric. Parametric approaches include autoregressive integrated moving average (ARIMA) model [7, 44], Kalman filtering (KF) [32], Structural time-series model(STM) [10] and latent space model [6] etc. However, these models rely on the stationary assumptions of traffic time series data and ignore the temporal and spatial dynamics. In order to deal with the stochastic and nonlinear nature of traffic data, researchers have paid much attention to the non-parametric approaches such as K-nearest neighbor(KNN) [3], Support Vector Regression(SVR) [36], Random Forest(RF) [2], etc. Unfortnuately, most non-parametric approaches are limited to model the complex dynamic spatial-temporal dependency. With the great success of deep learning in various applications such as computer vision [1, 17, 18, 29, 31, 37, 37, 40, 41], nature language processing [16], public health [30, 38, 47] and economy [19] etc, recent researches have leveraged deep learned features to further improve the performance of prediction by adopting various deep learning neural networks. An early attemp by Huang et al. [13] used a deep belief networks (DBN) with multitask learning for traffic prediction. Although capble of mining high-dimensional features from traffic data, it is difficult to extract specific spatio-temporal features. Since RNNs are adept at extracting the correlation of temporal feature, it is unsurprisingly that many works in the area [8, 25] are built on RNN and it variants like Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU). For spatial dependency, CNN were introduced in ST-ResNet [57] to capture spatial correlation, combining with the residual unit for citywide traffic forecasting. Traffic flows were treated as a raster image to model the temporal closeness, period, trend, and external factors. And how to model nonlinear and complex spatial-temporal data simultaneously becomes a challenge. Shikhar et al. [34] employe 3D CNNs to recognize the patterns in volumetric data like videos, which proves the superior characteristics of 3D CNNs. Based on this characteristic, [9] apply 3D CNNs to automatically model saptio-temporal information and thus improve the accuracy of prediction. However, the improvement is limited due to the inefficient mining of spatio-temporal information. ConvLSTM [35] was proposed to settle spatio-temporal sequence forecasting problem, with a rather complex network structure. With the deepening of the network, training becomes more difficult. A major problem with current CNN based methods is that CNNs are suitable for Euclidean data (such as images, regular grids, etc.) but can not handle road networks with complex topology well. The recent work of Zeng et al. [53] revealed that when revisiting the modifiable area unit problem in deep traffic prediction, and tried to address the problem with deformable convolutions in the follow up [53].

This paper follows the work of Zeng et al. [53] to improve the performance of CNNs based model on complex road networks, but with a different strategy by introducing the attention mechanism.

2.2 Attention mechanism

Attention mechanisms is a recent popular topic, being widely used in different types of deep learning tasks such as natural language processing [24], image classification [27], machine translation [5] etc. Attention mechanism mimics the human brain’s tendency to focus on something of interest and automatically ignore low-value information. Essentially, it is a combinatorial function that computes the probability distribution of attention to highlight the impact of a key input on the output, thereby achieving an efficient allocation of information processing resources. Mnih et al. [27] pioneered the use of attention mechanisms in image classification tasks and combined them with recurrent neural network models. Non-local Networks [43] utilizes self-attention as a non-local operation to capture long range dependencies. Yan et al. [48] proposed LVSNet for liver vessel segmentation, which employing an attention-guided concatenation module to enhances segmentation details. Hu et al. [12] proposed squeeze-and-excitation (SE) block to explicitly models the interdependencies between feature maps and adaptively obtains the importance of each feature map by learning. For traffic flow prediction, [50] designed a spatial-temporal dynamic network with a periodic transfer attention mechanism deal with capture long-term periodic temporal similarity. Zheng et al. [58] proposed a graph multi-attention network with spatial and temporal attentions. However the graph attention mechanism is only applicable to non-Euclidean structured data.

We design and integrate a spatio-temporal attention mechanism block to capture the dynamic relevance of the traffic network in the spatial and temporal dimensions respectively. Different from Zheng et al. [58], our mechanism works on regular grided map to handle Euclidean structured data.

3 Problem definition

We would like to firstly introduce some definitions and formulate the problem in this section before going into details of the method. The used notations are listed in Table 1.

Table 1 Meanings of all notations

Definition 1 (Movement)

A movement m is a continuously measured trajectory of a moving object during a time period \(\mathcal {T}\), which is defined by a set of spatio-temporal records \(\bigcup _{t \in \mathcal {T}} <t, l_{t}>\), where lt represents the position of m at time t. We denote all movements of multiple moving objects as \({\mathscr{M}}\).

Definition 2 (Region)

There are many ways to partition a geographical area into a collection of appropriate regions \(\mathcal {R} = \{r_{k}\}_{k=1}^{n}\). For example, a city can be partitioned into n = I × J equal-sized grids based on latitude and longitude, in which each grid is regarded as an independent spatial unit. Besides, according to census block information or function of different parts, the area can also be divided into non-overlaping and independent regions called traffic analysis zones (TAZ) [26].

Definition 3 (Inflow & Outflow)

From the movements \({\mathscr{M}}\), we compute the inflow \(x_{r_{i, j}, t}^{in}\) and outflow \(x_{r_{i, j}, t}^{out}\) per time slot t for each region \(r_{i, j} \in \mathcal {R}\) with the following formulas:

$$ \begin{array}{@{}rcl@{}} x_{r_{i, j}, t}^{i n} =\quad \mid\{m \in \mathcal{M} \mid m \cdot l_{t-1} \notin r_{i, j} \wedge m \cdot l_{t} \in r_{i, j}\}\mid \end{array} $$
(1)
$$ \begin{array}{@{}rcl@{}} x_{r_{i, j}, t}^{\text {out }}=\quad \mid\{m \in \mathcal{M} \mid m \cdot l_{t} \notin r_{i, j} \wedge m \cdot l_{t+1} \in r_{i, j}\}\mid \end{array} $$
(2)

where |⋅| denotes the cardinality of the set. \(x_{r_{i, j}, t}^{in}\) and \(x_{r_{i, j}, t}^{out}\) indicates inflow and outflow at per time slot t for region ri,j, respectively.

3.1 Problem formulation

Problem 1 (Short-term traffic flow prediction)

Given a set of citywide historical traffic flow data represented by a series of matrices \(\left \{X_{\mathcal {R},t} \mid t=1,2, \ldots , n\right \}\) in Region \(\mathcal {R}\), the problem of traffic forecasting is to predict the traffic flow for all region cells in the next time interval t + 1, denote as \( X_{\mathcal {R},t+1}\).

4 Methodology

4.1 Data processing

Prior work [54] has demonstrated that the modifiable areal unit problem [33] within aggregation processes can lead to perturbations in the network inputs. As such, it eventually lead to inaccurate traffic flow forecasting results, affecting traffic planning decisions and other applications. We intend to further explore the effects of partition scale and manner on the prediction accuracy of of deep learning model in this work. We used the same three datasets (TaxiBJ, BikeNYC, and TaxiSZ) for experiments. Maps of the studying areas, (Beijing, New York and Shenzhen), need to be processed first for CNN input.

The map of Beijing is divided into 32 × 32 grids based on longitude and latitude. The data of the last four weeks in the dataset is kept as test set, and the rest is used for training. For BikeNYC, the entire city is break up into 8 × 16 grid map. The data of the last ten days in the data set is fetched for testing, and the rest data is used for training.

To explore the impact of different partition shapes, we use two types of partitioning to process the TaxiSZ dataset: TAZs and grids. TAZs are special zones usually designated by the department of transportation for tabulating traffic-related census data. A TAZ, a geographic grouping of census units, occupies a contiguous region with a minimum population of 600 in general. Besides, the border of a TAZ usually corresponds with recognizable physical boundaries, such as main streets and water sources. The land use activities and populations within each TAZ are relatively homogeneous. Thus TAZ partition can satisfy the need of traditional transportation planning and demand analysis better. However, the sizes of the regions are critical. Too small regions make the entry and departure between neighboring regions more frequent at various predictive times, introducing a lot of computational complexity. Larger regions, while greatly reducing the computational burden, are meaningless for traffic prediction purposes. We use the 491 TAZs provided by the Shenzhen Transportation Department. In order to better handle the traffic data in Shenzhen, we chose an appropriate size to divide Shenzhen into \(\{r_{k}\}_{k=1}^{n}\) where n = 1250 for scale 25 × 50 and n = 5000 for scale 50 × 100 based on grid-based method. Same processing as before, we break each day into 48 slots with each slot lasting 30 min. As for TAZ-based method, we utilize official data from the city of Shenzhen to divide the city into n = 491 irregular regions. In this step, rasterization of TAZ partitions is a necessary step to fit in the inputs of our model.

Definition 4 (Rasterization)

We divide TAZs into a grid map \(\mathcal {G}\) of size i × j. Each grid \(g \in \mathcal {G}\) can intersect with arbitrary number of TAZs . We calculate the in/out traffic flow for each grid g at time slot t as:

$$ x_{g,t} = \sum\limits_{k=1}^{n} x_{r_{k},t} \times \frac{S(r_{k} \cap g)}{S(r_{k})}, $$
(3)

where S(⋅) stands for the area of a region, and rkg indicates the intersection between rk and g.

4.2 Spatial-temporal attention based convolutional networks

In order to model spatio-temporal dependency in traffic prediction, we design an end-to-end deep learning based model STAtt-Net. Figure 2 illustrates the overall framework, which consists of two major components: the temporal dependency module and the spatial attention block.

Fig. 2
figure 2

Network architecture of our STAtt-Net, which mainly consists of three modules: (i) a temporal dependency module including weekly trend, daily periodicity, and hourly closeness components to learn periodic patterns; (ii) a STblock module taking advantage of Attention mechanism (b) to learn global spatio-temporal dependence(iii) a fusion and activation module to fuse temporal components and activate the final prediction

4.2.1 Temporal dependency module

It can be easily observed from the exemplary data in Fig. 1 that daily activities follows certain temporal periodicity in both day and week granularity, as pointed out by previous works [39, 50]. In particular, a morning peak can be found at around 9:00 every weekday, and however, the peaks delay about a half-hour every weekend. The traffic flow in a given arbitrary region is usually continuously varying, which means that the traffic flow at the current moment is strongly correlated with the traffic flow at the next moment. Therefore, the closeness for recent time dependency should be considered in traffic data prediction.

In addition, a travel peak is usually reached on Friday night. Based on this phenomenon, we consider the temporal dependency of hourly, daily and weekly. We set their length-dependent sequence Δh, Δd, Δw, among these three components, respectively. Thus our input data can be designed as:

$$ X_{h} = \left[X_{\mathcal{G}, t-{\Delta} h}, X_{\mathcal{G}, t-({\Delta} h-1)}, \ldots, X_{\mathcal{G}, t-1}\right], $$
(4)
$$ X_{d} = \left[X_{\mathcal{G}, t-{\Delta} d \cdot l_{d}}, X_{\mathcal{G}, t-({\Delta} d-1)\cdot l_{d}}, \ldots, X_{\mathcal{G}, t-l_{d}}\right], $$
(5)
$$ X_{w} = \left[X_{\mathcal{G}, t-{\Delta} w \cdot l_{w}}, X_{\mathcal{G}, t-({\Delta} w-1 \cdot l_{w})}, \ldots, X_{\mathcal{G}, t-l_{w}}\right], $$
(6)

where ld,lw denote the time period of a day and a week respectively.

4.2.2 Spatial attention block

We notice that important features are often concentrated in a certain region, thus a spatial attention mechanism can be introduced to focus on different regions of the feature map in space, telling the network where the region of interest is located. Besides, traffic flows are not only spatially correlated, but also have complex local spatial heterogeneity, so we propose here to use a spatial attention mechanism for traffic prediction. Specifically, we design a spatio-temporal unit module based on the attention mechanism, which captures the rich spatio-temporal relationships between regions over the whole city so as to obtain more significant spatial dependence.

Let \(X_{\mathcal {G}}^{(l)} \in \mathbb {R}^{c \times i\times j}\) be the feature map extracted by the l-th ST-Block layer, where c is the number of channels and i × j is the size of the feature map. Figure 2(b) gives an illustration of the spatial attention model. CNNs are more suitable for processing data with Euclidean structure, which can better model spatial correlation. First, we improve the nonlinear representation capability of the model with a convolution operation , which can be regarded as the weighted sum of samples:

$$ X_{\mathcal{G},conv}^{(l+1)} = f_{c}(W_{\mathcal{G}}^{(l)} \ast X_{\mathcal{G}}^{(l)}), $$
(7)

where ∗ denotes the convolution operation between a filter and the input feature maps, while \(W_{\mathcal {G}}^{(l)}\) is a learnable filter in the l-th convolution layer, fc(⋅) refers to the rectified linear unit (ReLU) ie. and \(f_{c}(z)={\max \limits } (0, z)\) is the activation function.

Besides, the global and local density distributions have certain regularities due to the constant movement changes of the vehicle flows. To encode the two types of observations described above, we design a spatial attention model that is capable of modelling a large range of contextual information and capturing changes in the density distribution of crowd flows. Figure 2 gives an illustration of the spatial attention mechanism structure. The feature map \(X_{\mathcal {G},conv}^{(l+1)} \in \mathbb {R}^{c \times i\times j}\) output from the previous convolution operation is fed into each of the three 1 × 1 convolution operations to generate three feature maps \(\mathcal {F}_{1}\), \(\mathcal {F}_{2}\) and \(\mathcal {F}_{3}\), and reshape them into \(\mathbb {R}^{c \times n}\), where n = i × j. For \(\mathcal {F}_{1}\), a further transpose operation is required. Next, we apply matrix multiplication and softmax operations to feature map F1 and F2 to obtain spatial attention maps \(\mathcal {W} \in \mathbb {R}^{n \times n}\). \(\mathcal {W}\) is defined as follow:

$$ \mathcal{W}_{j,i} = \frac{\exp \left( \mathcal{F}_{1}^{i} \cdot \mathcal{F}_{2}^{j}\right)}{{\sum}_{i=1}^{n} \exp \left( \mathcal{F}_{1}^{i} \cdot \mathcal{F}_{2}^{j}\right)}, $$
(8)

here \(\mathcal {W}_{j,i}\) represents the effect of position i on position j , a larger value means a higher similarity between position i and position j.

After generating the spatial attention matrix, we once again use the multiplication operation between \(\mathcal {W}\) and \(\mathcal {F}_{3}\) and reshape the result to \(\mathbb {R}^{c \times i\times j}\). The final output of the spatial attention block is defined as:

$$ X_{\mathcal{G},j}^{(l+1)} = \lambda \sum\limits_{i=1}^{n}\left( \mathcal{W}^{j,i} \cdot \mathcal{F}_{3}^{i}\right)+X_{\mathcal{G},conv}^{(l+1),j} , $$
(9)

where λ is a learnable parameter. As can be seen from the detailed description of the entire model, the final output is actually a weighted sum of the features at all locations and the original features, which contain global features and selective features according to the spatial attention map.

4.2.3 Fusion and activation module

In STAtt-Net, the last layer is a fusion layer that fuses the three components to modeling spatial-temporal correlation, including closeness, daily, weekly:

$$ X_{\mathcal{G}}^{F u s i o n}=W_{h} \circ X_{\mathcal{G}, h}^{\prime}+W_{d} \circ X_{\mathcal{G}, d}^{\prime}+W_{w} \circ X_{\mathcal{G}, w}^{\prime}, $$
(10)

where Wh, Wd and Ww are learnable parameters matrices and \(X_{\mathcal {G}, h}^{\prime }\), \(X_{\mathcal {G}, d}^{\prime }\) and \(X_{\mathcal {G}, w}^{\prime }\) are predicted results by three components based on historical data respectively. The ∘ is hadamard product which is formed by the elementwise multiplication of their elements. \(X_{\mathcal {G}}^{F u s i o n}\) denotes the output of the merge layer. After merging the three components, we employee the active function at this phase, and the predicted value at the t th time interval is denoted by \(Y_{\mathcal {R}, t_{n+1}}\), the final output of STAtt-Net is then derived as:

$$ Y_{\mathcal{G}, t_{n+1}}=\tanh \left( X_{\mathcal{G}}^{F u s i o n}\right), $$
(11)

4.3 Training

In the end, we predict inflow and outflow simultaneously. Our model can be trained end-to-end via back-propagation by minimizing the mean square error (MSE) between the predicted traffic flow \(Y_{R, t_{n+1}}\) and the ground truth \(X_{\mathcal {G}, t_{n+1}}\). The loss function is defined as:

$$ \mathcal{L}(\theta)=\left\|Y_{\mathcal{G}, t_{n+1}}-X_{\mathcal{G}, t_{n+1}}\right\|_{2}^{2}, $$
(12)

where 𝜃 is learnable parameters in our model.

5 Experinments

5.1 Experinmental setting

All experiments were conducted in Ubuntu16.04 (64bit) with AMD Ryzen 7 2700 8-Core Processor × 16 @ 3.60GHz CPU and NVIDIA GeForce RTX 2080 Ti GPU. The STAtt-Net model is implemented under an open-sources framework Keras with TensorFlow backend. During the training phase, the model was optimized by the Adam optimizer with a learning rate of 0.0002. The batch size was set as 64. The datasets were scaled into the range [-1,1] using Mmn-max normalization. Notice that we denormalized the predicted values to compare with the true values in the evaluation phase. In order to obtain optimal model parameters and prevent overfitting, we performed the early-stopping strategy on training to control the number of epochs. All kernels of the convolutions were set to 3 × 3 in size. The parameters for the three temporal components were set as: Δw = 1, Δd = 1, and Δc = 3.

5.2 DataSets

We used three datasets from the real-world to assess performance of our model as mentioned before: TaxiBJ, BikeNYC, and TaxiSZ. The first two datasets are publicly available and commonly used as benchmark in various CNN based traffic prediction works [57]. While the last one, TaxiSZ, comes from our cooperation with local transportation agency. We chose it to investigate the generalization ability of our model. The statistics of the datasets are summarized in Table 2, The details are as follows:

  • TaxiBJ. This traffic flow dataset contains 528 days GPS data of taxi in four different time periods of Beijing. After discarding corrupted data, the whole dataset is divided into 22,459 segments, with each segement to be 30-min.

  • BikeNYC. The dataset of Bike track in New York, containing bicycle trajectory in the New York bicycle system from April 1, 2014 to September 30, 2014. Each record includes information of bicycle trip duration, trip Starting and ending time, and trip date, starting and terminal station name, station number, station longitude and latitude, bicycle ID, etc. There are 183 days of records in total. The dataset is divided into 60-min segments, resulting in a total of 4392 segments with records less than 60 s excluded.

  • TaxiSZ. The data record of taxi transactions in Shenzhen carried out by more than 20k taxis over the duration from 1st Jan. 2019 to 30th Jun. 2019. There are approximately 800k transactions record every day, leading to over 145million transactions. For each taxi transaction record, the following attributes are recorded: taxi ID, price, operating mileage, get-on position (denoted as mp0) and time (mt0), and get-off position (mp1) and time (mt1). The raw data contains numerously corrupted or incomplete information, such as positions outside of Shenzhen or missing get-on/get-off times. After data wiping, 128 million accurate transaction records were reserved.

Table 2 Statistics of the datasets used in the experiments

5.3 Baselines

To assess the performance of our model, we compared STAtt-Net with the following baseline:

  • HA: Historical average method predicts the trend of data using the average of historical mobile traffic flow in data within relatively identical time intervals of a given range.

  • ARIMA: Autoregressive Integrated Moving Average Model is one of the classic time series forecasting models, and it was used in traffic flow prediction earlier. ARIMA regards the time series of data as a random time series, transforms the non-stationary data into a stationary series through several differences, and fits the time series into the parameter model.

  • ST-ResNet: The residual network based model, proposed by Zheng et al. [57], can fit the traffic flow data by capturing the time correlation of traffic flow and combining with external information (date attribute and weather data, etc.).

  • ST-3DNet: ST-3DNet uses a specially designed 3D CNN structure to learn the temporal and spatial features of traffic flow dataset together.

  • T-GCN [55]: T-GCN combined graph convolutional network and gated recurrent units to capture the complex spatial and temporal dependencies in traffic speed prediction

  • DeFlow-Net [53]: DeFlow-Net, a deep deformable convolutional residual network based on deformable convolutions. It is one of the most advanced convolution based deep traffic flow prediction models.

5.4 Evaluation metrics

We use Root Mean Square Error (RMSE) and Mean Absolute Error (MAE) to evaluate our proposed network performance.They are defined as follows:

$$ R M S E =\sqrt{\frac{1}{N} \sum\limits_{g=1}^{N}\left( {x}_{g,t_{n+1}}-y_{g,t_{n+1}}\right)^{2}}, $$
(13)
$$ M A E=\frac{1}{N} \sum\limits_{g=1}^{N}\mid {x}_{g,t_{n+1}}-y_{g,t_{n+1}}\mid, $$
(14)

where \({x}_{g,t_{n+1}}\) and \(y_{g,t_{n+1}}\) represent the real value and the predicted value at time frame t and grid g respectively, N is the number of all all the samples for prediction. RMSE and MAE are common indices used in traffic forecasting. However, as pointed out in literature [53], RMSE measurements are unit dependent, making it unsuitable for comparison between different datasets. To address the problem, we further incorporate Mean Absolute Scaled Error (MASE), which can be express as :

$$ I\left( x_{g_{i}, t}\right)=\frac{x_{g_{i}, t}-\bar{X}_{\mathcal{G}, t}}{S^{2}} \sum\limits_{j=1, j \neq i}^{n} w_{i j}\left( x_{g_{j}, t}-\bar{X}_{\mathcal{G}, t}\right) $$
(15)

where \(x_{g,t_{n+1}}\) and \(y_{g,t_{n+1}}\) in the numerator are from the testing data, while xg,t and xg,tm in the denominator are from the training data, respectively. T is the total number of time slots in the training data, and m is the seasonality of the time series (i.e., 48 for TaxiBJ and TaxiSZ, and 24 for BikeNYC). MASE is unit independent, allowing us to compare traffic flow predictions in different cities and at different scales. Moreover, MASE can handle actual values of zero and is not biased by very extreme values, which are problematic for mean absolute percentage error (MAPE) [14]. In general, a MASE less than 1 indicates a model is better than the naive model, and lower MASE indicates better model.

5.5 Performance comparison

The experimental results are shown in Table 3. It includes a comparison of the proposed model with the five baselines mentioned above. The best performance of all methods is marked in bold. We can observe that traditional time series methods, such as ARIMA,HA, cannot obtain good traffic forecasting results because they rely only on historical records to predict future values. Machine learning-based methods such as SVR can achieve better performance results, is limited in modeling the complex temporal and spatial dependencies in traffic forecasting. Deep learning-based methods such as ST-ResNet, ST3DNet aslo have better performance, but they are still worse than our STAtt-Net, which introduces attention mechanism and multiple time components to modelling a large range of contextual information and spatio-temporal dependencies. Recently, graph-based methods are effective for the problem of traffic flow prediction. We also conducted additional comparison using a recent GNN model for traffic prediction, namely T-GCN. The results became very bad However, the spatial features learned in GCN are not optimal for the grid-based traffic network prediction. The reasons for this result is that our works aims to predict traffic flows for regions, in which CNNs are more suitable because convolutions can better model spatial correlation by decomposing the traffic network as grids. In contrast, GNN models are more appropriate for graph-structured traffic data. STAtt-Net consistently achieves the better accuracy among all the compared models with the smallest RMSE value 16.64, 5.95,5.41, MAE value 9.39, 2.93, 1.48 and MASE value 0.281, 0.163, 0.236. In general, the prediction accuracy of our method is better than all the other methods in either RMSE, MAE or MASE, except for the latest DeFlow-Net which is slightly better than ours. However, the time cost of DeFlow-Net is about four times of ours.

Table 3 Comparison with baseline models

5.6 Comparison of different partitioning shapes and scales

Deep learning-based traffic flow prediction is affected by the plasticity area cell problem, which causes perturbations in the prediction results [52]. To explore the prediction performance of STAtt-Net on different partitioning shapes (grids vs. TAZs ) and scales (50 × 25 vs.100 × 50). The results are listed in Table 4. We can conclude that the RMSE results based on TAZ-partition are always better than grid-partition at the same scale. A potential reason is that some grids at this scale are too large and contain multiple small TAZs, resulting in the information of smaller TAZs are lost, and may even interfere with the prediction results. In addition, at scale (100 × 50) are records the improvements of at least 48.79% on RMSE and 54.42% on MAE at grid-partition, and 56.54% on RMSE and 66.67% on MAE at TAZ-partition compared to scale (50 × 25). The results infer that finer scale 100 × 50 is better.

Table 4 Performance comparison of different partition shapes (grid vs. TAZ) and scales (50 × 25 vs.100 × 50) on TaxiSZ

5.7 Impact of the number of ST-Block layers

The number of ST-Block layers can affect the prediction result. To investigate how it can affecting ST-ResNet efficiency, we change the number of layers of ST-Block from 1 to 5 and the model to get different predictions. As shown in Fig. 3, the number of ST-Block layers also great affects the experinment result. Taking Fig. 3(a) as an example, when the number of ST-Block layers increases from 1 to 4, the RMSE and MAE declines continuously to 5.96 and 2.93. The same is true for the results on the TaxiSZ. The RMSE reduces to 5.41 and MAE reduces to 1.48. This shows that an appropriate number can improve the prediction accuracy of the network.

Fig. 3
figure 3

Performance comparison of different numbers of ST-Block,the numbers of ST-Block layers increases from 1 to 5

6 Conclusion

In this paper, a spatial-temporal attention based convolutional networks, called STAtt-Net, is proposed for short-term traffic prediction. We developed the ST-Block module to enhance the feature extraction capabilities for learning the saptial heterogeneity. Besides, considering the temporal properties of traffic data, STAtt-Net models the temporal dependency as i) trend for weekly trend, ii) period for daily periodicity, and iii) closeness for recent time dependence. We evaluated our model on three large-scale datasets, respectively. The experimental results demonstrate that STAtt-Net significatly outperforms state-of-art approaches. Our method achieves a good balance between accuracy and efficiency. Since each region in STBlock has to capture global contextual information, this leads to a large computational complexity for the whole attention mechanism module. But it is still more efficient than pure 3D convolution model, such as DeFlow-Net, as revealed in the experiments. Another issue is that pure attention-based models are known to be quite ’data-hungry’ as they usually require huge amounts of data to pre-train before being applicable. Finally, the interpretability of deep model is still quite challenging and the introduction of attention mechanism aggravates the issue. We note some novel work such as TFT [20], a multilayer pure deep learning model for time series with an LSTM encoder-decoder and a new attention mechanism that provides interpretable predictions. This provides us with some ideas for the next step. Besides, we will consider introduce other mechanisms such as transformers to optimize the predictive capabilities of the model in the future.