When using PyTorch to train a neural network model, an important step is backpropagation like this:

``````loss = criterion(y_pred, y)
loss.backward()``````

The gradient of weight tensors are calculated here though none of the gradients or the weight matrix ever appear. The mathematical formulation of backpropagation is clear. In this post, we try to understand how PyTorch calculate gradient with this `loss.backward()` function from the perspective of programming.

Some of PyTorch source code is based on C++. We will mostly focus on the Python part and give one example of C++ to help us understand.

## Background Knowledge

### `backward()` method

PyTorch uses the `autograd` package for automatic differentiation. For a tensor `y`, we can calculate the gradient with respect to input with two methods. They are equal:

``y.backward()``
``torch.autograd.backward(y)``

After we do the `.backward()`, we can check the gradient value using:

``x.grad()``

### Calculation Graph

PyTorch generates a Dynamic Computation Graph when the forward function of network is called. We borrow a toy example from here.

If our model is like this and you actually run that (not just define a model):

``````a = torch.tensor(2.0, requires_grad=True)
c = a + b
e = c * d
e.backward()``````

A calculation graph is generated as you run the forward pass. The corresponding graph of this model is like this: The nodes represent the tensors and the circles represent the operation. What's more important, when the forward calculation graph is generated, a backward graph is simultaneously generated to calculat the gradient. The whole picture is like this: We can see there are some notions in yellow rectangles. They actually implement the backpropagation operation. We will talk about them later.

We use the abovementioned example to explain. First let's look at this tensor `e`:

``tensor(20., grad_fn=<MulBackward0>)``

You find an attribute called `grad_fn`. Based on the documentationm `grad_fn` stores the reference to the backward propagation function of the tensor. That's confusing.

To be straightforward, `grad_fn` stores the according backpropagation method based on how the tensor (`e` here) is calculated in the forward pass. In this case `e = c * d`, `e` is generated through multiplication. So `grad_fn` here is `MulBackward0`, which means it is a backpropagation operation for `multiplication`.

`grad_fn` has a method called `next_functions`, we check `e.grad_fn.next_functions`, it returns a tuple of tuple:

``((<AddBackward0 at 0x268c6d3e668>, 0), (<AccumulateGrad at 0x268c6d3e588>, 0))``

If you remember the meaning of `MulBackward0`, you will notice here `AddBackward0` represents an addition operation in forward pass. Considering the fact that `e = (a+b) * d`, the pattern is clear: `grad_fn` traverse all members in its `next_functions` to use a chain structure in the gradient calculation process.

In this case, to calculate gradient of `e` with respect to input `a`, it need to both calculate the gradients of multiplication operation and then the addition operation.

`AccumulateGrad` is similar to `MulBackward0` and `AddBackward0`, it belongs to the original input. In this case `d` is an original input which is not `c`calculated as `a+b`.

And `AccumulateGrad` has two methods called `next_functions` and `variable`.

If we go further to `c = a + b`, we pick `AddBackward0` and take a look at its `next_functions`:

``````ag = e.grad_fn.next_functions #<AddBackward0 at 0x268c6d3e668>
ag.next_functions

Since `c = a + b` and `a` and `b` are original inputs, the next_function is `AccumulateGrad` for `a` and `b` as same as that for `d`.

We then check its `variable`:

``ag.variable``

It will returns an error: `AttributeError: 'AddBackward0' object has no attribute 'variable'​`. The reason is that here `c = a + b`. `c` is not an original input. So it's not stored as `variable`.

However, if we go with `AccumulateGrad` (`d`):

``````ag = e.grad_fn.next_functions #<AddBackward0 at 0x268c6d3e668>
ag.variable``````

It returns the variable:

``tensor(4., requires_grad=True)``

This should be familiar! It's just the definition of `d` (`d = torch.tensor(4.0, requires_grad=True)`).

Similarly, if we go with the `AccumulateGrad` in `a` and `b`, we should also see familiar tensors. We omit that here.

Now we have see the whole process from the output `e` to input `a`, `b`, and `d`.

To sum up, when we call `e.backward()` to calculate the gradient, Pytorch first calculate the derivative of `e` for variables based on the traversal of `next_functions`. If the next_function is `AccumulateGrad`, it means that's an original input. The calculated gradient is stored in the variable's `.grad` attribute.

Pytorch keeps traversing the computation graph till it reaches all the original inputs. At that moment, the gradients of all inputs are updated. During this process, the chain rule is used: First derivative of multiplication (de/dc) is calculated, then derivative of addtion (dc/da) is calculated. Then we get a's gradient as (de/da).

### Derivative Calculation

Till now, we already understand the process of gradient calculation. But there's still one question, how is each derivative get calculated? The answer is extremly simple: It is stored in object such like `MulBackward0` and `AddBackward0` class. For example, the derivative of multiplication `e = c * d` is `de/dc = d`, `MulBackward0` object knows the value of `e`, `c`, and `d`. It directly return `d` as the output derivative. It does not do any real derivative calculation here.

A related question is how the derivative of activation function calculated? It's just as same as the `MulBackward0`. The class directly return the needed value without calculation. We use `sin()` activation as an example. We put an image of its C++ implementation here. We can see that `auto grad_result = grad * self.cos()`, which returns `cos()` as the output. No calculation, no intelligence, just memory.

## Conclusion

In this post, we introduce how the gradient is generated in Pytorch on the level of programming. We also show how exactly each derivative is calculated.

## References

To write this post, I read several posts related to this topic. They are really helpful. You may refer to them to acquire a better understanding. 