大模型学习笔记2:链式法则、矩阵与反向传播

从链式法则到反向传播,理解梯度计算为何能高效扩展

本文用于解决上篇文章的遗留问题

简单回顾一下

如果想要知道期望输出 $y$ 和模型的真实输出 $\hat{y}$ 差,就使用如下公式,这个叫做 Loss

$$ Loss = (y - \hat{y})^2 $$

如果想要知道某个参数应该往哪个方向变,变多少,即计算变化率(坡度),则使用求导的方式,得到

$$ \frac{\partial Loss}{\partial w} $$

如果想要知道哪个参数影响最大,就把所有参数的排排站,就可以知道谁的影响最大

$$ \frac{\partial Loss}{\partial w_1}, \frac{\partial Loss}{\partial w_2}, ..., \frac{\partial Loss}{\partial w_n} $$

这个排序用向量的形式理解,就叫做梯度

$$ \nabla Loss = \begin{bmatrix} \frac{\partial Loss}{\partial w_1} \\ \frac{\partial Loss}{\partial w_2} \\ \vdots \\ \frac{\partial Loss}{\partial w_n} \end{bmatrix} $$

但是为什么要是向量,为什么要用矩阵,什么是链式法则,如何数学证明梯度下降方向是最优解?

本章为读者娓娓道来

链式法则

直接上例子

举个极端的例子来对比

假设我们有一个 3 层的极简网络(实际上大模型有几百层):

  • 第一层:$h_1 = w_1 \cdot x$
  • 第二层:$h_2 = w_2 \cdot h_1$
  • 第三层:$y = w_3 \cdot h_2$
  • Loss:$L = (y - 1)^2$

我们的目标是求:Loss 对 $w_1$ 的导数。

手动展开(解析解法/符号求导)

我们可以先手动求导试试

如果我们手动展开,把所有中间变量替换掉,$L$ 的公式会变成:

$$ L = (w_3 \cdot (w_2 \cdot (w_1 \cdot x)) - 1)^2 $$

好,现在我们对这个极其冗长的公式直接对 $w_1$ 求导(利用高数里的基本求导公式)。 算出来的导数结果是:

$$ \frac{\partial L}{\partial w_1} = 2 \cdot (w_3 \cdot w_2 \cdot w_1 \cdot x - 1) \cdot (w_3 \cdot w_2 \cdot x) $$

此时可以注意到,求导结果的右边。在算导数时,计算机需要重新计算 $w_3 \cdot w_2 \cdot w_1 \cdot x$ !!!

但是实际上,$w_3 \cdot w_2 \cdot w_1 \cdot x$ 就是预测值 $\hat{y}$,在计算Loss的时候,该值已经存在了,这个情况下,其实只需要计算$w_3 \cdot w_2 \cdot x$ 就行了

$x$ 由于是整个知识库的输入端,所以是不断变化的,但是对于单轮训练来说,$w_3$ 和 $w_2$ 是确定的,此时就可以看做

$$ w_3 \cdot w_2 \cdot x = (w_3 \cdot w_2) \cdot x $$

所以如果要计算 $w_1$ 的导数,就只需要计算 $(w_3 \cdot w_2) \cdot x$ 就行了

同理:

如果要计算 $w_2$ 的导数,就只需要计算 $(w_3 \cdot w_1) \cdot x$ 就行了

如果要计算 $w_3$ 的导数,就只需要计算 $(w_2 \cdot w_1) \cdot x$ 就行了

这个计算过程,其实和链式法则是一体的,只不过链式法则告诉我们不需要展开。

链式法则的优雅推导

如果我们用链式法则重新审视上面的 3 层网络,求 $\frac{\partial L}{\partial w_1}$ 的过程可以拆解为:

$$ \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h_2} \cdot \frac{\partial h_2}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_1} $$

将每一项单独拎出来看:

  • $\frac{\partial L}{\partial y} = 2(y - 1)$ (Loss 对预测值的导数)
  • $\frac{\partial y}{\partial h_2} = w_3$ (第三层对第二层输出的导数)
  • $\frac{\partial h_2}{\partial h_1} = w_2$ (第二层对第一层输出的导数)
  • $\frac{\partial h_1}{\partial w_1} = x$ (第一层对权重的导数)

把它们乘起来:

$$ \frac{\partial L}{\partial w_1} = 2(y - 1) \cdot w_3 \cdot w_2 \cdot x $$

你会惊奇地发现,这和我们手动展开算出来的结果完全一致

但是,链式法则提供了一个极其重要的工程视角:局部性(Locality)。 在计算 $w_1$ 的导数时,我们不需要知道整个网络的全局公式,只需要知道:

  1. 上一层传回来的“压力”(梯度)。
  2. 本层自己的输入和权重。

他证明了:只需要知道上一层的计算结果,本层的输入权重,即可完成本层的求导计算

矩阵

为什么是矩阵?其实工作量上来之后就能看到规律了

刚才的例子让我们了解到了一些基本条件,但是如果参数量再多一点会怎么?

假设有5个参数 $w$

$$ w_n = w_1 , w_2 , w_3 , w_4 , w_5 $$

假设整个方程是:

$$ \begin{aligned} h_1 &= w_1 \cdot x + b_1 \\ h_2 &= w_2 \cdot h_1 - b_2 \\ h_3 &= w_3 \cdot h_2 - b_3 \\ h_4 &= w_4 \cdot h_3 + b_4 \\ y &= w_5 \cdot h_4 + b_5 \end{aligned} $$

那么手动展开后,这个式子就变成了:

$$ y = w_5 \cdot (w_4 \cdot (w_3 \cdot (w_2 \cdot (w_1 \cdot x + b_1) - b_2) - b_3) + b_4) + b_5 $$

整理一下得到

$$ y = (w_5 w_4 w_3 w_2 w_1) \cdot x + (w_5 w_4 w_3 w_2) \cdot b_1 - (w_5 w_4 w_3) \cdot b_2 - (w_5 w_4) \cdot b_3 + w_5 \cdot b_4 + b_5 $$

$$ y = (w_5 w_4 w_3 w_2 w_1) \cdot x + C $$

对$w_1$求偏导的话得到

$$ \frac{\partial y}{\partial w_1} = w_5 w_4 w_3 w_2 x = (w_5 w_4 w_3 w_2) \cdot x $$

矩阵化的终极发现:从“暴力推导”中寻找规律

抛开所谓的中间变量 $h$,抛开抽象的链式法则。我们直接面对最原始、最庞大的展开式。

已知展开后的公式为:

$$ y = (w_5 w_4 w_3 w_2 w_1) \cdot x + (w_5 w_4 w_3 w_2) \cdot b_1 + (w_5 w_4 w_3) \cdot b_2 + (w_5 w_4) \cdot b_3 + w_5 \cdot b_4 + b_5 $$

暴力求偏导(纯体力劳动)

我们对每一个 $w$ 逐一求导,看看结果长什么样:

  • 对 $w_1$ 求导: $\frac{\partial y}{\partial w_1} = (w_5 w_4 w_3 w_2) \cdot x + C$
  • 对 $w_2$ 求导: $\frac{\partial y}{\partial w_2} = (w_5 w_4 w_3 \cdot w_1) \cdot x + C$
  • 对 $w_3$ 求导: $\frac{\partial y}{\partial w_3} = (w_5 w_4 \cdot w_2 w_1) \cdot x + C$
  • 对 $w_4$ 求导: $\frac{\partial y}{\partial w_4} = (w_5 \cdot w_3 w_2 w_1) \cdot x + C$
  • 对 $w_5$ 求导: $\frac{\partial y}{\partial w_5} = (w_4 w_3 w_2 w_1) \cdot x + C$

那如果要对 $w_1$ 进行调整,假设学习系数是 $\eta$,那么调整的计算表达式就是:

$$ w_1' = w_1 - \eta \cdot (w_5 w_4 w_3 w_2 \cdot x) $$

做一下移项和拓展,就可以得到:

$$ \begin{aligned} w_1 - w_1' &= \eta \cdot (w_5 w_4 w_3 w_2) \cdot x \\ w_2 - w_2' &= \eta \cdot (w_5 w_4 w_3 w_1) \cdot x \\ w_3 - w_3' &= \eta \cdot (w_5 w_4 w_2 w_1) \cdot x \\ w_4 - w_4' &= \eta \cdot (w_5 w_3 w_2 w_1) \cdot x \\ w_5 - w_5' &= \eta \cdot (w_4 w_3 w_2 w_1) \cdot x \end{aligned} $$

诶?是不是有一点点矩阵的味道了?

如果把 $w_n$的原始值写作向量 $\mathbf{w}$,变化后的结果 $w'$ 写作 $\mathbf{w}'$:

$$ \mathbf{w} = \begin{bmatrix} w_1 \\ w_2 \\ w_3 \\ w_4 \\ w_5 \end{bmatrix}, \quad \mathbf{w}' = \begin{bmatrix} w_1' \\ w_2' \\ w_3' \\ w_4' \\ w_5' \end{bmatrix} $$

那么易得(实际就是凑出来的)梯度向量 $\mathbf{G}$ 就是:

$$ \mathbf{G} = \begin{bmatrix} w_5 w_4 w_3 w_2 \\ w_5 w_4 w_3 w_1 \\ w_5 w_4 w_2 w_1 \\ w_5 w_3 w_2 w_1 \\ w_4 w_3 w_2 w_1 \end{bmatrix} $$

最终的权重更新公式(计算方式)就是:

$$ \mathbf{w} - \mathbf{w}' = \eta \cdot \mathbf{G} \cdot x $$

那什么是梯度呢?

如果严谨一些,此时$\mathbf{G}$只能算作输出对权重的偏导数部分,完整的梯度的定义就是

$$ \nabla L(\mathbf{w}) = \mathbf{G} \cdot x $$

然后更新参数的表达式就变成了

$$ \mathbf{w}' = \mathbf{w} - \eta \cdot \nabla L(\mathbf{w}) $$

这便是,最开始学到的梯度以及梯度参与的计算式子

矩阵和链式法则的交汇

你可能会问:“链式法则”又起到了什么作用?

提取公因式

逻辑就在你“凑”出来的这个 $\mathbf{G}$ 向量里:

$$ \mathbf{G} = \begin{bmatrix} \mathbf{w_5 w_4 w_3} \cdot w_2 \\ \mathbf{w_5 w_4 w_3} \cdot w_1 \\ \mathbf{w_5 w_4} \cdot w_2 w_1 \\ \mathbf{w_5} \cdot w_3 w_2 w_1 \\ w_4 w_3 w_2 w_1 \end{bmatrix} $$

1. 发现冗余

仔细看前两行:为了算 $w_1$ 和 $w_2$ 的梯度,你都重复计算了 $\mathbf{w_5 \cdot w_4 \cdot w_3}$。 在 5 层网络里这不算什么,但如果是 1000 层网络:

  • 算第 1 层的梯度要乘 999 次。
  • 算第 2 层的梯度又要乘 999 次。
  • 总计算量是 $O(N^2)$,这是灾难性的。

2. 提取公因子的逻辑

如果我们换个思路,定义一个中间变量 $\delta$(Delta):

  • $\delta_5 = 1$
  • $\delta_4 = \delta_5 \cdot w_5$
  • $\delta_3 = \delta_4 \cdot w_4 = w_5 w_4$
  • $\delta_2 = \delta_3 \cdot w_3 = w_5 w_4 w_3$

现在,你的 $\mathbf{G}$ 向量变成了:

$$ \mathbf{G} = \begin{bmatrix} \delta_2 \cdot w_2 \\ \delta_2 \cdot w_1 \\ \delta_3 \cdot w_2 w_1 \\ \delta_4 \cdot w_3 w_2 w_1 \\ \delta_5 \cdot w_4 w_3 w_2 w_1 \end{bmatrix} $$

反向传播

但是实际计算的顺序,因为要借助缓存,所以计算顺序是从下往上计算的;

按照 $\mathbf{G}$ 向量的例子,过程如下,先计算最下方的算式,这一步结束后获取了$\delta_5 = 1$ (初始值)

$$ w_5 = \delta_5 \cdot w_4 w_3 w_2 w_1 $$

同样的,第二次计算的式子如下,这一步结束后 获取了$\delta_4 = \delta_5 \cdot w_5 = w_5 $ (其中 $\delta_5$ 复用缓存结果)

$$ w_4 = \delta_4 \cdot w_3 w_2 w_1 $$

同样的,第三次计算的式子如下,这一步结束后 获取了$\delta_3 = \delta_4 \cdot w_4 = w_5 w_4 $ (其中 $\delta_4$ 复用缓存结果)

$$ w_3 = \delta_3 \cdot w_2 w_1 $$

同样的,第四次计算的式子如下,这一步结束后 获取了$\delta_2 = \delta_3 \cdot w_3 = w_5 w_4 w_3$ (其中 $\delta_3$ 复用缓存结果)

$$ w_2 = \delta_2 \cdot w_1 $$

最后,第五次计算的式子如下,这一步直接复用缓存结果 $\delta_2 = w_5 w_4 w_3$

$$ w_1 = \delta_2 \cdot w_2 $$

此刻,我们已经计算出了所有的 $\delta$,也就计算出了所有的 $\mathbf{G}$ 向量的元素

Why called Backpropagation - 反向传播

由于这个计算过程是反向的,所以叫做反向传播

既然链式法则只是简单的乘法,为什么 AI 领域要专门发明“反向传播”这个词?

这是因为在计算机实现中,顺序决定了效率

想象一下,如果你有 1000 层网络:

  • 从左往右算(前向模式求导):每计算一个参数的导数,都要从头往后走一遍。对于大模型的亿万级参数,这会算到天荒地老。
  • 从右往左算(反向模式求导):我们从 Loss 开始,把梯度像接力棒一样往回传。每一层收到的梯度,都是前面所有层累积的结果。

这就是反向传播:它是链式法则在神经网络上的高效实现。

这就是链式法则的本质

这种“先算出一部分乘积并存起来,供后面所有人复用”的逻辑,在数学上写出来就是:

$$ \frac{\partial y}{\partial w_1} = \underbrace{\frac{\partial y}{\partial h_4} \cdot \frac{\partial h_4}{\partial h_3} \cdot \frac{\partial h_3}{\partial h_2}}_{\text{提前算好的 } \delta} \cdot \frac{\partial h_2}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_1} $$

“凑”出来的 $\mathbf{G}$ 是结果,而“链式法则”是达到这个结果的最快路径(即实操中类似缓存思想)

  • 链式法则:负责“提取公因子”(让大家少算重复逻辑)

双向缓存:极致的“中间算子化”

为了追求极致的计算速度,我们不仅缓存后向的 $\delta$,也可以缓存前向的中间积 $H$:

1. 定义前向缓存 H(从左往右积):

  • $H_0 = 1$
  • $H_1 = w_1 \cdot H_0 = w_1$
  • $H_2 = w_2 \cdot H_1 = w_2 w_1$
  • $H_3 = w_3 \cdot H_2 = w_3 w_2 w_1$
  • $H_4 = w_4 \cdot H_3 = w_4 w_3 w_2 w_1$

2. 最终的 G 矩阵(全中间算子化):

$$ \mathbf{G} = \begin{bmatrix} \delta_1 \cdot H_0 \\ \delta_2 \cdot H_1 \\ \delta_3 \cdot H_2 \\ \delta_4 \cdot H_3 \\ \delta_5 \cdot H_4 \end{bmatrix} $$

(注:$\delta_1 = \delta_2 \cdot w_2$,代表从输出一路传到第一层的完整误差链)

而且中间算子之间存在着紧密的迭代关系

  • 前向迭代:$H_i = w_i \cdot H_{i-1}$ (每一层的前向信号,都是“上一层的前向信号” $\times$ “当前层权重”)
  • 后向迭代:$\delta_i = \delta_{i+1} \cdot w_{i+1}$ (每一层的后向误差,都是“后一层的后向误差” $\times$ “后一层权重”)

这种“邻居传邻居”的递推逻辑,才是**链式法则(Chain Rule)**在工程上能够如此高效的根本原因。

这就是空间换时间的终极体现: 计算每一层的梯度时,我们不再进行任何重复的连乘,只需取出**“左边存好的 $H$”“右边传回的 $\delta$”**直接相乘。计算量从 $O(N^2)$ 瞬间降到了 $O(N)$。

小结

  • $H_i$:前向接力(Forward Cache)

时机:发生在从 $x$ 计算到 $y$ 的过程中(前向传播)。 目的:为了算出最终的预测值 $y$。 缓存价值:在算 $y$ 的路上,每一层算出的结果(比如 $H_2 = w_2 w_1$)都被悄悄存了下来。因为稍后的反向传播必须用到这些来计算权重对输出的影响。

  • $\delta_i$:后向接力(Backward Cache)

时机:发生在计算出 Loss 后,从后往前推的过程中(反向传播)。 目的:为了实现链式法则缓存价值:它存储了“误差信号”的累积(比如 $\delta_3 = w_5 w_4$)。每一层都直接复用后一层的累积结果,而不需要从最末尾重新连乘一遍。

再小结

直到现在,利用一些算数方法,获得了$\mathbf{G}$,就可以利用如下公式进行参数调整了

$$ \mathbf{w}' = \mathbf{w} - \eta \cdot \mathbf{G} \cdot x $$

新的参数 $\mathbf{w}'$ 又会开启新一轮的计算。通过不断重复“前向计算 $\to$ 计算损失 $\to$ 反向传播 $\to$ 参数更新”的过程,模型会像在山坡上行走一样,一步步挪向误差最小的谷底。

巩固理解

那到此,参数$w_n$合集,输入$x$,以及输出$y$在高维空间的意义是什么呢?

这要分两个视角,并且为了方便理解,暂时不考虑$y$为向量的情况

那么此时,我们可以从两个截然不同的视角来观察这个世界:

视角 A:前向推理(人在画中,x 是主角)

  • 世界观:你构造了一个由参数 $\mathbf{w}$ 定义的“数学世界” $f(\mathbf{x}; \mathbf{w})$。
  • 主角:当你在这个世界的“地面”上投入一个多维向量 $\mathbf{x}$,这个世界会反馈给你一个对应的“海拔高度”(输出 $y$)。
  • 目标:我们希望这个世界给出的高度 $y'$,能够完美对齐真理中的那个目标高度 $y$。
  • 参数 $\mathbf{w}$ 的意义:它们是这个世界的“造物主”。调整 $\mathbf{w}$,就是在重塑这个世界的山川地貌(函数曲面),直到能在大多数情况下,输出的$y'$距离和期望的$y$的高度差是最小的(包括已知的版图和未知的版图)

视角 B:反向训练(人在画外,w 是主角)

  • 世界观:这不再是预测的世界,而是“检讨错误”的世界。此时的数学函数变成了损失函数 $Loss = L(\mathbf{w}; \mathbf{x}, y)$。

  • 主角:权重向量 $\mathbf{w}$,它现在变成了需要不断调整位置的“旅行者”。

  • 场景:已知的真理数据 $(\mathbf{x}, y)$(常量)像全息投影一样,在 $\mathbf{w}$ 的多维空间里投射出了一片起伏不平的“误差地形”。

  • 目标:根据脚下地形的坡度(梯度 $\nabla L$),一步步往下走,寻找那个让“误差海拔(Loss)”最低的经纬度坐标。

  • 权重向量 $\mathbf{w}$:就是你的“经纬度”,决定了你在地图上的位置。

  • 输入数据 $(x, y)$:就是“全息投影仪”,它在权重空间里投射出了一片起伏不平的地形(损失函数曲面)。

  • 梯度 $\nabla L(\mathbf{w})$:就是你脚下那块地最陡峭的方向。它结合了模型的结构(系数 $\mathbf{G}$)和当前的输入信号 $x$。

  • 学习率 $\eta$:就是你的“步长”,决定了你这一步迈多大。

  • 前向缓存(Activations):是你在“吸气”时留下的案底,记录了 $x$ 走过的痕迹。

  • 后向传播(Backprop):是你在“呼气”时利用案底,即时算出每一步该如何修正。

结论: 只要我们朝着梯度 $\nabla L(\mathbf{w})$ 的反方向迈出一小步,就能确保我们在当前这一刻,是以最快的速度在“下山”。

梯度反方向是“最快”下降方向?

局部最快下降方向可以通过数学推导来看看,为什么偏偏要沿着梯度的反方向走?

1. 泰勒展开:微观世界的线性近似

假设我们的损失函数是 $L(\mathbf{w})$。当我们给权重一个极其微小的变化 $\Delta \mathbf{w}$ 时,新的损失值可以用一阶泰勒展开来近似:

$$ L(\mathbf{w} + \Delta \mathbf{w}) \approx L(\mathbf{w}) + \nabla L(\mathbf{w})^T \cdot \Delta \mathbf{w} $$

这里 $\nabla L(\mathbf{w})$ 是梯度向量,$\Delta \mathbf{w}$ 是我们的步进方向。

2. 寻找下降最快的方向

为了让损失函数减小得最快,我们需要让 $L(\mathbf{w} + \Delta \mathbf{w}) - L(\mathbf{w})$ 这个差值尽可能小(也就是负得越多越好)。

根据上面的公式,这个差值等于两个向量的点积:

$$ \text{Change} = \nabla L(\mathbf{w})^T \cdot \Delta \mathbf{w} = \|\nabla L(\mathbf{w})\| \cdot \|\Delta \mathbf{w}\| \cdot \cos \theta $$

其中 $\theta$ 是梯度向量和步进方向之间的夹角。

  • 当 $\cos \theta = 1$ 时($\theta = 0^\circ$),增加最快。
  • 当 $\cos \theta = -1$ 时($\theta = 180^\circ$),减少最快。

结论: 当步进方向 $\Delta \mathbf{w}$ 与梯度方向完全相反时,函数值下降最快。这就是为什么更新公式里要写成 $\mathbf{w} - \eta \nabla L$。

3. 局部最优 vs 全局最优

梯度下降就像是一个“近视眼”在下山:

  • 贪心性:它只看脚下那一寸土地哪里最陡,就往哪里走。
  • 局限性:如果山谷里有一个小坑(局部最优解),它一旦掉进去,周围都是向上的坡,它就会觉得已经到了最低点,从而停下来。
  • 凸函数:只有当整个山坡是一个完美的“碗”状(凸函数,Convex Function)时,梯度下降才能保证找到唯一的全局最低点。

而在深度学习中,地形往往极其复杂,布满了数以亿计的局部最优解和鞍点。这也是为什么我们需要 动量(Momentum)自适应学习率(Adam) 等技巧来帮助模型跳出这些“小坑”。

不过这都是些后话了

Dan❤Anan
Built with Hugo
主题 StackJimmy 设计