Why Batch Norm Works Explained With Minimal Math
Whenever I look at innovations in machine learning I always think: “wow, how did someone come up with that?” and the surprising genius of various discoveries is in their simplicity. From GANs to the discovery of batch norm, there’s a dazzling amount of complexity behind the reasons simple equations or techniques result in high accuracy of machine learning models. The wonders of batch norm specifically stem from how we represent datasets mathematically in terms of an affine transform on a Gaussian distribution.
Batch normalization
Batch normalization is a layer you can add to your neural network which normalizes the inputs you give it with respect to the batch it is in. It was introduced in 2015 and has been shown to reduce overfitting in networks as well as produce better generalization accuracy. The output of a BN layer with respect to input x
is the normalized output of x
with respect to the mean and variance of the batch sample you found it in.
Why it works
But why does it help to train a network better? Let’s think of this in terms of how a NN learns.
Whenever you feed data to your neural network from some sample x
in a training set it has an underlying inherent distribution associated with it. That is, all of the samples can be thought of as coming from some distribution. For the sake of generalization, let’s say this distribution can be approximated using a Gaussian function parameterized by a mean and variance (something you will no doubt be familiar with if you have ever taken a STATS 101 class).
If you take an arbitrary sample from the distribution of your data and call this x
, we can think of x
as a transformation (specifically affine) on the dataset it came from with respect to the (approximate) mean/variance of this dataset to get x'
.
But why is this useful? The reason is because the data we’re actually dealing with that is x
in fact actually depends on the mean and variance of the sample it’s in, and this is very useful since it can be used to normalize x
so that no matter which distribution it comes from the relative meaning of each sample stays the same. That is with respect to a Gaussian, when we normalize x
directly at the mean it represents 0
and gets increasingly larger further away from the mean.
For instance, take two distributions of male and female heights. Both distributions will be Gaussian (this is a well known fact), but we can expect the distribution of male heights to have a higher mean average while the distribution of female heights has a lower average. Yet, if we fed this to our neural network to determine if someone is “tall” or not we should treat this data irrespective of distribution because at the mean we consider two people with the average, regardless of gender to be at the average.
We can explain why our neural network prefers learning this signal if we understand what happens when we pass data through our neural network WITHOUT applying batch norm. When you do this on some input, you can think of your data as undergoing the affine transformation (see above) before it is inserted which depends on as well as . This is a HUGE problem because this changes for each input sample x
! Mathematically this just means our gradients will depend on MORE parameters during training which include the mean and variance of the dataset each sample comes from which they come from (cough covariate shift cough). This reduces generalization accuracy significantly since out in the wild since you can imagine that your training data distribution may not match your testing set distribution. We would benefit from normalizing our input data with respect to the underlying distribution of the dataset we expect to feed it since this would reduce this phenomenon (though not get rid of it entirely).
In this case our network would have a hard time learning without batch norm if we fed it the untransformed data signals because when gradients propagate it would actually have to take into account that the data undergoes this affine transformation with respect to the distribution it came from, which is the assumption we make if we assume it comes from a Gaussian distribution. To conclude, put even more simply: our network without batch norm is learning based on the un-normalized input we gave it, which undergoes a transform f(x)
that DEPENDS on the distribution x
came from, but it is in fact way simpler for the network to learn on x
itself. However, this explanation doesn’t capture the rich nature of what is actually going on when you train a neural network which I have described above.
From the equation above we can see that we approximate the batch with respect to the entire datasaet. The reason this works is that with random shuffling we can assume that the batch should approximate the entire dataset in theory.
Conclusion
I feel like this topic is explained poorly in general. I didn’t learn it in my machine learning class, though I feel like it can be explained very simply with respect to an understanding of statistics. Hopefully this post has helped anyone still struggling with understanding why batch norm is used. I believe there are a ton of ways of looking at this topic, including the aspect of “covariate shift”, but I feel like I explained it better without using too much statistical terminology.
Back