Keywords

1 Introduction

Deep Convolution Neural Networks (DCNNs) have demonstrated breaking results on a variety of computer vision tasks, including but not limited to image classification [1, 2] and object detection [3, 4]. However, deploying Deep Convolution Neural Networks (DCNNs) on embedded devices has been found highly difficult due to the massive amount of storage and multi-accumulate operations. As a result, it remains a great challenge to deploy deep CNNs on embedded devices.

Substantial efforts have been made to solve this problem. The most common method is to compress a full-trained networks directly. [5] proposed vector quantization techniques to compress deep CNNs, by replacing the weights in full connected layers with respective floating-point centers obtained from k-means clustering. HashedNets [6] reduced model sizes by using a hash function to put pre-trained weights into corresponding buckets and force them to share the same value. However, they both concentrated on the full connected layers only.

Another common method is using lower precision weights, which can not only reduce the size of networks, but also speed up the execution. [7] proposed that using SIMD instructions with 8-bits fixed-point implementation can improve the performance of computing during inference, yielding 3\( \times \) speed-up over floating-point baseline. [8] trained deep neural networks with low precision multipliers and high precision accumulators. [9] introduced an approach to eliminate the need of float-point multiplication by converting multiplication into binary shift. Moreover, [10] eliminated the need for multiplications by forcing the weights used in forward and backward propagations to be binary (not necessarily 0 and 1), and achieved near state-of-the-art results on MNIST, CIFAR-10 datasets, but performed worse than full precision counterparts by a wide margin on ImageNet [11] dataset. Furthermore, [12] introduced a high performance fixed-point optimization method that allow networks with ternary {−1, 0, +1} weights and 2 or 3 bits of fixed-point signals, which can greatly reduce the word-length of weights and signals for implementing networks on embedded devices. However, the performance of networks shows obvious degradation on large datasets. Later, [13] proposed ternary weight networks (TWNs) with weights quantized to {−α, 0, +α} to find a balance between high model compression rate and high accuracy, which achieved better performance on large dataset compared with previous quantized networks due to the increased weight precision and scaling factors. However, the same scaling factors for positive and negative weights have limited the expression ability of the ternary weight networks. Recently, lots of new methods have been proposed to train CNNs with low-precision weights, including but not limited to BinaryNet [14], XNOR-Net [15], DoReFa-Net [16], Bitwise Neural Network [17] and TTQ [18].

This paper makes the following contributions:

  1. 1:

    We introduce Three-Means Ternary Quantization (TMTQ), a new method to quantize the weights to ternary values {\( - \alpha_{1} , 0, + \alpha_{2} \)} for each layer during forward and backward propagations (Sect. 3).

  2. 2:

    We show that TMTQ performs better than the existing quantization methods and obtains near state-of-the-art results on MNIST, CIFAR-10 and ImageNet datasets (Sect. 4).

2 Related Quantization Methods

Recently, more and more researchers concentrate on deploying deep neural networks on embedded devices. In order to solve the limitations of storage and computing power, they proposed low-precision alternatives to perform deep learning tasks, following are some latest studies on low-precision network quantization methods.

2.1 BinaryConnect

BinaryConnect [10] proposed a method to quantize full precision weights to binary values, shown in Eq. (1), which constrains the weights to {+1, −1} during forward and backward propagations.

$$ W_{l}^{b} = \left\{ {\begin{array}{*{20}c} { + 1} & {if\,w_{l} \ge 0,} \\ { - 1} & {otherwise} \\ \end{array} } \right. $$
(1)

The key point of BinaryConnect is that it only binarizes the weights during forward and backward propagations but not during the parameters update when reserved full precision weights are used. And the real-valued are restricted to [−1, 1] to reduce the impact of the large weights. During inference, only binary weights are needed, a 32\( \times \) smaller model can be deployed on embedded devices.

2.2 Fixed-Point Feedforward Deep Neural Networks

Hwang [12] proposed a direct 3-point quantization method to constrain the weights to {−1, 0, +1}, which is shown in Eq. (2).

$$ {\text{W}}_{l}^{t} = \left\{ {\begin{array}{*{20}c} { + 1} & {w_{l} > + \Delta } \\ 0 & {|w_{l} | < \Delta } \\ { - 1} & {w_{l} < - \Delta } \\ \end{array} } \right. $$
(2)

Here ∆ is the threshold used to quantize continuous weights. However, determining threshold ∆ is a difficult problem, because there is no clear relation between the parameters and final output errors resulted in by the quantization. Therefore, the threshold ∆ is initially determined by using an L2-error minimizing approach, and then fine-tuned by using exhaustive search to find a best value that minimized the output error.

After training, by using 2-bits to store the ternary values, they obtained almost 16\( \times \) compression rate compared with the full precision weights. The fixed-point networks show only negligible performance loss when compared to full precision counterparts on small datasets according to their paper. Also the “0” value ensure the sparseness of networks, which can prevent the network over-fitting.

2.3 Ternary Weights Networks

Ternary weight networks (TWNs) [13] – neural networks with weights constrained to {+α, 0, −α}. A scaling factor α is used to reduce the loss between ternary and full-precision weights, shown in Eq. (3).

$$ {\text{W}}_{l}^{t} = \left\{ {\begin{array}{*{20}c} { + \alpha } & {w_{l} > +\Delta } \\ 0 & {|w_{l} | <\Delta } \\ { - \alpha } & {w_{l} < -\Delta } \\ \end{array} } \right. $$
(3)

Also, ∆ is a threshold used to quantize continuous weights. During training, α and ∆ are optimized by minimizing L2-error between full precision and ternary weights. However, because α and ∆ are independent factors, this problem has no straightforward solution as [12] (described in Sect. 2.2). To overcome this, approximated values are used, shown in Eqs. (4) and (5).

$$ \Delta = 0.7 * E \left( {|w_{l} |} \right) $$
(4)
$$ \alpha = \frac{1}{{\left| {I_{\Delta } } \right|}}\sum\nolimits_{{i \in I_{\Delta } }} {\left| {w_{i} } \right|,\,I_{\Delta } = \left\{ {i|\,\left| {w_{i} } \right| >\Delta } \right\}} $$
(5)

The training process of ternary weight networks is the same as binary weights described before. Also, with this quantization method, the authors obtained 16\( \times \) smaller models compared with full precision counterparts and achieved near state-of-the-art results on different datasets according to their paper.

3 Three-Means Ternary Quantization

In this section, we give a detailed view of TMTQ, considering how to obtain ternary values from full precision weights and train deep neural networks with ternary weights. We first consider the ternary quantization method and then introduce how to train networks with this method.

3.1 Quantization Method

Our method is shown in (6). First, we set two different thresholds \( \Delta _{l}^{p} \) and \( \Delta _{l}^{n} \) for positive weights and negative weights, and then quantize the full-precision weights to ternary values {\( W_{l}^{p} \), 0, \( - W_{l}^{n} \)} by thresholds.

$$ W_{l}^{t} = \left\{ {\begin{array}{*{20}c} {W_{l}^{p} } & {W_{l} >\Delta _{l}^{p} } \\ 0 & { -\Delta _{l}^{n} < W_{l} <\Delta _{l}^{p} } \\ { - W_{l}^{n} } & {W_{l} < -\Delta _{l}^{n} } \\ \end{array} } \right. $$
(6)

Here we introduce four independent factors {\( \Delta _{l}^{n} , \Delta _{l}^{p} , W_{l}^{n} , W_{l}^{p} \)} to quantize the continuous full-precision weights. The different thresholds and scaling factors between positive and negative weights enable networks to have stronger learning ability. Unlike previous works which have the thresholds \( \Delta _{l}^{*} \) and scaling factors \( W_{l}^{*} \) set by experience, we propose a novel algorithm to optimizing these four factors simultaneously from the full precision weights, which is shown in Algorithm 1.

As shown in Algorithm 1, our quantization method is similar to k-means with k = 3, but still have some differences. First, we do not choose centers randomly. If the weights \( W_{l} \) is the first time to be quantized, we just initialize three centers with \( {\text{Min}}(W_{l} ) \), 0, \( {\text{Max}}(W_{l} ) \) to accelerate clustering convergence [19]. Otherwise, because parameters update is small during each training iteration, using previous training iteration centers is also a good way to reduce the number of clustering iterations. Second, centers are updated during each clustering iteration process except for \( {\text{center}}[1] \), we fixed its value equals 0 to make sure the sparseness of the networks which can prevent over-fitting of the networks like dropout. Furthermore, though we quantize the full-precision weights with four independent factors, we do not need to know how to calculate these specific values with our method TMTQ. We get the ternary weights automatically by invisible thresholds after some clustering iterations without setting any approximate value.

The benefits of using TMTQ: (i) TMTQ method obtained all parameters automatically from the weights without any artificial factors, which is easy to be implemented for arbitrary networks and datasets. (ii) The asymmetric of ternary values {\( + W_{l}^{p} , 0, - W_{l}^{n} \)} enables networks to have more model capacity.

figure a

3.2 Train Ternary Networks with TMTQ

We use ternary weights during forward and backward propagations and update the parameters with reserved full precision weights as described before. Stochastic gradient descent (SGD) is used to train the networks. The training steps are shown in Algorithm 2.

Noting that our training steps are similar to normal training methods except for the ternary weights are used in forward and backward propagations. In addition, some useful tricks are utilized to speed up training process and improve the inference accuracy. Batch Normalization (BN) [20] not only accelerates training by reducing internal covariate shift, but also reduces the impact of weights scales. And also, learning rate scaling and momentum are both effective methods to optimize network training.

Furthermore, our TMTQ method does not increase training time much for we update clustering centers with centers’ value obtained from previous training iterations (Algorithm 1). Through this way, 2 clustering iterations are enough to obtain good results during each training iteration.

figure b

3.3 Inference

In previous sections, we have introduced the way to train deep neural networks with TMTQ method. During inference, only the ternary weights and scaling factors are needed. By storing the weights with 2-bits values, we can reduce the mode size by about 16\( \times \). Furthermore, due to the \( W_{l}^{p} \) and \( W_{l}^{n} \) are fixed during inference, calculating the scaling factors on activate function in advance is an effective way to speed up forward propagation on specialized hardware, for lots of multiplications are replaced with addition or subtract operations.

4 Experiments

In this section, we compare our TMTQ method with different existing quantization methods on three benchmark datasets: MNIST, CIFAR-10 and ImageNet. For fair comparison, the same hyper parameters are used during training, such as network structure, learning rate, regularization method and optimization method (SGD). In addition, MNIST and CIFAR-10 experiments are repeated 4 times to obtain the average results, reducing the effect of random initialization and data augmentation. We implement our experiments on Caffe [21] framework.

4.1 MNIST

The MNIST is an image classification benchmark dataset containing 60 thousand training images and 10 thousand test images. We train LeNet-5 network on MNIST without any data augmentation or preprocessing methods. The LeNet-5 consists of: “32-C5 + MP2 + 64-C5 + MP2 + 512-FC + 10SoftMax”. Where 32-C5 means the convolution layer contains 32 kernels with size 5 \( \times \) 5, MP2 means 2 \( \times \) 2 max-pooling layer, FC is fully connected layer and SoftMax is an output layer. We use SGD to update parameters with momentum equals 0.9. Minibatch size is set to 100. Learning rate is initialized to 0.0001 and reduced by steps. Moreover, we add Batch Normalization layer after every convolution layer to reduce internal covariate shift.

In order to make the quantized network converge as soon as possible, we first train a full precision model on MNIST as a baseline, and then fine-tune the full precision baseline with binary and ternary quantization methods. The training curves are shown in Fig. 1. The result (Table 1) shows that our ternary model obtained from TMTQ outperforms BinaryConnect model and TWNs model by 0.31%, 0.05% respectively and has 0.02% accuracy degradation over full precision model.

Fig. 1.
figure 1

Test accuracy of Lenet-5 on MNIST with different quantization methods

Table 1. Accuracy rate on MNIST,CIFAR-10 and ImageNet.

4.2 CIFAR-10

The CIFAR-10 is an image classification benchmark dataset containing 50 thousand 32 \( \times \) 32 RGB training images and 10 thousand test images. We train VGG13 network which is inspired from VGG16 [22] on CIFAR-10 with some data-augmentation operations. We pad 2 pixels in each side of images and randomly crop 32 \( \times \) 32 size from padded images during training. During inference, original 32 \( \times \) 32 images are used to test the networks. Our VGG13 networks denoted as: “(2 \( \times \) 128-C3) + MP2 + (2 \( \times \) 256-C3) + MP2 + (2 \( \times \) 512-C3) + MP2 + (2 \( \times \) 512-C3) + MP2 + (2 \( \times \) 512-C3) + MP2 + (2 \( \times \) 1024-FC) + 10-SoftMax”. These layers have the same meaning as described in Sect. 4.1. Parameters update by SGD method with momentum equals 0.9 and learning rate is initialized to 0.0001. Minibatch size is set to 100. Furthermore, Batch Normalization (BN) is used after convolution layers to speed up the training process.

Also, we first use a full-trained VGG13 model as a baseline, and then fine-tune the baseline with binary and ternary quantization methods. Training curves are shown in Fig. 2. The result (Table 1) shows that our ternary model obtained from TMTQ outperforms BinaryConnect model and TWNs model by 1.36%, 0.44% respectively, and has 0.17% accuracy degradation over full precision model.

Fig. 2.
figure 2

Test accuracy of VGG13 on CIFAR-10 with different quantization methods.

4.3 ImageNet

ImageNet is an image classification dataset with over 1.28 million training images and 50 thousand validation images. We use AlexNet structure in our experiment with the full precision weights for the first convolution layer and the last full connect layer. During training, images are resized to 256 \( \times \) 256 and randomly cropped to 227 \( \times \) 227 before input. SGD method is used to update the parameters with momentum equals 0.9. Minibatch size is set to 256. Learning rate is initialized to 0.0001 and reduced by 0.1 at iteration 200000.

We download a full-trained AlexNet model from caffe model zoo as a baseline and then fine-tune this baseline model with TMTQ and TWNs quantization methods, training curves of top-1 accuracy in validation dataset are shown in Fig. 3. The result (Table 1) shows that our TMTQ model outperforms TWNs model by 2.42% and has only 0.97% accuracy degradation over full precision counterpart.

Fig. 3.
figure 3

Validation accuracy of AlexNet on ImageNet with different quantization methods

5 Conclusion

We propose a novel method TMTQ which quantizes continuous weights to ternary values during forward and backward propagations. With TMTQ method, we do not need to set any thresholds \( \Delta _{l}^{*} \) in advance or calculate the scaling factors \( W_{l}^{*} \) by approximately, all factors are obtained automatically by learning the centers of the full-precision weights. Furthermore, our quantization method reduces the model size by about 16\( \times \) for we just need ternary weights (2-bits) and scaling factors during inference. The above experiments proved that our method TMTQ performs better than BinaryConnect and TWNs quantization methods on CIFAR-10 and ImageNet datasets, and has only slightly accuracy degradation over full precision counterparts. Future works will extend those results to other models and datasets, and explore the deep relationship between ternary values and network outputs.