自动微分¶
手算微分麻烦且容易出错,符号微分处理代码中的函数不方便,数值微分的误差大,目前主流的自动微分方法是 Forward 和 Backward 传播法。
Dual number¶
Dual number 可以表示为 \(a+b \varepsilon\),其中 \(\varepsilon\) 是一个无穷小的数,满足 \(\varepsilon^2=0\) 但 \(\varepsilon \ne 0\)。
它的一些运算性质
- \(\lambda (a+b \varepsilon)=\lambda a + \lambda b \varepsilon\)
- \((a+b\varepsilon)+(c+d\varepsilon)=(a+c)+(b+d)\varepsilon\)
- \((a+b\varepsilon)(c+d\varepsilon)=ac+(ad+bc)\varepsilon\)
除法可以通过上下同乘分母共轭化简
将 \(f(a+b\varepsilon)\) 在 \(a\) 处泰勒展开
计算 \(f(a+b\varepsilon)\) 可以同时得到 \(f(a),f'(a)\),这是 Forward 自动微分的核心。
Forward¶
对于多元函数 \(f(\mathbf{a}+\mathbf{b}\varepsilon)\) 泰勒展开后得
其中 \(\mathbf{a}=(a_0,a_1,\dots)^T,\mathbf{b}=(b_0,b_1,\dots)^T\),当 \(\mathbf{b}=(1,0,\dots)^T\) 时
可以得到关于 \(x\) 的偏导。类似地,当 \(\mathbf{b}=(0,1,\dots)^T\) 时,可以得到关于 \(y\) 的偏导。以 \(f(x,y)=\exp(x^2+2y+2)\) 为例
求 \(x=3,y=4\) 处的偏导,由
得
在 slang 中,用 DifferentialPair<T>
表示 \(p+d\varepsilon\)
struct DifferentialPair<T : IDifferentiable> : IDifferentiable
{
typealias Differential = DifferentialPair<T.Differential>;
property T p { get; } // the initial primal value
property T.Differential d { get; } // the partial derivative
static Differential dzero();
static Differential dadd(Differential a, Differential b);
}
用 fwd_diff
算子验证刚才的计算
[Differentiable]
float myFunc(float x, float y)
{
return exp(x * x + 2 * y + 2);
}
void printMain()
{
// Use forward differentiation to compute the gradient of the output w.r.t. x only.
let diffX = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 0.0));
printf("dF wrt x: %f\n", diffX.d);
// Use forward differentiation to propagate the gradient from input parameters to output value.
let diffXY = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 1.0));
printf("dF wrt x and y: %f\n", diffXY.d);
}
生成的 hlsl 代码为
struct DiffPair_float_0
{
float primal_0;
float differential_0;
};
DiffPair_float_0 _d_exp_1(DiffPair_float_0 dpx_1)
{
float _S2 = exp(dpx_1.primal_0);
DiffPair_float_0 _S3 = { _S2, _S2 * dpx_1.differential_0 };
return _S3;
}
DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_2, DiffPair_float_0 dpy_0)
{
float _S4 = dpx_2.primal_0;
float _S5 = dpx_2.differential_0 * dpx_2.primal_0;
DiffPair_float_0 _S6 = { _S4 * _S4 + 2.0 * dpy_0.primal_0 + 2.0, _S5 + _S5 + dpy_0.differential_0 * 2.0 };
DiffPair_float_0 _S7 = _d_exp_1(_S6);
DiffPair_float_0 _S8 = { _S7.primal_0, _S7.differential_0 };
return _S8;
}
这种方法每计算一个变量的偏导,就要重新执行整个微分函数,当输入的变量较多时效率不高。
Backward¶
使用复合函数求导的链式法则实现。以 \(f(x,y)=x^2y+y+2\),当 \(x=3,y=4\) 时为例
因为 \(n_7\) 就是 \(f\),所以 \(\dfrac{\partial f}{\partial n_7}=1\)。然后,因为 \(n_7=n_5+n_6\),所以 \(\dfrac{\partial n_7}{\partial n_5}=1\),进而
其他的以此类推,最后得到 \(\dfrac{\partial f}{\partial x},\dfrac{\partial f}{\partial y}\)。
在 slang 中用 bwd_diff
算子进行 Backward 微分,参数是 inout
的 DifferentialPair<T>
,p
中存放参数的原始值,d
用于接收反向传播过来的偏导,bwd_diff
时不表示 Dual number。
[Differentiable]
float myFunc(float x, float y)
{
return x * x * y + y + 2;
}
void printMain()
{
// Create a differentiable pair to pass in the primal value and to receive the gradient.
DifferentialPair<float> dpX = diffPair(3.0, 0.0);
DifferentialPair<float> dpY = diffPair(4.0, 0.0);
// Propagate the gradient of the output (1.0f) to the input parameters.
bwd_diff(myFunc)(dpX, dpY, 1.0);
printf("dF wrt x computed using backward differentiation: %f\n", dpX.d);
printf("dF wrt y computed using backward differentiation: %f\n", dpY.d);
}
最后的 1.0
参数,相当于前面推导时的 \(\dfrac{\partial f}{\partial n_7}\),也可以根据需要改成其他值。
生成的 hlsl 代码
struct DiffPair_float_0
{
float primal_0;
float differential_0;
};
DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_0, DiffPair_float_0 dpy_0)
{
float _S1 = dpx_0.primal_0;
float _S2 = _S1 * _S1;
float _S3 = dpx_0.differential_0 * dpx_0.primal_0;
DiffPair_float_0 _S4 = { _S2 * dpy_0.primal_0 + dpy_0.primal_0 + 2.0, (_S3 + _S3) * dpy_0.primal_0 + dpy_0.differential_0 * _S2 + dpy_0.differential_0 };
return _S4;
}
void s_bwd_myFunc_0(inout DiffPair_float_0 _S8, inout DiffPair_float_0 _S9, float _S10)
{
s_bwd_prop_myFunc_0(_S8, _S9, _S10);
return;
}
通过一次前向计算中间结点的值,再一次反向传播梯度,就能把关于所有参数的偏导都求出来。