Automatic Differentiation in PyTorch
Introduction
Calculating gradients manually is tedious and error-prone. Autodiff allows us to automatically compute gradients of computations defined in a programming language like Python.
PyTorch uses reverse mode autodiff to efficiently calculate gradients. It records operations performed on tensors to build up a computational graph, and then applies chain rule to backpropagate gradients.
We'll start with a brief background on calculating derivatives mathematically, then explain how autodiff builds on these concepts. We'll go through examples of using PyTorch's autodiff and discuss how it works under the hood.
Mathematical Background
Let's quickly review some mathematical basics for computing derivatives.
The derivative of a function $f(x)$ is defined as:
$$f'(x) = \lim_{h \to 0} \frac{f(x + h) - f(x)}{h}$$
The derivative tells us the rate of change of $f(x)$ with respect to $x$. It allows us to find the tangent line approximation at a point $x$:
$$f(x + h) \approx f(x) + f'(x)h$$
Using the limit definition directly is impractical for complicated functions. Instead, we can use rules of calculus like the chain rule to break down derivatives.
The chain rule allows us to compute derivatives of nested functions:
$$(f(g(x)))' = f'(g(x)) \cdot g'(x)$$
For example, if $f(x) = x^2$ and $g(x) = x + 1$, then:
\begin{align*}
(f(g(x)))'& = f'(g(x)) \cdot g'(x) \
&= 2(x + 1) \cdot 1 \
&= 2(x + 1)
\end{align*}
We applied the chain rule to break down the derivative into simpler pieces. This forms the basis for how automatic differentiation works.
Automatic Differentiation
Autodiff tools like PyTorch reverse mode autodiff allow us to compute derivatives of arbitrary programs. The key ideas are:
- Treat programs as compositions of elementary operations like +, *, sin(), etc.
- Apply chain rule to backpropagate derivatives through these operations.
- Store intermediate values during forward pass to efficiently compute gradients on the backwards pass.
Let's see a simple example:
import torch
x = torch.tensor(3.0, requires_grad=True)
y = x**2 + x*2 + 1
y.backward()
print(x.grad) # 12
By setting requires_grad=True
, we tell PyTorch to track computations involving x
.
y = x**2 + x*2 + 1
performs the "forward pass" and builds up a computational graph tracking each operation.
Calling backward()
kicks off backpropagation, with PyTorch applying chain rule to compute gradients d$y$/d$x$.
So how does PyTorch accomplish this behind the scenes?
Recording Operations with Graphs
PyTorch uses a technique called reverse mode differentiation.
As each operation is applied, PyTorch records them to build up a computational graph. Nodes in the graph represent operations, and edges represent data dependencies.
For example, the graph for our example would look like:
+ <--- x
/ \
* 2
/ \
**
|
x
The graph keeps track of how y
was computed from input x
. We can now apply chain rule to backpropagate gradients along this graph.
Backpropagation with Chain Rule
To backpropagate gradients, PyTorch recursively applies chain rule starting from the final node:
- Initialize $\frac{dy}{dx} = 1$ since $y$ is the final output.
- Apply chain rule backward through each operation:
- For addition, $\frac{d(x+y)}{dz} = \frac{dx}{dz} + \frac{dy}{dz}$
- For multiplication, $\frac{d(xy)}{dz} = x \frac{dy}{dz} + y \frac{dx}{dz}$
- For exponent, $\frac{d(x^y)}{dx} = yx^{y-1}$
- Store intermediate $\frac{dx}{dz}$ results along the way.
- Overall gradient is $\frac{dy}{dx}$ at the input node $x$.
Unrolling this whole process would give:
\begin{align*}
\frac{dy}{dx} &= \frac{d(x^2 + 2x + 1)}{dx} \
&= \frac{d(x^2)}{dx} + \frac{d(2x)}{dx} + \frac{d(1)}{dx} \
&= 2x\frac{dx}{dx} + 2\frac{dx}{dx} + 0 \
&= 2x + 2 = 12
\end{align*}
By recursively applying chain rule, we can compute gradients automatically!
Implementing Autograd in PyTorch
Let's now dive into how PyTorch implements autodiff under the hood. The key components are:
Forward pass - Records operations on tensors to build computational graph.
Backward pass - Backpropagates gradients along the graph using chain rule.
Function
- Abstraction representing primitive operations.
Tensor
- Wraps data and gradients.
autograd
- Automatic differentiation engine.
The autograd package provides all the necessary bookkeeping for backward passes. But let's see how we could build a simple autograd implementation ourselves.
Forward Pass
During the forward pass, PyTorch wraps tensors with a Variable
class that records operations applied to compute its value:
class Variable:
def __init__(self, value, grad_fn=None):
self.value = value
self.grad_fn = grad_fn
The grad_fn
tracks the operation that produced this variable, allowing us to walk the graph backwards later.
Each operation is wrapped in a Function
class:
class Function:
def __init__(self, *inputs):
self.inputs = inputs
def forward(self, *args):
# Forward logic
...
def backward(self, grad_output):
# Backward logic
...
For example, an Add
function would look like:
class Add(Function):
def forward(self, x, y):
return x + y
def backward(self, grad_output):
return grad_output, grad_output
To hook everything together:
x = Variable(torch.ones(2), requires_grad=True)
y = Variable(torch.randn(2))
z = Add().forward(x, y)
z.grad_fn = Add() # Store operation for backward
This builds up the computational graph tracking each operation.
Backward Pass
The backward pass kicks off from the final variable z
and recursively applies chain rule:
def backward(z):
if z.grad_fn is None: # Base case
return
grads = z.grad_fn.backward(1.0) # Apply chain rule
for var, grad in zip(z.grad_fn.inputs, grads):
var.grad = grad
backward(var) # Recursively backprop
We pass in dout/dz = 1
to kick off backpropagation. Each Function
computes dout/dx
and dout/dy
for its inputs.
The gradients are stored on the input variables like x.grad = dout/dx
to get the overall gradient at each node.
And there we have a basic autograd implementation! The full PyTorch autograd is more complex, but implements the same principles.
Benefits of Autograd
There are several key benefits of using autograd:
Automatic - No need to manually implement backpropagation. Cleaner and less error prone.
Efficient - Performs all computations in single forward/backward passes. Stores intermediate values instead of recomputing.
Flexible - Works on any code, not just neural networks. Supports dynamic graphs.
Debuggable - Easy to inspect forward graphs and backward gradients.
Autograd allows us to focus on high-level model logic rather than error-prone gradient calculations.
Variable Retention Mode
Earlier we saw how PyTorch stores intermediate values during the forward pass to reuse during the backward pass. This approach is called "define-by-run" - the graph is defined on the fly.
An alternative approach is "define-and-run", where we first define the computational graph, then run forward and backward passes later.
PyTorch supports this via variable retention mode:
x = torch.tensor(..., requires_grad=True)
with torch.no_grad():
# Temporarily disable autograd
y = x * 2
# y does not require grad since it was computed inside no_grad
y.backward() # Error!
Here, y
is computed without autograd, so it cannot be used in a backward pass later.
For linear algebra, PyTorch batch-processes define-and-run graphs for efficiency. But define-by-run is more flexible for neural nets with dynamic control flow.
Variable retention mode also allows checkpointing models during training to reduce memory usage.
Support for Non-Differentiable Operations
Autograd requires operations to be differentiable to compute gradients. But some operations like ReLUs have non-differentiable kinks:
def relu(x):
return max(0, x)
PyTorch handles non-differentiable functions with subgradients. The backward pass defines a subgradient for points where the gradient is not defined.
For ReLU:
def backward(grad_output):
subgrad = 0 if input < 0 else 1
return subgrad * grad_output
This allows backpropagation to continue, bypassing kinks with appropriate subgradients. Extensions like higher provide fully smooth gradients for non-differentiable functions.
Accelerating Autograd
While autodiff is convenient, it can have high overhead from tracking computations. Some ways PyTorch optimizes performance:
Just-in-Time Compilation - Use deferred bytecode and caching to reduce overhead.
In-place operations - Avoid unnecessary allocations during intermediate steps.
Asynchronous execution - Process multiple graph nodes simultaneously.
CUDA support - Offload backward pass to GPU when possible.
Bulk computations - Vectorize calculations when operating on batches.
Efficient implementation is key to making autodiff practical for deep networks.
Applications Beyond Deep Learning
While autograd was pioneered for deep learning, it enables new applications like:
Probabilistic modeling - Easily train complex Bayesian networks with variational inference.
Differential equation solving - Optimize parameters of ODE/PDE solvers.
Model analysis - Compute saliency maps and other insights by analyzing gradients.
Hyperparameter optimization - Train models with automatic hyperparameters gradients.
Meta learning - Learn-to-learn approaches like MAML use gradients over gradients.
The flexibility of imperative differentiation allows models to be treated just like any other Python program.
Conclusion
We covered a lot of ground on how automatic differentiation works in PyTorch. The key takeaways are:
- Autograd computes gradients by recording operations and then reversing them with chain rule.
- This allows any Python code to be differentiated, enabling powerful applications.
- PyTorch uses dynamic define-by-run graphs along with optimizations like CUDA and JIT to make it fast.
- Extensions like subgradients and higher derivatives increase the flexibility of autograd.
Automatic differentiation has revolutionized deep learning by allowing models to be efficiently trained. And its support in frameworks like PyTorch makes the powerful technique easily accessible to all Python programmers.
Looking forward, we can expect more applications to be unlocked as autograd spreads beyond differentiable models like neural networks. The ability to work with gradients programmatically opens up exciting possibilities in scientific computing and beyond.
References