\( \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\RR}{\mathbb{R}} \newcommand{\ZZ}{\mathbb{Z}} \newcommand{\eps}{\varepsilon} \)

In these notes we will explicitly derive the equations to use when backpropagating through a linear layer, using minibatches. Following a similar thought process can help you backpropagate through other types of computations involving matrices and tensors.

Forward Pass

During the forward pass, the linear layer takes an input $X$ of shape $N\times D$ and a weight matrix $W$ of shape $D\times M$, and computes an output $Y=XW$ of shape $N\times M$ by computing the matrix product of the two inputs. To make things even more concrete, we will consider the case $N=2$, $D=2$, $M=3$.

We can then write out the forward pass in terms of the elements of the inputs:

\[ X = \begin{pmatrix} x_{1,1} & x_{1,2} \\ x_{2,1} & x_{2,2} \end{pmatrix} \hspace{2pc} W = \begin{pmatrix} w_{1,1} & w_{1,2} & w_{1,3} \\ w_{2,1} & w_{2,2} & w_{2,3} \end{pmatrix} \]

\[ Y = XW = \begin{pmatrix} x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1}w_{1,3} + x_{1,2}w_{2,3} \\ x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3} \end{pmatrix} \]

After the forward pass, we assume that the output will be used in other parts of the network, and will eventually be used to compute a scalar loss $L$.

Backward Pass

During the backward pass through the linear layer, we assume that the upstream gradient $\pd{L}{Y}$ has already been computed. For example if the linear layer is part of a linear classifier, then the matrix $Y$ gives class scores; these scores are fed to a loss function (such as the softmax or multiclass SVM loss) which computes the scalar loss $L$ and derivative $\pd{L}{Y}$ of the loss with respect to the scores.

Since $L$ is a scalar and $Y$ is a matrix of shape $N\times M$, the gradient $\pd{L}{Y}$ will be a matrix with the same shape as $Y$, where each element of $\pd{L}{Y}$ gives the derivative of the loss $L$ with respect to one element of $Y$:

\[ \pd{L}{Y} = \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\ \pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix} \]

During the backward pass our goal is to use $\pd{L}{Y}$ in order to compute the downstream gradients $\pd{L}{X}$ and $\pd{L}{W}$. Again, since $L$ is a scalar we know that $\pd{L}{X}$ must have the same shape as $X$ ($N\times D$) and $\pd{L}{W}$ must have the same shape as $W$ ($D\times M$).

By the chain rule, we know that:

\[\pd{L}{X} = \pd{Y}{X}\pd{L}{Y} \hspace{4pc} \pd{L}{W} = \pd{Y}{W}\pd{L}{Y}\]

The terms $\pd{Y}{X}$ and $\pd{Y}{W}$ in this equation are Jacobians containing the partial derivative of each element of $Y$ with respect to each element of the inputs $X$ and $W$. These equations thus tell us that the downstream gradients can be computed using a matrix-vector product between the Jacobians and the upstream gradients.

However we do not want to form the Jacobian matrices $\pd{Y}{X}$ and $\pd{Y}{W}$ explicitly, because they will be very large. In a typical neural network we might have $N=64$ and $M=D=4096$. Then $\pd{Y}{X}$ consists of $64\cdot4096\cdot64\cdot4096$ scalar values. This is more than 68 billion numbers: using 32-bit floating point numbers, this Jacobian matrix will take 256 GB of memory to store. Therefore it is completely hopeless to try and explicitly store and manipulate the Jacobian matrix.

However it turns out that for most common neural network layers, we can derive expressions that compute the product $\pd{Y}{X}\pd{L}{Y}$ without explicitly forming the Jacobian $\pd{Y}{X}$. Even better, we can typically derive this expression without even computing an explicit expression for the Jacobian $\pd{Y}{X}$; in many cases we can work out a small case on paper and then infer the general formula.

Let’s see how this works out for our specific case of $N=2$, $D=2$, $M=3$. We first tackle $\pd{L}{X}$. Again, we know that $\pd{L}{X}$ must have the same shape as $X$:

\begin{equation} \label{eq:dLdX} X = \begin{pmatrix} x_{1,1} & x_{1,2} \\ x_{2,1} & x_{2,2} \end{pmatrix} \implies \pd{L}{X} = \begin{pmatrix} \pd{L}{x_{1,1}} & \pd{L}{x_{1,2}} \\ \pd{L}{x_{2,1}} & \pd{L}{x_{2,2}} \end{pmatrix} \end{equation}

Computing $\pd{L}{x _ {1,1}}$

We can proceed one element of a time. First we will compute $\pd{L}{x _ {1,1}}$. By the chain rule, we know that

\begin{equation} \label{eq:dLdX11-chain} \pd{L}{x _ {1,1}} = \sum _ {i=1}^N\sum _ {j=1}^M \pd{L}{y _ {i,j}}\pd{y _ {i,j}}{x _ {1,1}} = \pd{L}{Y} \cdot \pd{Y}{x _ {1,1}} \end{equation}

In the above equation $L$ and $x _ {1,1}$ are scalars so $\pd{L}{x _ {1,1}}$ is also a scalar. If we view $Y$ not as a matrix but as a collection of intermediate scalar variables, then we can use the chain rule to write $\pd{L}{x _ {1,1}}$ solely in terms of scalar derivatives.

To avoid working with sums, it is convenient to collect all terms $\pd{L}{y _ {i,j}}$ into a single matrix $\pd{L}{Y}$; here $L$ is a scalar and $Y$ is a matrix, so $\pd{L}{Y}$ has the same shape as $Y$ $(N\times M)$, where each element of $\pd{L}{Y}$ gives the derivative of $L$ with respect to one element of $Y$. We similarly collect all terms $\pd{y _ {i,j}}{x _ {1,1}}$ into a single matrix $\pd{Y}{x _ {1,1}}$; since $Y$ is a matrix and $x _ {1,1}$ is a scalar, $\pd{Y}{x _ {1,1}}$ is a matrix with the same shape as $Y$ ($N\times M$).

Since $\pd{L}{x _ {1,1}}$ is a scalar, we know that the product of $\pd{L}{Y}$ and $\pd{Y}{x _ {1,1}}$ must be a scalar; by inspecting the expression using only scalar derivatives, it is clear that in this context the product of $\pd{L}{Y}$ and $\pd{Y}{x _ {1,1}}$ must be a dot product.

In the backward pass we are already given $\pd{L}{Y}$, so we only need to compute $\pd{L}{x _ {1,1}}$.

Recall that $Y$ is defined as:

\[ Y = XW = \begin{pmatrix} x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1}w_{1,3} + x_{1,2}w_{2,3} \\ x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3} \end{pmatrix} \]

Therefore we can easily compute that $\pd{L}{x _ {1,1}}$ is:

\begin{equation} \label{eq:dYdX11} \pd{Y}{x _ {1,1}} = \begin{pmatrix} w_{1,1} & w_{1,2} & w_{1,3} \\ 0 & 0 & 0 \end{pmatrix} \end{equation}

Combining Equations \ref{eq:dLdX11-chain} and \ref{eq:dYdX11} now gives:

\begin{align} \pd{L}{x_{1,1}} &= \pd{L}{Y}\cdot\pd{Y}{x_{1,1}} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\ \pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix}\cdot \begin{pmatrix} w_{1,1} & w_{1,2} & w_{1,3} \\ 0 & 0 & 0 \end{pmatrix} \\ &= \pd{L}{y_{1,1}}w_{1,1} + \pd{L}{y_{1,2}}w_{1,2} + \pd{L}{y_{1,3}}w_{1,3} \label{eq:dLdX11} \end{align}

Computing $\pd{L}{X}$

We can now repeat the same computation to derive expressions for the other entries of $\pd{L}{X}$.

For $\pd{L}{x _ {1,2}}$ we compute:

\begin{align} \pd{L}{x_{1,2}} &= \pd{L}{Y}\cdot\pd{Y}{x_{1,2}} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\ \pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix}\cdot \begin{pmatrix} w_{2,1} & w_{2,2} & w_{2,3} \\ 0 & 0 & 0 \end{pmatrix} \\ &= \pd{L}{y_{1,1}}w_{2,1} + \pd{L}{y_{1,2}}w_{2,2} + \pd{L}{y_{1,3}}w_{2,3} \label{eq:dLdX12} \end{align}

For $\pd{L}{x _ {2,1}}$ we compute:

\begin{align} \\ \pd{L}{x_{2,1}} &= \pd{L}{Y}\cdot\pd{Y}{x_{2,1}} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\ \pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix}\cdot \begin{pmatrix} 0 & 0 & 0 \\ w_{1,1} & w_{1,2} & w_{1,3} \end{pmatrix} \\ &= \pd{L}{y_{2,1}}w_{1,1} + \pd{L}{y_{2,2}}w_{1,2} + \pd{L}{y_{2,3}}w_{1,3} \label{eq:dLdX21} \end{align}

For $\pd{L}{x _ {2, 2}}$ we compute:

\begin{align} \pd{L}{x_{2,2}} &= \pd{L}{Y}\cdot\pd{Y}{x_{2,2}} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\ \pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix}\cdot \begin{pmatrix} 0 & 0 & 0 \\ w_{2,1} & w_{2,2} & w_{2,3} \end{pmatrix} \\ &= \pd{L}{y_{2,1}}w_{2,1} + \pd{L}{y_{2,2}}w_{2,2} + \pd{L}{y_{2,3}}w_{2,3} \label{eq:dLdX22} \end{align}

We can now combine Equations \ref{eq:dLdX11}, \ref{eq:dLdX12}, \ref{eq:dLdX21}, and \ref{eq:dLdX22} to give a single expression for $\pd{L}{X}$ in terms of $W$ and $\pd{L}{Y}$:

\begin{align} \pd{L}{X} &= \begin{pmatrix} \pd{L}{x_{1,1}} & \pd{L}{x_{1,2}} \\\pd{L}{x_{2,1}} & \pd{L}{x_{2,2}} \end{pmatrix} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} w_{1,1}+ \pd{L}{y_{1,2}} w_{1,2}+ \pd{L}{y_{1,3}} w_{1,3}& \pd{L}{y_{1,1}} w_{2,1}+ \pd{L}{y_{1,2}} w_{2,2}+ \pd{L}{y_{1,3}} w_{2,3} \\\pd{L}{y_{2,1}} w_{1,1}+ \pd{L}{y_{2,2}} w_{1,2}+ \pd{L}{y_{2,3}} w_{1,3}& \pd{L}{y_{2,1}} w_{2,1}+ \pd{L}{y_{2,2}} w_{2,2}+ \pd{L}{y_{2,3}} w_{2,3} \end{pmatrix} \\ &= \begin{pmatrix} \pd{L}{y_{1,1}} & \pd{L}{y_{1,2}} & \pd{L}{y_{1,3}} \\\pd{L}{y_{2,1}} & \pd{L}{y_{2,2}} & \pd{L}{y_{2,3}} \end{pmatrix} \begin{pmatrix} w_{1,1} & w_{2,1} \\w_{1,2} & w_{2,2} \\w_{1,3} & w_{2,3} \end{pmatrix} \\ &= \boxed{\pd{L}{Y}W^T} \label{eq:dLdX-backprop} \end{align}

In Equation \ref{eq:dLdX-backprop}, recall that $\pd{L}{Y}$ is a matrix of shape $N\times M$ and $W$ is a matrix of shape $D\times M$; thus $\pd{L}{X}=\pd{L}{Y}W^T$ has shape $N\times D$, which is the same shape as $X$.

The expression $\pd{L}{X}=\pd{L}{Y}W^T$ allows us to backpropagate without explicitly forming the Jacobian $\pd{Y}{X}$!

Computing $\pd{L}{W}$

We can follow a similar strategy to derive an expression for $\pd{L}{W}$ without forming the Jacobian $\pd{Y}{W}$.

Recall that $X$ has shape $N\times D$, $W$ has shape $N\times M$, and $Y=XW$ has shape $N\times M$.

The gradient $\pd{L}{W}$ gives the effect on each element of $W$ on the scalar loss $L$, so it should be a matrix of partial derivatives of the same shape as $W$:

\[ W = \begin{pmatrix} w _ {1,1} & w _ {1,2} & w _ {1,3} \\ w _ {2,1} & w _ {2,2} & w _ {2,3} \end{pmatrix} \implies \pd{L}{W} = \begin{pmatrix} \pd{L}{w _ {1,1}} & \pd{L}{w _ {1,2}} & \pd{L}{w _ {1, 3}} \\\pd{L}{w _ {1,1}} & \pd{L}{w _ {1,2}} & \pd{L}{w _ {1, 3}} \end{pmatrix} \]

Similar to $\pd{L}{X}$, we can proceed one element at a time. We need to compute $\pd{L}{w _ {i,j}}$. We can use the chain rule to derive an expression similar to Equation \ref{eq:dLdX11-chain}:

\begin{equation} \pd{L}{w _ {i,j}} = \sum _ {i’=1}^N\sum _ {j’=1}^M\pd{L}{y _ {i’,j’}}\pd{y _ {i’,j’}}{w _ {i,j}} = \pd{L}{Y} \cdot \pd{Y}{w _ {i,j}} \label{eq:dLdW-chain} \end{equation}

As before, the term $\pd{L}{Y}$ are the upstream gradients of the loss with respect to the layer outputs. We assume these have already been computed. To compute the $\pd{Y}{w _ {i,j}}$ terms we recall that

\[ Y = XW = \begin{pmatrix} x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1}w_{1,3} + x_{1,2}w_{2,3} \\ x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3} \end{pmatrix} \]

We can then compute the required partial derivatives:

\[ \pd{Y}{w _ {1,1}} = \begin{pmatrix} x _ {1,1} & 0 & 0 \\ x _ {2, 1} & 0 & 0 \end{pmatrix} \hspace{4pc} \pd{Y}{w _ {1,2}} = \begin{pmatrix} 0 & x _ {1,1} & 0 \\ 0 & x _ {2,1} & 0 \end{pmatrix} \hspace{4pc} \pd{Y}{w _ {1,3}} = \begin{pmatrix} 0 & 0 & x _ {1,1} \\ 0 & 0 & x _ {2,1} \end{pmatrix} \]

\[ \pd{Y}{w _ {2,1}} = \begin{pmatrix} x _ {1,2} & 0 & 0 \\ x _ {2,2} & 0 & 0 \end{pmatrix} \hspace{4pc} \pd{Y}{w _ {2,2}} = \begin{pmatrix} 0 & x _ {1,2} & 0 \\ 0 & x _ {2,2} & 0 \end{pmatrix} \hspace{4pc} \pd{Y}{w _ {2,3}} = \begin{pmatrix} 0 & 0 & x _ {1,2} \\ 0 & 0 & x _ {2,2} \end{pmatrix} \]

We can now apply Equation \ref{eq:dLdW-chain} to each of these six terms to derive expressions for each element of $\pd{L}{W}$:

\[ \pd{L}{w _ {1,1}} = \pd{L}{Y}\cdot\pd{Y}{w _ {1,1}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} x _ {1,1} & 0 & 0 \\ x _ {2,1} & 0 & 0 \end{pmatrix} = \pd{L}{y _ {1,1}}x _ {1,1} + \pd{L}{y _ {2,1}}x _ {2,1} \]

\[ \pd{L}{w _ {1,2}} = \pd{L}{Y}\cdot\pd{Y}{w _ {1,2}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} 0 & x _ {1,1} & 0 \\ 0 & x _ {2,1} & 0 \end{pmatrix} = \pd{L}{y _ {1,2}}x _ {1,1} + \pd{L}{y _ {2,2}}x _ {2,1} \]

\[ \pd{L}{w _ {1,3}} = \pd{L}{Y}\cdot\pd{Y}{w _ {1,3}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} 0 & 0 & x _ {1,1} \\ 0 & 0 & x _ {2,1} \end{pmatrix} = \pd{L}{y _ {1,3}}x _ {1,1} + \pd{L}{y _ {2,3}}x _ {2,1} \]

\[ \pd{L}{w _ {2,1}} = \pd{L}{Y}\cdot\pd{Y}{w _ {2,1}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} x _ {1,2} & 0 & 0 \\ x _ {2,2} & 0 & 0 \end{pmatrix} = \pd{L}{y _ {1,1}}x _ {1,2} + \pd{L}{y _ {2,1}}x _ {2,2} \]

\[ \pd{L}{w _ {2,2}} = \pd{L}{Y}\cdot\pd{Y}{w _ {2,2}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} 0 & x _ {1,2} & 0 \\ 0 & x _ {2,2} & 0 \end{pmatrix} = \pd{L}{y _ {1,2}}x _ {1,2} + \pd{L}{y _ {2,2}}x _ {2,2} \]

\[ \pd{L}{w _ {2,3}} = \pd{L}{Y}\cdot\pd{Y}{w _ {2,3}} = \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \cdot \begin{pmatrix} 0 & 0 & x _ {1,2} \\ 0 & 0 & x _ {2,2} \end{pmatrix} = \pd{L}{y _ {1,3}}x _ {1,2} + \pd{L}{y _ {2,3}}x _ {2,2} \]

We can combine these six results into a single expression for $\pd{L}{W}$:

\begin{align} \pd{L}{W} &= \begin{pmatrix} \pd{L}{w _ {1,1}} & \pd{L}{w _ {1,2}} & \pd{L}{w _ {1,3}}
\\ \pd{L}{w _ {2,1}} & \pd{L}{w _ {2,2}} & \pd{L}{w _ {2,3}} \end{pmatrix} \\ &= \begin{pmatrix} \pd{L}{y _ {1,1}}x_ {1,1} + \pd{L}{y _ {2,1}}x _ {2,1} & \pd{L}{y _ {1,2}}x_ {1,1} + \pd{L}{y _ {2,2}}x _ {2,1} & \pd{L}{y _ {1,3}}x_ {1,1} + \pd{L}{y _ {2,3}}x _ {2,1} & \\ \pd{L}{y _ {1,1}}x_ {1,2} + \pd{L}{y _ {2,1}}x _ {2,2} & \pd{L}{y _ {1,2}}x_ {1,2} + \pd{L}{y _ {2,2}}x _ {2,2} & \pd{L}{y _ {1,3}}x_ {1,2} + \pd{L}{y _ {2,3}}x _ {2,2} & \end{pmatrix} \\ &= \begin{pmatrix} x _ {1,1} & x _ {2,1} \\ x _ {1,2} & x _ {2,2} \end{pmatrix} \begin{pmatrix} \pd{L}{y _ {1,1}} & \pd{L}{y _ {1,2}} & \pd{L}{y _ {1,3}} \\ \pd{L}{y _ {2,1}} & \pd{L}{y _ {2,2}} & \pd{L}{y _ {2,3}} \end{pmatrix} \\ &= \boxed{X^T\pd{L}{Y}} \end{align}

This allows us to compute the downstream gradient $\pd{L}{W}$ as the matrix-matrix product of $X^T$ and the upstream gradient $\pd{L}{Y}$. We never need to explicitly form the Jacobian $\pd{Y}{W}$!

Summary

In summary, we have derived the backpropagation expressions for the matrix-matrix product $Y=XW$:

\begin{equation} \boxed{\pd{L}{X} = \pd{L}{Y}W^T} \hspace{6pc} \boxed{\pd{L}{W} = X^T\pd{L}{Y}} \label{eq:final} \end{equation}

Our derivation covers the specific case where $X$ has shape $2\times2$ and $W$ has shape $2\times3$, but in fact the expressions in Equation \ref{eq:final} hold in general for matrices of any shape.