跳转至

自动微分

手算微分麻烦且容易出错,符号微分处理代码中的函数不方便,数值微分的误差大,目前主流的自动微分方法是 Forward 和 Backward 传播法。

Dual number

Dual number 可以表示为 \(a+b \varepsilon\),其中 \(\varepsilon\) 是一个无穷小的数,满足 \(\varepsilon^2=0\)\(\varepsilon \ne 0\)

它的一些运算性质

  • \(\lambda (a+b \varepsilon)=\lambda a + \lambda b \varepsilon\)
  • \((a+b\varepsilon)+(c+d\varepsilon)=(a+c)+(b+d)\varepsilon\)
  • \((a+b\varepsilon)(c+d\varepsilon)=ac+(ad+bc)\varepsilon\)

除法可以通过上下同乘分母共轭化简

\[ \begin{align} \frac{a+b\varepsilon}{c+d\varepsilon}&=\frac{(a+b\varepsilon)(c-d\varepsilon)}{(c+d\varepsilon)(c-d\varepsilon)}\\\\ &=\frac{ac+(bc-ad)\varepsilon}{c^2}\\\\ &=\frac{a}{c}+\frac{bc-ad}{c^2}\varepsilon \end{align} \]

\(f(a+b\varepsilon)\)\(a\) 处泰勒展开

\[ f(a+b\varepsilon)=\sum_{n=0}^\infty \frac{f^{(n)}(a)}{n!}b^n\varepsilon^n=f(a)+bf'(a)\varepsilon \]

计算 \(f(a+b\varepsilon)\) 可以同时得到 \(f(a),f'(a)\),这是 Forward 自动微分的核心。

Forward

对于多元函数 \(f(\mathbf{a}+\mathbf{b}\varepsilon)\) 泰勒展开后得

\[ f(\mathbf{a}+\mathbf{b}\varepsilon)=f(\mathbf{a}) + \varepsilon \cdot (\nabla f(\mathbf{a}) )^T \mathbf{b} \]

其中 \(\mathbf{a}=(a_0,a_1,\dots)^T,\mathbf{b}=(b_0,b_1,\dots)^T\),当 \(\mathbf{b}=(1,0,\dots)^T\)

\[ f(\mathbf{a}+\mathbf{b}\varepsilon)=f(\mathbf{a}) + f'_x(\mathbf{a}) \varepsilon \]

可以得到关于 \(x\) 的偏导。类似地,当 \(\mathbf{b}=(0,1,\dots)^T\) 时,可以得到关于 \(y\) 的偏导。以 \(f(x,y)=\exp(x^2+2y+2)\) 为例

\[ \begin{align} f(a+b\varepsilon,c+d\varepsilon)&=\exp \left((a+b\varepsilon)^2+2(c+d\varepsilon)+2 \right)\\\\ &=\exp \left((a^2+2c+2)+2(ab+d)\varepsilon \right)\\\\ &=\exp \left(a^2+2c+2 \right)+2(ab+d) \exp \left(a^2+2c+2 \right) \varepsilon\\\\ &=\exp \left(a^2+2c+2 \right) \left(1 + 2(ab+d) \varepsilon \right) \end{align} \]

\(x=3,y=4\) 处的偏导,由

\[ \begin{align} f(3+\varepsilon,4)&=e^{19}(1+6\varepsilon)\\ f(3,4+\varepsilon)&=e^{19}(1+2\varepsilon) \end{align} \]

\[ \begin{align} f'_x(3,4)&=6 e^{19}\\ f'_y(3,4)&=2 e^{19} \end{align} \]

在 slang 中,用 DifferentialPair<T> 表示 \(p+d\varepsilon\)

struct DifferentialPair<T : IDifferentiable> : IDifferentiable
{
    typealias Differential = DifferentialPair<T.Differential>;
    property T p { get; }              // the initial primal value
    property T.Differential d { get; } // the partial derivative
    static Differential dzero();
    static Differential dadd(Differential a, Differential b);
}

fwd_diff 算子验证刚才的计算

[Differentiable]
float myFunc(float x, float y)
{
    return exp(x * x + 2 * y + 2);
}

void printMain()
{
    // Use forward differentiation to compute the gradient of the output w.r.t. x only.
    let diffX = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 0.0));
    printf("dF wrt x: %f\n", diffX.d);

    // Use forward differentiation to propagate the gradient from input parameters to output value.
    let diffXY = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 1.0));
    printf("dF wrt x and y: %f\n", diffXY.d);
}

生成的 hlsl 代码为

struct DiffPair_float_0
{
    float primal_0;
    float differential_0;
};

DiffPair_float_0 _d_exp_1(DiffPair_float_0 dpx_1)
{
    float _S2 = exp(dpx_1.primal_0);
    DiffPair_float_0 _S3 = { _S2, _S2 * dpx_1.differential_0 };
    return _S3;
}

DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_2, DiffPair_float_0 dpy_0)
{
    float _S4 = dpx_2.primal_0;
    float _S5 = dpx_2.differential_0 * dpx_2.primal_0;
    DiffPair_float_0 _S6 = { _S4 * _S4 + 2.0 * dpy_0.primal_0 + 2.0, _S5 + _S5 + dpy_0.differential_0 * 2.0 };
    DiffPair_float_0 _S7 = _d_exp_1(_S6);
    DiffPair_float_0 _S8 = { _S7.primal_0, _S7.differential_0 };
    return _S8;
}

这种方法每计算一个变量的偏导,就要重新执行整个微分函数,当输入的变量较多时效率不高。

Backward

使用复合函数求导的链式法则实现。以 \(f(x,y)=x^2y+y+2\),当 \(x=3,y=4\) 时为例

Backward Derivative Propagation
Backward Derivative Propagation

因为 \(n_7\) 就是 \(f\),所以 \(\dfrac{\partial f}{\partial n_7}=1\)。然后,因为 \(n_7=n_5+n_6\),所以 \(\dfrac{\partial n_7}{\partial n_5}=1\),进而

\[ \frac{\partial f}{\partial n_5}=\frac{\partial f}{\partial n_7} \times \frac{\partial n_7}{\partial n_5}=1 \times 1 = 1 \]

其他的以此类推,最后得到 \(\dfrac{\partial f}{\partial x},\dfrac{\partial f}{\partial y}\)

在 slang 中用 bwd_diff 算子进行 Backward 微分,参数是 inoutDifferentialPair<T>p 中存放参数的原始值,d 用于接收反向传播过来的偏导,bwd_diff 时不表示 Dual number。

[Differentiable]
float myFunc(float x, float y)
{
    return x * x * y + y + 2;
}

void printMain()
{
    // Create a differentiable pair to pass in the primal value and to receive the gradient.
    DifferentialPair<float> dpX = diffPair(3.0, 0.0);
    DifferentialPair<float> dpY = diffPair(4.0, 0.0);
    // Propagate the gradient of the output (1.0f) to the input parameters.
    bwd_diff(myFunc)(dpX, dpY, 1.0);
    printf("dF wrt x computed using backward differentiation: %f\n", dpX.d);
    printf("dF wrt y computed using backward differentiation: %f\n", dpY.d);
}

最后的 1.0 参数,相当于前面推导时的 \(\dfrac{\partial f}{\partial n_7}\),也可以根据需要改成其他值。

生成的 hlsl 代码

struct DiffPair_float_0
{
    float primal_0;
    float differential_0;
};

DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_0, DiffPair_float_0 dpy_0)
{
    float _S1 = dpx_0.primal_0;
    float _S2 = _S1 * _S1;
    float _S3 = dpx_0.differential_0 * dpx_0.primal_0;
    DiffPair_float_0 _S4 = { _S2 * dpy_0.primal_0 + dpy_0.primal_0 + 2.0, (_S3 + _S3) * dpy_0.primal_0 + dpy_0.differential_0 * _S2 + dpy_0.differential_0 };
    return _S4;
}

void s_bwd_myFunc_0(inout DiffPair_float_0 _S8, inout DiffPair_float_0 _S9, float _S10)
{
    s_bwd_prop_myFunc_0(_S8, _S9, _S10);
    return;
}

通过一次前向计算中间结点的值,再一次反向传播梯度,就能把关于所有参数的偏导都求出来。

参考