How does PyTorch calculate gradient: a programming perspective

When using PyTorch to train a neural network model, an important step is backpropagation like this:

loss = criterion(y_pred, y)
loss.backward()

The gradient of weight tensors are calculated here though none of the gradients or the weight matrix ever appear. The mathematical formulation of backpropagation is clear. In this post, we try to understand how PyTorch calculate gradient with this loss.backward() function from the perspective of programming.

Some of PyTorch source code is based on C++. We will mostly focus on the Python part and give one example of C++ to help us understand.


Background Knowledge

backward() method

PyTorch uses the autograd package for automatic differentiation. For a tensor y, we can calculate the gradient with respect to input with two methods. They are equal:

y.backward()
torch.autograd.backward(y)

After we do the .backward(), we can check the gradient value using:

x.grad()

Calculation Graph

PyTorch generates a Dynamic Computation Graph when the forward function of network is called. We borrow a toy example from here.

If our model is like this and you actually run that (not just define a model):

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a + b
d = torch.tensor(4.0, requires_grad=True)
e = c * d
e.backward()

A calculation graph is generated as you run the forward pass. The corresponding graph of this model is like this:

Image

The nodes represent the tensors and the circles represent the operation. What's more important, when the forward calculation graph is generated, a backward graph is simultaneously generated to calculat the gradient. The whole picture is like this:

Image

We can see there are some notions in yellow rectangles. They actually implement the backpropagation operation. We will talk about them later.

Gradient Calculation Process

We use the abovementioned example to explain. First let's look at this tensor e:

tensor(20., grad_fn=<MulBackward0>)

You find an attribute called grad_fn. Based on the documentationm grad_fn stores the reference to the backward propagation function of the tensor. That's confusing.

To be straightforward, grad_fn stores the according backpropagation method based on how the tensor (e here) is calculated in the forward pass. In this case e = c * d, e is generated through multiplication. So grad_fn here is MulBackward0, which means it is a backpropagation operation for multiplication.

grad_fn has a method called next_functions, we check e.grad_fn.next_functions, it returns a tuple of tuple:

((<AddBackward0 at 0x268c6d3e668>, 0), (<AccumulateGrad at 0x268c6d3e588>, 0))

If you remember the meaning of MulBackward0, you will notice here AddBackward0 represents an addition operation in forward pass. Considering the fact that e = (a+b) * d, the pattern is clear: grad_fn traverse all members in its next_functions to use a chain structure in the gradient calculation process.

In this case, to calculate gradient of e with respect to input a, it need to both calculate the gradients of multiplication operation and then the addition operation.

AccumulateGrad is similar to MulBackward0 and AddBackward0, it belongs to the original input. In this case d is an original input which is not ccalculated as a+b.

And AccumulateGrad has two methods called next_functions and variable.

If we go further to c = a + b, we pick AddBackward0 and take a look at its next_functions:

ag = e.grad_fn.next_functions[0][0] #<AddBackward0 at 0x268c6d3e668>
ag.next_functions

# ((<AccumulateGrad at 0x268c6d3e978>, 0),
# (<AccumulateGrad at 0x268c6d3e898>, 0))

Since c = a + b and a and b are original inputs, the next_function is AccumulateGrad for a and b as same as that for d.

We then check its variable:

ag.variable

It will returns an error: AttributeError: 'AddBackward0' object has no attribute 'variable'​. The reason is that here c = a + b. c is not an original input. So it's not stored as variable.

However, if we go with AccumulateGrad (d):

ag = e.grad_fn.next_functions[1][0] #<AddBackward0 at 0x268c6d3e668>
ag.variable

It returns the variable:

tensor(4., requires_grad=True)

This should be familiar! It's just the definition of d (d = torch.tensor(4.0, requires_grad=True)).

Similarly, if we go with the AccumulateGrad in a and b, we should also see familiar tensors. We omit that here.

Now we have see the whole process from the output e to input a, b, and d.

To sum up, when we call e.backward() to calculate the gradient, Pytorch first calculate the derivative of e for variables based on the traversal of next_functions. If the next_function is AccumulateGrad, it means that's an original input. The calculated gradient is stored in the variable's .grad attribute.

Pytorch keeps traversing the computation graph till it reaches all the original inputs. At that moment, the gradients of all inputs are updated. During this process, the chain rule is used: First derivative of multiplication (de/dc) is calculated, then derivative of addtion (dc/da) is calculated. Then we get a's gradient as (de/da).

Derivative Calculation

Till now, we already understand the process of gradient calculation. But there's still one question, how is each derivative get calculated? The answer is extremly simple: It is stored in object such like MulBackward0 and AddBackward0 class. For example, the derivative of multiplication e = c * d is de/dc = d, MulBackward0 object knows the value of e, c, and d. It directly return d as the output derivative. It does not do any real derivative calculation here.

A related question is how the derivative of activation function calculated? It's just as same as the MulBackward0. The class directly return the needed value without calculation. We use sin() activation as an example. We put an image of its C++ implementation here.

Image

We can see that auto grad_result = grad * self.cos(), which returns cos() as the output. No calculation, no intelligence, just memory.

Conclusion

In this post, we introduce how the gradient is generated in Pytorch on the level of programming. We also show how exactly each derivative is calculated.

References

To write this post, I read several posts related to this topic. They are really helpful. You may refer to them to acquire a better understanding.

  1. AutoGrad Official Documentation

  2. PyTorch Autograd Explained - In-depth Tutorial (video)

  3. PyTorch自动求导(Autograd)原理解析

  4. Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works

  5. PyTorch Internals 5:Autograd的实现

Comments
Write a Comment