手算微分麻烦且容易出错,符号微分处理代码中的函数不方便,数值微分的误差大,目前主流的自动微分方法是 Forward 和 Backward 传播法。
Dual number¶
Dual number 可以表示为 \(a+b \varepsilon\),其中 \(\varepsilon\) 是一个无穷小的数,满足 \(\varepsilon^2=0\) 但 \(\varepsilon \ne 0\)。
- \(\lambda (a+b \varepsilon)=\lambda a + \lambda b \varepsilon\)
- \((a+b\varepsilon)+(c+d\varepsilon)=(a+c)+(b+d)\varepsilon\)
- \((a+b\varepsilon)(c+d\varepsilon)=ac+(ad+bc)\varepsilon\)
将 \(f(a+b\varepsilon)\) 在 \(a\) 处泰勒展开
计算 \(f(a+b\varepsilon)\) 可以同时得到 \(f(a),f'(a)\),这是 Forward 自动微分的核心。
对于多元函数 \(f(\mathbf{a}+\mathbf{b}\varepsilon)\) 泰勒展开后得
其中 \(\mathbf{a}=(a_0,a_1,\dots)^T,\mathbf{b}=(b_0,b_1,\dots)^T\),当 \(\mathbf{b}=(1,0,\dots)^T\) 时
可以得到关于 \(x\) 的偏导。类似地,当 \(\mathbf{b}=(0,1,\dots)^T\) 时,可以得到关于 \(y\) 的偏导。以 \(f(x,y)=\exp(x^2+2y+2)\) 为例
求 \(x=3,y=4\) 处的偏导,由
在 slang 中,用 DifferentialPair<T>
表示 \(p+d\varepsilon\)
struct DifferentialPair<T : IDifferentiable> : IDifferentiable
typealias Differential = DifferentialPair<T.Differential>;
property T p { get; } // the initial primal value
property T.Differential d { get; } // the partial derivative
static Differential dzero();
static Differential dadd(Differential a, Differential b);
用 fwd_diff
float myFunc(float x, float y)
return exp(x * x + 2 * y + 2);
void printMain()
// Use forward differentiation to compute the gradient of the output w.r.t. x only.
let diffX = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 0.0));
printf("dF wrt x: %f\n", diffX.d);
// Use forward differentiation to propagate the gradient from input parameters to output value.
let diffXY = fwd_diff(myFunc)(diffPair(3.0, 1.0), diffPair(4.0, 1.0));
printf("dF wrt x and y: %f\n", diffXY.d);
生成的 hlsl 代码为
struct DiffPair_float_0
float primal_0;
float differential_0;
DiffPair_float_0 _d_exp_1(DiffPair_float_0 dpx_1)
float _S2 = exp(dpx_1.primal_0);
DiffPair_float_0 _S3 = { _S2, _S2 * dpx_1.differential_0 };
return _S3;
DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_2, DiffPair_float_0 dpy_0)
float _S4 = dpx_2.primal_0;
float _S5 = dpx_2.differential_0 * dpx_2.primal_0;
DiffPair_float_0 _S6 = { _S4 * _S4 + 2.0 * dpy_0.primal_0 + 2.0, _S5 + _S5 + dpy_0.differential_0 * 2.0 };
DiffPair_float_0 _S7 = _d_exp_1(_S6);
DiffPair_float_0 _S8 = { _S7.primal_0, _S7.differential_0 };
return _S8;
使用复合函数求导的链式法则实现。以 \(f(x,y)=x^2y+y+2\),当 \(x=3,y=4\) 时为例
因为 \(n_7\) 就是 \(f\),所以 \(\dfrac{\partial f}{\partial n_7}=1\)。然后,因为 \(n_7=n_5+n_6\),所以 \(\dfrac{\partial n_7}{\partial n_5}=1\),进而
其他的以此类推,最后得到 \(\dfrac{\partial f}{\partial x},\dfrac{\partial f}{\partial y}\)。
在 slang 中用 bwd_diff
算子进行 Backward 微分,参数是 inout
的 DifferentialPair<T>
时不表示 Dual number。
float myFunc(float x, float y)
return x * x * y + y + 2;
void printMain()
// Create a differentiable pair to pass in the primal value and to receive the gradient.
DifferentialPair<float> dpX = diffPair(3.0, 0.0);
DifferentialPair<float> dpY = diffPair(4.0, 0.0);
// Propagate the gradient of the output (1.0f) to the input parameters.
bwd_diff(myFunc)(dpX, dpY, 1.0);
printf("dF wrt x computed using backward differentiation: %f\n", dpX.d);
printf("dF wrt y computed using backward differentiation: %f\n", dpY.d);
最后的 1.0
参数,相当于前面推导时的 \(\dfrac{\partial f}{\partial n_7}\),也可以根据需要改成其他值。
生成的 hlsl 代码
struct DiffPair_float_0
float primal_0;
float differential_0;
DiffPair_float_0 s_fwd_myFunc_0(DiffPair_float_0 dpx_0, DiffPair_float_0 dpy_0)
float _S1 = dpx_0.primal_0;
float _S2 = _S1 * _S1;
float _S3 = dpx_0.differential_0 * dpx_0.primal_0;
DiffPair_float_0 _S4 = { _S2 * dpy_0.primal_0 + dpy_0.primal_0 + 2.0, (_S3 + _S3) * dpy_0.primal_0 + dpy_0.differential_0 * _S2 + dpy_0.differential_0 };
return _S4;
void s_bwd_myFunc_0(inout DiffPair_float_0 _S8, inout DiffPair_float_0 _S9, float _S10)
s_bwd_prop_myFunc_0(_S8, _S9, _S10);