Keywords

1 Introduction

The Graph Neural Networks (GNNs), which was firstly proposed in 2009 [1], has been developed rapidly in recent years due to the powerful processing ability for data in non-Euclidean space, for example the graph data [2,3,4,5,6]. Nowadays, GNNs are widely used in many areas such as social networks [7, 8], drug discovery [9] and recommendation [10, 11]. As mentioned in [12], there have been mainly four GNNs categories so far including Recurrent Graph Neural Networks (RecGNNs) [13, 14], Convolutional Graph Neural Networks (ConvGNNs) [15,16,17,18,19,20,21,22,23], Graph Autoencoders (GAEs) [24,25,26,27] and Spatial-temporal Graph Neural Networks (STGNNs) [28,29,30]. Among them, ConvGNNs realize the generalization of the convolution from grid data to graph data, whose typical model is Graph Convolution Network (GCN) [31, 32].

In order to improve the generalization performance of GCN for new nodes, the Graph SAmpling based INductive learning meThod (GraphSAINT) algorithm [33] is proposed. GraphSAINT can realize the effective training of the deep GCNs by using a special minibatch construction way. This algorithm obtains a set of subgraphs by sampling the original training graph and then builds a GCN based on the subgraphs. Therefore, the graph sampling strategy is the main contribution of GraphSAINT. Besides, this strategy also alleviates the problem of the neighbor explosion, so that the number of neighboring nodes no longer increases exponentially with the number of layers in GraphSAINT. Moreover, compared with GraphSAGE [34], GraphSAINT also enhances the processing capability for large graphs by applying the subgraph sampling method. Therefore, although GraphSAINT uses the same inductive framework as GraphSAGE, GraphSAINT and GraphSAGE are different on sampling method: GraphSAINT samples multiple subgraphs from original dataset to construct minibatch for training; while GraphSAGE adopts the neighbor-node sampling method to generate the node embeddings.

Although GraphSAINT solves the problem of neighbor explosion and has the stronger generalization ability than GCN, the use of the graph sampling strategy makes it difficult training the network. In general, nodes which have the higher influence on each other should be selected to form a subgraph with the higher probability. This ensures that the nodes can support each other within the subgraph. However, such sampling strategy leads to different node sampling probabilities and introduces bias in the mini-batch estimator. In [33], the normalization techniques are developed to deal with this issue so that the feature learning will not give priority to the more frequently sampled nodes. As a result, GraphSAINT effectively solves the problems of instability and non-convergence faced in the training process, and obtains a good performance improvement on the classification task.

However, experiments show that the stability problem of GraphSAINT in the training process appears when GraphSAINT is applied to solve the link prediction task [35,36,37,38] in different application areas [39,40,41,42,43] on the citation dataset of the standard Open Graph Benchmark (OGB) [44]. Link prediction is widely used in many aeras such as recommendation system [45,46,47], biological networks [48] and knowledge graph completion [49]. Different with node classification, the main task of link prediction is to judge whether two nodes in a network are likely to have a link. This stability problem in the training process means that the normalization techniques in [33] are insufficient to improve the training quality for the link prediction task. Thus, it is difficult to avoid falling into non-convergence during training, which appears that the value of the training loss suddenly rises and remains forever unchanged. This is a problem that has a big impact on the GNN model development.

From the analysis of the CNN training method, we had some new discovery in training stability. Stochastic gradient descent and its variants such as momentum [50] and Adagrad [51] have been widely used to train the neural networks. The training process is complicated. Besides, as the network gets deeper, small changes to the network parameters will amplify [52]. In the process of constantly adapting to the new distribution, the distributions of layers’ inputs present a problem called covariate shift [53,54,55,56,57,58], which is harmful for the neural network convergence. Ioffe and Szegedy proposed the Batch normalization (BN) mechanism to reduce the internal covariate shift and accelerate the convergence of the deep neural nets [52]. This mechanism makes use of the mean and variance to normalize the data values over each mini-batch, which allows us to set a higher learning rate and drop the Dropout [59]. However, this effective BN mechanism has not been widely used in GNNs due to the fewer network layers [12]. Nowadays, as the graphs become larger and the tasks become more complex, the GNN models are more complicated and the difficulty of training the GNN models is also increasing rapidly, which leads to the stability problem during training. Thus, the application of the BN mechanism is of great significance to the robustness of training on large graphs.

Therefore, in order to solve the stability problem in the training process of the link prediction task, we propose an improved GraphSAINT method by adjusting the BN strategy to the special training and inference process of GraphSAINT based on the OGB in this paper. By applying the normalization strategy during training, we achieve the elimination of the instability during training successfully. Moreover, we also realize a reduction in the training time and gain an increase in accuracy under the premise of maintaining the original link prediction accuracy. The effectiveness of our method is validated by the citation dataset of the OGB.

Inspired by [60], the paper is organized as follows: Firstly, Sect. 2 is the related work about the GraphSAINT, especially its sampling strategy and the typical batch normalization techniques; Next, Sect. 3 describes our improved GraphSAINT and the training method; Then, Sect. 4 shows the comparative experiment results based on the citation dataset; At last, conclusions are given in Sect. 5.

2 Related Work

2.1 The Sampling Strategy of GraphSAINT

GCN achieves one-hop neighbors’ information aggregation by using the adjacent matrix [31, 32]. However, because of the using of the adjacent matrix, when a new node is added into the graph, we must adapt the adjacent matrix and re-train the GCN model based on the new adjacent matrix of the adapted graph data to obtain all node’s new embeddings. Therefore, as mentioned in Sect. 1, GCN causes a lack of generalization performance for unseen nodes. Besides, it has a high time cost. To overcome this shortcoming, GraphSAGE is proposed [34]. In this method, the neighbors of the target node are sampled and a new aggregation function is learned to aggregate neighbor nodes and generate the embedding vector of a target node, which avoids applying the adjacent matrix and reduces training costs by sampling the nodes of each layer of GNN.

Furthermore, compared with GraphSAGE, GraphSAINT makes it possible to solve the learning tasks on the large graphs by designing a sampler called SAMPLE to obtain subgraphs [33]. Besides, this subgraph sampling method can also improve the generalization performance like GraphSAGE. With the help of the sampling strategy, GraphSAINT can deal with the Neighbor Explosion problem better.

In order to preserve the connectivity features of the graph, bias in the mini-batch estimation will almost inevitably introduced by the sampler. Therefore, in [33], the self-designed normalization techniques are introduced to eliminate deviations. The key step is to estimate the sampling probability of each node, edge, and subgraph.

The sampling probability distribution \(P\left( u \right)\) of the node \(u\) is

$$P\left( u \right) \propto \left\| {{\tilde{\mathbf{A}}}_{:,u} } \right\|^{2}$$
(1)

where \({\mathbf{A}}\) is the adjacency matrix and \({\tilde{\mathbf{A}}}\) is the normalized one, that is \({\tilde{\mathbf{A}}} = {\mathbf{D}}^{{{ - }1}} {\mathbf{A}}\), \({\mathbf{D}}\) is the diagonal degree matrix.

The sampling probability distribution \(P_{u,v}^{\left( l \right)}\) of the edge \(\left( {u,v} \right)\) in the \(l^{th}\) GCN layer is

$$P_{u,v}^{\left( l \right)} \propto \frac{1}{\deg \left( u \right)} + \frac{1}{\deg \left( v \right)}$$
(2)

The sampling probability distribution \(P_{u,v}\) of the subgraph is

$$P_{u,v} \propto {\mathbf{B}}_{u,v} + {\mathbf{B}}_{v,u}$$
(3)

where \({\mathbf{B}}_{u,v}\) can be interpreted as the probability of a random walk to start at \(u\) and end at \(v\) in \(L\) hops, \({\mathbf{B}}_{v,u}\) can be interpreted as the probability of a random walk to start at \(v\) and end at \(u\) in \(L\) hops, \({\mathbf{B}} = {\tilde{\mathbf{A}}}^{L}\), \(L\) means \(L\) layers which can be represented as a single layer with edge weights. Thus, the sampling probabilities of each node, edge, and subgraph are all well-estimated. Then, the subgraphs obtained by sampling will be used for GraphSAINT training.

2.2 The Typical Normalization Technique

The normalization techniques are proposed to eliminate the Internal Covariate Shift, which is caused by the change in the distributions of internal nodes of a deep network in the process of training, and offer the faster training [53, 58].

The typical normalization technique for mini-batch presented in [52] basically follows the mathematical statistics: Firstly, the mini-batch mean is calculated based on the values of each data point over a mini-batch; Next, the mini-batch variance is calculated based on the mini-batch mean; Then, each data point over a mini-batch can be normalized by subtracting the mini-batch mean and then dividing by the mini-batch variance; Finally, the scale and shift parameters are introduced and learned for each data point over a mini-batch.

3 Methodology

As mentioned in Sect. 1, the original normalization techniques of GraphSAINT, which are effective for the node classification task, are insufficient to improve the training quality for the link prediction task. Therefore, an improved GraphSAINT training algorithm is proposed.

Since the sampled subgraphs in GraphSAINT are based on the connectivity rules of the nodes, it can get the edge sampling probability with the smallest variance. In contrast, for node selection, it uses the Random node sampler. Therefore, the node feature data of each sampled subgraph do not obey the standard normal distribution. Assume the graph dataset to be processed is a whole graph \(\zeta = \left( {V,\xi } \right)\) with N nodes \(v \in V\), edges \(\left( {v_{i} ,v_{j} } \right) \in \xi\). For the node \(v_{i}\) in a sampled subgraph \(\zeta_{s}\) of \(\zeta\) according to SAMPLE, its feature \(h_{i,s}\) has \(d\) elements. In order to normalize the distributions of the inputs to reduce the internal covariate shift, the input node feature vector can be normalized by

$$\begin{aligned} & \hat{h}_{i,s} = \frac{{h_{i,s} - \mu_{s} }}{{\sqrt {\sigma_{s}^{2} } }} \\ & \mu_{s} = \frac{1}{d}\sum\limits_{i = 1}^{d} {h_{i,s} } \sigma_{s}^{2} = \frac{1}{d}\sum\limits_{i = 1}^{d} {\left( {h_{i,s} - \mu_{s} } \right)^{2} } \\ \end{aligned}$$
(4)

where \(\mu_{s}\) and \(\sigma_{s}^{2}\) are the mean and the variance for the node \(v_{i}\) in the node feature dimension and are computed over the training data set.

Therefore, by means of the node-wise normalization technology in each subgraph, each node feature vector is normalized by making its mean zero and variance 1. Besides, based on the typical normalization principle, the training time can also be effectively shortened by applying the node-wise normalization technology in each subgraph.

figure a

The whole training process of the improved GraphSAINT is illustrated in Algorithm 1. Before the training starts, we perform a pre-processing on \(\zeta\) to convert the directed graph to the undirected graph and obtain the sampled subgraph \(\zeta_{s}\) with the given SAMPLE [33]. Then an iterative training process is conducted via SGD to update model weights. Each iteration uses an independently subgraph \(\zeta_{s}\). Next, the original GCN is modified through applying the normalization technology on the output of the convolutional layer, which is also the original input of the RELU layer. Finally, the modified GCN on \(\zeta_{s}\) is built to generate embeddings and the loss can then be calculated according to Mean Reciprocal Rank (MRR). In MRR, the score of the first matched result is 1, the score of the second matched result is 1/2, and the score of the nth matched result is 1/n. If there is no matching sentence, the score is 0. The final score is the sum of all scores.

As mentioned in Sect. 2.1, GraphSAINT uses the subgraphs obtained by the subgraph sampling method for training, while it uses the whole graph data to calculate the output result during inference. Therefore, the normalization operation can be added independently during inference.

4 Experiments

In this section, in order to verify the effectiveness of the improved GraphSAINT algorithm, we choose the Link prediction task based on the citation dataset of OGB (ogbl-citation). The ogbl-citation dataset is a directed graph and can be viewed as a ‘subgraph’ of the citation network called MAG [61]. In this dataset, each node represents a paper, whose title and abstract are encoded into a 128-dimensional word2vec features, and each directed edge indicates the citation relationship between two papers.

The link prediction task means that we need to predict missing citations based on the exiting citations on the graph. Two of each source paper’s references are randomly dropped and the model is required to achieve the ranking of the missing two references in front of other 1000 references that are also randomly sampled from all the papers and not referenced by the source papers. According to this, MRR is chosen as the evaluation metric [44]. Besides, we use the two dropped edges of all source papers respectively for validation and testing. Naturally, the training set contains the rest of the edges.

Table 1. Results for GraphSAINT on citation dataset.

The official results of the traditional GraphSAINT on the citation dataset are given in Table 1: the MRR value of the training set is 0.8626 ± 0.0046, the MRR value of the validation set is 0.7933 ± 0.0046 and the MRR value of the test set is 0.7943 ± 0.0043. However, in the training of our recurrence experiment of the traditional GraphSAINT, we found that GraphSAINT algorithm will not converge in training with a probability range from 0.1 to 0.4. For a RUN where loss converges as shown in Fig. 1(a), the MRR results are consistent with the official results as shown in Table 1, i.e., the training result can reach about 0.8690, the validation result can reach about 0.8031 and the test result can also reach about 0.8048. But for a RUN where loss does not converge, the MRR results are also shown in Table 1 and the loss curve trained under different epochs is basically as shown in Fig. 1(b).

As shown in Fig. 1, we can see that one RUN contains 200 epochs. Besides, in Fig. 1(b), after the 78th epoch, the loss is suddenly and sharply increased to 34.5388 and remains unchanged. Therefore, some measures need to be taken to solve the problem of non-convergence in the training process.

After applying the improved GraphSAINT, our loss curve during training is shown by the solid line in Fig. 2. We can see that the solid loss curve is convergent. Besides, compared with the dotted line in Fig. 2, which is the loss curve during training of the traditional GraphSAINT as shown in Fig. 1(a), the improved GraphSAINT has a more stable convergence during training and a faster convergence rate. Moreover, all the three MRR values have an improvement as shown in Table 1: the training results can reach 0.9001 ± 0.0014, the validation results can reach 0.8335 ± 0.0020 and the test results can also reach 0.8344 ± 0.0023. Thus, the effectiveness of our improved GraphSAINT is verified.

Fig. 1.
figure 1

The loss curve during training of the traditional GraphSAINT (a) with convergence (b) without convergence

Fig. 2.
figure 2

The loss curve during training of the improved GraphSAINT (the solid line) and the traditional GraphSAINT with convergence in Fig. 1(a) (the dotted line).

5 Conclusions

The stability problem during training of graph neural network is crucial. In this paper, we focus on this stability problem and propose an improved GraphSAINT by applying the normalization strategy. The proposed method not only improves the robustness of the training process of the GraphSAINT, but also accelerates the convergence of the model. In the future, more attention will be paid to the distributed training methods for large graph datasets.