What should you know about the Batch Size

Recently I played with the batch size in my experiments. And I learned something new in this process: we should know more than the fact that there are three types of batch size in the training process. In this post, I'd like to share the fundamental knowledge of the batch size, some theoretic properties of batch size, and some useful tricks.

Batch size is the number of training examples used in an iteration. In the gradient descent optimization process, the difference between the prediction value and true value of the sample batch is used as the error estimator to update the model.


The batch size will influence the speed of convergence in training. If we use the entire training set in one batch, it is guaranteed that the model will converge to the global optimal point, however, it comes with fewer episodes till convergence. On the contrary, a mini-batch (batch size smaller than the total dataset size but larger than one) takes the risk of falling into the local optimal point but converges faster. A useful tip here is to start with small batch size and gradually increase the batch size: at the beginning, the model will train fast, and finally, it won't be trapped in a locally optimal point.

A theoretical explanation of why a smaller batch size converges faster is that smaller batch sizes generate a larger error estimator, which is further away from the initial values. Intuitively, the full batch size may get a smaller error estimator because the changes in many different directions compensate for each other.


Recent work has also found that the selection of batch size and learning rate are related: a small batch size tends to work better with a small learning rate and vice versa. Considering the fact that the big batch size takes a shorter time to fit the entire training set one time, in the deal case where appropriate learning rate and batch size are selected and the number of episodes is same with different batch size, training with a big batch size costs a shorter time.

Practically, we have to use small batch sizes sometimes. We may face up with problems whose samples are generated online so we cannot use the full batch size. Another problem is that the size of the dataset may also be so huge that it raises the out of memory error if we use the full batch size.

Furthermore, in the inference, small batch sizes tend to generalize better than the full batch size. Intuitively, small batch size makes some noises in the gradient descent. A more theoretical explanation is here: small batch size leads to a flat minimizer while a bigger batch size has a sharp minimizer.


A very useful trick is to use 2^n as the batch size. Because the number of physical processors usually is 2^m. This is based on the Data Parallelism: all the PP do the same thing at the same time but on different data. The batch size decides the mapping relationship between the number of virtual processors and the number of physical processors. For example, if you have 16 PP, you can map 16 VP: 1VP is mapped for 1 PP. However, if you map 17 VP, then 17-16=1 PP will execute while the other 15 PP does nothing.


If you want to know more, here are some useful resources about the batch size.

Comments
Write a Comment