How PyTorch backward() Propagates Gradients to Params
Technical explanation of how PyTorch autograd and backward() build the dynamic graph and accumulate gradients into linear layer weights and biases for SGD.
How does PyTorch’s backward() function propagate gradients from a loss tensor back to model parameters like the linear layer’s weights and biases?
Consider this Softmax Regression model implementation:
class SoftmaxRegission(torch.nn.Module):
linear: torch.nn.Linear
def __init__(self, num_features: int, num_classes: int):
super(SoftmaxRegission, self).__init__()
self.linear = torch.nn.Linear(num_features, num_classes)
self.linear.weight.detach().zero_()
self.linear.bias.detach().zero_()
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
logits: torch.Tensor = self.linear(x)
probas = F.softmax(logits, dim=1)
return logits, probas
model = SoftmaxRegission(num_features=num_features, num_classes=num_classes)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
torch.manual_seed(random_seed)
def compute_accuracy(model: SoftmaxRegission, data_loader: DataLoader[datasets.MNIST]):
correct_pred: torch.Tensor
correct_pred, num_examples = torch.Tensor(0), 0
for features, targets in data_loader:
features: torch.Tensor = features.view(-1, 28 * 28).to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float() / num_examples * 100
start_time = time.time()
epoch_costs: list[int] = []
for epoch in range(num_epochs):
avg_cost = 0
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.view(-1, 28 * 28).to(device)
targets = targets.to(device)
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
# Tensor has a callback to upgrade the grad, finally this will change the linear in model
cost.backward()
avg_cost += cost
optimizer.step()
Initially, tensors seemed like simple values, but cost.backward() updates gradients in the SoftmaxRegission model’s linear layer. Having read the PyTorch autograd documentation, seeking deeper internal implementation details: How does PyTorch achieve this? With C++ and Rust experience, wondering if tensors hold pointers to model parameters or use callbacks upon creation. Reviewed Python Tensor source but found no clear mechanism; it extends a C-based class, but struggled with CPython C code.
PyTorch’s backward() function leverages the autograd engine to propagate gradients from a scalar loss tensor—like your cross-entropy cost—back through a dynamic computation graph to model parameters such as the linear layer’s weights and biases. Every operation in the forward pass, from self.linear(x) producing logits to F.softmax and finally F.cross_entropy, attaches a grad_fn to output tensors, forming a chain that encodes the chain rule for differentiation. Calling cost.backward() kicks off a reverse traversal of this graph, computing partial derivatives step-by-step and accumulating them directly into model.linear.weight.grad and model.linear.bias.grad for the optimizer to use.
Contents
- Building the Computation Graph in PyTorch Autograd
- What Happens Inside backward()
- grad_fn: The Key to Gradient Tracking
- Gradient Flow to Linear Layer Parameters
- Tracing Propagation in Your Softmax Regression Example
- Customizing and Debugging Backward Passes
- Sources
- Conclusion
Building the Computation Graph in PyTorch Autograd
Ever wonder why your simple forward pass suddenly “knows” how to update weights? It starts here, in the forward pass. PyTorch’s autograd doesn’t just compute values—it simultaneously constructs a directed acyclic graph (DAG) representing every tensor operation.
Take your model: logits = self.linear(features). The Linear module applies weight @ features.T + bias, but under the hood, this creates AddmmBackward (matrix multiply plus add) as the grad_fn for logits. Then probas = F.softmax(logits, dim=1) attaches SoftmaxBackward, and cost = F.cross_entropy(logits, targets)—which combines log softmax and NLL loss—tags on NllLossBackward.
This graph gets recreated fresh every iteration, which is why loops or conditionals work seamlessly. As the PyTorch autograd tutorial explains, autograd traces dynamically at runtime: “When computing the forward pass, autograd simultaneously performs the requested computations and builds up a graph.” No static compilation needed. Tensors don’t hold pointers to parameters directly; instead, the graph links outputs back to inputs via Function objects.
And here’s the clever part—parameters like self.linear.weight are “leaf” tensors (requires_grad=True by default in modules). They have no grad_fn themselves but get gradients populated later.
What Happens Inside backward()
You hit cost.backward(). Boom—propagation begins. Since cost is scalar (or treated as such for cross-entropy mean), it starts with gradient 1.0 for itself. PyTorch then calls the backward method on cost.grad_fn, which is typically NllLossBackward.
Each Function’s backward computes the local gradient Jacobian-vector product (JVP), multiplying incoming gradients by the derivative of that operation. It’s the chain rule in action: for composite , .
From the official Tensor.backward docs: “The graph is differentiated using the chain rule… Computes the gradient of current tensor wrt graph leaves.” It traverses backward: NLLLossBackward passes grads to LogSoftmaxBackward (inside cross_entropy), which feeds SoftmaxBackward, down to AddmmBackward for the linear op.
Non-scalar tensors? You can pass a gradient arg to backward(gradient=torch.ones_like(tensor)), but losses are usually scalar. Streams? CUDA ops respect them too, but that’s advanced.
Short version: no callbacks on tensor creation. The graph drives everything, evaluated topologically in reverse.
grad_fn: The Key to Gradient Tracking
Peeking at cost.grad_fn? You’ll see something like <NllLossBackward0Backward>. This isn’t magic—it’s a Function subclass holding forward inputs/outputs for backward computation. Every op (add, matmul, softmax) registers a twin backward function.
As this Stack Overflow breakdown notes: “grad_fn references the operation used to obtain the tensor.” It stores enough to reconstruct derivatives: saved tensors (like intermediate activations) avoid recompute.
In your code, model.linear.weight.grad_fn is None (leaf), but logits.grad_fn points to AddmmBackward, linking back to weight via the saved args. During backward, AddmmBackward.backward does:
dlogits/dweight = features^T @ dlogits
dlogits/dbias = sum(dlogits, dim=0)
Gradients accumulate in-place on .grad if retain_graph=False (default, frees graph after).
Rust/C++ angle? PyTorch’s core is C++/ATen, with THPObjectPtr wrapping CPython tensors. The graph lives in C++ autograd::Engine, dispatching via BackwardNode. No raw pointers from tensors to params—purely graph-referenced. Python sees the ergonomic side.
Gradient Flow to Linear Layer Parameters
How does it hit self.linear.weight.grad specifically? Linear’s forward saves inputs (input, weight, bias) in its Function. Backward fetches them:
- Incoming
dL/dlogitsfrom upstream (softmax + loss derivatives). dlogits/dweight = input^T @ dL/dlogits→ adds toweight.grad.dlogits/dbias = dL/dlogits.sum(dim=0)→bias.grad.dlogits/dinput = weight @ dL/dlogits→ propagates to previous layers.
From autograd mechanics docs: “We evaluate this graph in the backwards pass to compute the gradients.” Leaves (params) get .grad filled; non-leaves too if create_graph=True.
Your optimizer.zero_grad() clears them first—essential, or they’d accumulate forever. Then step() uses them for SGD updates.
Pitfall: Detach weight? Like your detach().zero_()—it breaks the graph for that op, but since it’s post-init, params stay trackable.
Tracing Propagation in Your Softmax Regression Example
Let’s walk your training loop. features → model.linear → logits (graph: AddmmBackward). F.cross_entropy(logits, targets) → cost (NllLossBackward + LogSoftmaxBackward).
cost.backward():
- Starts at cost.grad_fn.backward(grad=1.0).
- NLL: (approx).
- LogSoftmaxBackward applies softmax deriv.
- Hits AddmmBackward: computes weight/bias grads as above.
Print model.linear.weight.grad post-backward—you’ll see non-zero values proportional to errors. No direct tensor-to-param pointers; the saved_tensors in Functions reference them indirectly.
Why no C code visible? TensorBase in C++ holds metadata; Python torch.Tensor proxies via capsule. Graph eval is C+±heavy, per PyTorch autograd engine blog.
Customizing and Debugging Backward Passes
Want control? Subclass torch.autograd.Function:
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight, bias)
return input @ weight + bias
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
return grad_output @ weight.T, input.T @ grad_output, grad_output.sum(0)
As this forum post shows, wrap in forward: return CustomLinear.apply(x, self.weight, self.bias).
Debug? Check tensor.requires_grad, grad_fn is not None. Error like “no grad_fn”? Graph detached (e.g., .detach()). Hooks via register_hook() inspect grads mid-flow.
From Stack Overflow on backprop: Upstream scalar loss provides the seed gradient.
Sources
- torch.Tensor.backward — PyTorch 2.9 documentation
- Implementing backward function nn.Module - PyTorch Forums
- How does PyTorch module do the back prop - Stack Overflow
- How backward() works in PyTorch? - Stack Overflow
- Backward function in PyTorch - Stack Overflow
- A Gentle Introduction to torch.autograd — PyTorch Tutorials
- The Fundamentals of Autograd — PyTorch Tutorials
- Autograd mechanics — PyTorch 2.9 documentation
- Overview of PyTorch Autograd Engine – PyTorch
- In PyTorch, what exactly does the grad_fn attribute store? - Stack Overflow
Conclusion
PyTorch backward propagation shines through its dynamic autograd graph—no manual derivatives, just forward ops building the reverse path automatically. In your SoftmaxRegression, cost.backward() chains from loss to linear params via grad_fns like AddmmBackward, filling .grad for SGD. Master this, and tweaking custom layers or debugging gradients becomes second nature. Experiment with torchviz to visualize graphs; it’ll click fast.