【深度学习】神经正切核(NTK)理论
本文来自于《Theory of Deep Learning》,主要是对神经正切核(NTK)理论进行介绍。这里主要是补充了一些基本概念以及部分推导过程。作为软件工程出身,数学不是特别好,有些基础知识和推导步骤没办法一次补足。若有机会,后续会逐步补全缺失的部分。设X1,…,XnX_1,\dots,X_nX1,…,Xn为nnn个独立的随机变量,且XiX_iXi的边界为[ai,bi][a_i,b
本文来自于《Theory of Deep Learning》,主要是对神经正切核(NTK)理论进行介绍。这里主要是补充了一些基本概念以及部分推导过程。作为软件工程出身,数学不是特别好,有些基础知识和推导步骤没办法一次补足。若有机会,后续会逐步补全缺失的部分。
一、基础知识
1. Hoeffding不等式
设X1,…,XnX_1,\dots,X_nX1,…,Xn为nnn个独立的随机变量,且XiX_iXi的边界为[ai,bi][a_i,b_i][ai,bi]。令Xˉ=1n∑i=1nXi\bar{X}=\frac{1}{n}\sum_{i=1}^n X_iXˉ=n1∑i=1nXi,则有
P(∣Xˉ−E(Xˉ)∣≥t)≤exp(−2n2t2∑i=1n(bi−ai)2) P(|\bar{X}-E(\bar{X})|\geq t)\leq \exp\Big(-\frac{2n^2t^2}{\sum_{i=1}^n(b_i-a_i)^2}\Big) \\ P(∣Xˉ−E(Xˉ)∣≥t)≤exp(−∑i=1n(bi−ai)22n2t2)
2. Boole不等式
令AiA_iAi表达第iii个随机事件,那么有
P(∪iAi)≤∑iP(Ai) P\Big(\cup_i A_i\Big)\leq\sum_i P(A_i) \\ P(∪iAi)≤i∑P(Ai)
即至少一个事件发生的概率不大于单独事件发生概率之和。
3. 核函数与核回归
核函数。 设X\mathcal{X}X是输入空间,H\mathcal{H}H是特征空间,若存在一个从X\mathcal{X}X至H\mathcal{H}H的映射
ϕ(x):X→H \phi(\textbf{x}):\mathcal{X}\rightarrow\mathcal{H} \\ ϕ(x):X→H
使得对所有的x,z∈X\textbf{x},\textbf{z}\in\mathcal{X}x,z∈X,函数k(x,z)k(\textbf{x},\textbf{z})k(x,z)均满足
k(x,z)=⟨ϕ(x),ϕ(z)⟩ k(\textbf{x},\textbf{z})=\langle \phi(\textbf{x}),\phi(\textbf{z})\rangle \\ k(x,z)=⟨ϕ(x),ϕ(z)⟩
则称k(x,z)k(\textbf{x},\textbf{z})k(x,z)是核函数,ϕ(x)\phi(\textbf{x})ϕ(x)是映射函数,⟨ϕ(x),ϕ(z)⟩\langle \phi(\textbf{x}),\phi(\textbf{z}) \rangle⟨ϕ(x),ϕ(z)⟩表示ϕ(x)\phi(\textbf{x})ϕ(x)和ϕ(z)\phi(\textbf{z})ϕ(z)的内积。核函数的作用是特征映射后求内积,但是不一定需要显示进行映射。
高斯核是一种常见的核函数,定义为
k(x,z)=exp(−γ∥x−z∥2) k(\textbf{x},\textbf{z})=\exp(-\gamma\parallel \textbf{x}-\textbf{z}\parallel^2) \\ k(x,z)=exp(−γ∥x−z∥2)
其可以将特征映射至无穷维,因此
exp(−∥x−z∥2)=exp(−x⊤x−z⊤z+2x⊤z)=exp(−x⊤x)exp(z⊤z)exp(2x⊤z)=exp(−x⊤x)exp(z⊤z)(∑k=0∞(2x⊤z)kk!)=∑k=0∞[exp(−x⊤x)exp(−z⊤z)2kk!2kk!(xk)⊤(zk)]=ϕ(x)⊤ϕ(z) \begin{align} \exp(-\parallel \textbf{x}-\textbf{z}\parallel^2)&=\exp(-\textbf{x}^\top\textbf{x}-\textbf{z}^\top\textbf{z}+2\textbf{x}^\top\textbf{z}) \\ &=\exp(-\textbf{x}^\top\textbf{x})\exp(\textbf{z}^\top\textbf{z})\exp(2\textbf{x}^\top\textbf{z}) \\ &=\exp(-\textbf{x}^\top\textbf{x})\exp(\textbf{z}^\top\textbf{z})\Big(\sum_{k=0}^{\infty}\frac{(2\textbf{x}^\top\textbf{z})^k}{k!}\Big) \\ &=\sum_{k=0}^{\infty}\Big[ \exp(-\textbf{x}^\top\textbf{x})\exp(-\textbf{z}^\top \textbf{z})\sqrt{\frac{2^k}{k!}}\sqrt{\frac{2^k}{k!}}(\textbf{x}^k)^\top(\textbf{z}^k) \Big] \\ &=\phi(\textbf{x})^\top\phi(\textbf{z}) \end{align} \\ exp(−∥x−z∥2)=exp(−x⊤x−z⊤z+2x⊤z)=exp(−x⊤x)exp(z⊤z)exp(2x⊤z)=exp(−x⊤x)exp(z⊤z)(k=0∑∞k!(2x⊤z)k)=k=0∑∞[exp(−x⊤x)exp(−z⊤z)k!2kk!2k(xk)⊤(zk)]=ϕ(x)⊤ϕ(z)
(上式第三等号使用了Taylor展开exp(2x⊤z)=∑0∞(2x⊤z)kk!\exp(2\textbf{x}^\top\textbf{z})=\sum_{0}^{\infty}\frac{(2\textbf{x}^\top\textbf{z})^k}{k!}exp(2x⊤z)=∑0∞k!(2x⊤z)k)
基于上式可以得到高斯核的映射函数为
ϕ(x)=exp(−x⊤x)(1,211!x1,222!x2,…,2kk!xk,…) \phi(\textbf{x})=\exp(-\textbf{x}^\top\textbf{x})\Big( 1,\sqrt{\frac{2^1}{1!}}\textbf{x}^1,\sqrt{\frac{2^2}{2!}}\textbf{x}^2,\dots,\sqrt{\frac{2^k}{k!}}\textbf{x}^k,\dots \Big) \\ ϕ(x)=exp(−x⊤x)(1,1!21x1,2!22x2,…,k!2kxk,…)
核回归。核回归是经典的非线性回归算法。给定训练集(X,y)={(xi,yi)}i=1n(\textbf{X},\textbf{y})=\{(\textbf{x}_i,y_i)\}_{i=1}^n(X,y)={(xi,yi)}i=1n,其中xi\textbf{x}_ixi是输入数据,yi=f(xi)y_i=f(\textbf{x}_i)yi=f(xi)是对应的标量标签,核回归的目标是构建一个估计函数
f^(x)=∑i=1n(K−1y)ik(xi,x) \hat{f}(\textbf{x})=\sum_{i=1}^n(\textbf{K}^{-1}\textbf{y})_i k(\textbf{x}_i,\textbf{x}) \\ f^(x)=i=1∑n(K−1y)ik(xi,x)
其中K\textbf{K}K是n×nn\times nn×n的核矩阵,该矩阵的每个分量为Kij=k(xi,xj)\textbf{K}_{ij}=k(\textbf{x}_i,\textbf{x}_j)Kij=k(xi,xj),kkk是对称半正定核函数。
直觉上,核回归对于任意数据点x\textbf{x}x的估计值可以看做是训练数据xi\textbf{x}_ixi与x\textbf{x}x的相似性作为权重,然后对训练标签yiy_iyi进行加权求和。
二、预测的演化方程
设神经网络的输出表示为f(w,x)∈Rf(w,x)\in\mathbb{R}f(w,x)∈R,其中w∈RNw\in\mathbb{R}^Nw∈RN是网络中的所有参数,x∈Rdx\in\mathbb{R}^dx∈Rd是输入。给定训练数据{(xi,yi)}i=1n⊂Rd×R\{(x_i,y_i)\}_{i=1}^n\subset\mathbb{R}^d\times\mathbb{R}{(xi,yi)}i=1n⊂Rd×R,通过最小化训练数据上的均方误差来训练神经网络:
l(w)=12∑i=1n(f(w,xi)−yi)2(1) \mathcal{l}(w)=\frac{1}{2}\sum_{i=1}^n(f(w,x_i)-y_i)^2 \tag{1} \\ l(w)=21i=1∑n(f(w,xi)−yi)2(1)
这里主要研究梯度流(gradient flow),也就是极小学习率的梯度下降。在上面的例子中,预测的动力学可以描述为常微分方程:
dw(t)dt=−∇l(w(t))(2) \frac{d w(t)}{dt}=-\nabla\mathcal{l}(w(t)) \tag{2} \\ dtdw(t)=−∇l(w(t))(2)
引理1
令u(t)=(f(w(t),xi))i∈[n]∈Rnu(t)=(f(w(t),x_i))_{i\in[n]}\in\mathbb{R}^nu(t)=(f(w(t),xi))i∈[n]∈Rn表示神经网络在时刻ttt的所有输出xi′x_i'xi′,y=(yi)i∈[n]y=(y_i)_{i\in[n]}y=(yi)i∈[n]是标签。u(t)u(t)u(t)的演化遵循
du(t)dt=−H(t)⋅(u(t)−y)(3) \frac{du(t)}{dt}=-H(t)\cdot(u(t)-y) \tag{3} \\ dtdu(t)=−H(t)⋅(u(t)−y)(3)
其中,H(t)H(t)H(t)是n×nn\times nn×n的半正定矩阵,其第(i,j)(i,j)(i,j)个元素是⟨∂f(w(t),xi)∂w,∂f(w(t),xj)∂w⟩\langle\frac{\partial f(w(t),x_i)}{\partial w},\frac{\partial f(w(t),x_j)}{\partial w}\rangle⟨∂w∂f(w(t),xi),∂w∂f(w(t),xj)⟩。 证明。参数www的演化是基于下面的微分方程
dw(t)dt=−∇l(w(t))=−∑i=1n(f(w(t),xi)−yi)∂f(w(t),xi)∂w(4) \frac{dw(t)}{dt}=-\nabla\mathcal{l}(w(t))=-\sum_{i=1}^n(f(w(t),x_i)-y_i)\frac{\partial f(w(t),x_i)}{\partial w} \tag{4} \\ dtdw(t)=−∇l(w(t))=−i=1∑n(f(w(t),xi)−yi)∂w∂f(w(t),xi)(4)
其中t≥0t\geq 0t≥0是连续的时间坐标。基于等式(4),网络输出f(w(t),xi)f(w(t),x_i)f(w(t),xi)的演化可以写作
df(w(t),xi)dt=⟨∂f(w(t),xi)∂w(t),∂w(t)∂t⟩=⟨∂f(w(t),xi)∂w(t),−∑j=1n(f(w(t),xj)−yj)∂f(w(t),xj)∂w⟩=−∑j=1n(f(w(t),xj),yj)⟨∂f(w(t),xi)∂w,∂f(w(t),xj)∂w⟩(5) \begin{align} \frac{df(w(t),x_i)}{dt}&=\Big\langle\frac{\partial f(w(t),x_i)}{\partial w(t)},\frac{\partial w(t)}{\partial t}\Big\rangle \\ &=\Big\langle \frac{\partial f(w(t),x_i)}{\partial w(t)}, -\sum_{j=1}^n(f(w(t),x_j)-y_j)\frac{\partial f(w(t),x_j)}{\partial w} \Big\rangle \\ &=-\sum_{j=1}^n(f(w(t),x_j),y_j)\Big\langle \frac{\partial f(w(t),x_i)}{\partial w}, \frac{\partial f(w(t),x_j)}{\partial w}\Big\rangle \\ \end{align} \tag{5} \\ dtdf(w(t),xi)=⟨∂w(t)∂f(w(t),xi),∂t∂w(t)⟩=⟨∂w(t)∂f(w(t),xi),−j=1∑n(f(w(t),xj)−yj)∂w∂f(w(t),xj)⟩=−j=1∑n(f(w(t),xj),yj)⟨∂w∂f(w(t),xi),∂w∂f(w(t),xj)⟩(5)
因为u(t)=(f(w(t),xi))i∈[n]∈Rnu(t)=(f(w(t),x_i))_{i\in[n]}\in\mathbb{R}^nu(t)=(f(w(t),xi))i∈[n]∈Rn是神经网络ttt时刻在所有xix_ixi上的输出,y=(yi)i∈[n]y=(y_i)_{i\in[n]}y=(yi)i∈[n]是标签。等式(5)可以紧凑的写作
du(t)dt=−H(t)⋅(u(t)−y)(6) \frac{du(t)}{dt}=-H(t)\cdot(u(t)-y) \tag{6} \\ dtdu(t)=−H(t)⋅(u(t)−y)(6)
其中H(t)∈Rn×nH(t)\in\mathbb{R}^{n\times n}H(t)∈Rn×n是定义为[H(t)]i,j=⟨∂f(w(t),xi)∂w,∂f(w(t),xj)∂w⟩(∀i,j∈[n])[H(t)]_{i,j}=\langle\frac{\partial f(w(t),x_i)}{\partial w},\frac{\partial f(w(t),x_j)}{\partial w} \rangle(\forall i,j\in[n])[H(t)]i,j=⟨∂w∂f(w(t),xi),∂w∂f(w(t),xj)⟩(∀i,j∈[n])。
上面引理涉及到矩阵H(t)H(t)H(t)。下面将会定义一个无限宽的神经网络,并固定训练数据。在这种限制下,训练过程中的矩阵H(t)H(t)H(t)为常数,即H(t)H(t)H(t)的等于H(0)H(0)H(0)。此外,对于随机初始化参数,当网络宽度为无限时,随机矩阵H(0)H(0)H(0)概率收敛至某个确定的核矩阵H∗H^*H∗,该矩阵就是通过训练数据估计出的神经正切核(Neural Tangent Kernel, NTK) k(⋅,⋅)k(\cdot,\cdot)k(⋅,⋅)。若对于所有ttt均有H(t)=H∗H(t)=H^*H(t)=H∗,那么等式(3)就变成
du(t)dt=−H∗⋅(u(t)−y)(7) \frac{d u(t)}{dt}=-H^*\cdot(u(t)-y) \tag{7} \\ dtdu(t)=−H∗⋅(u(t)−y)(7)
可以发现上述公式的动力学与梯度流下的核回归一致,那么当t→∞t\rightarrow\inftyt→∞时最终的预测函数为
f∗(x)=(k(x,x1),…,k(x,xn))⋅(H∗)−1y(8) f^*(x)=(k(x,x_1),\dots,k(x,x_n))\cdot(H^*)^{-1}y\tag{8} \\ f∗(x)=(k(x,x1),…,k(x,xn))⋅(H∗)−1y(8)
三、无限宽网络与神经正切核(NTK)
下面是一个简单的两层神经网络
f(a,W,x)=1m∑r=1marσ(wrTx)(9) f(a,W,x)=\frac{1}{\sqrt{m}}\sum_{r=1}^m a_r\sigma(w_r^Tx) \tag{9} \\ f(a,W,x)=m1r=1∑marσ(wrTx)(9)
其中mmm是网络的宽度,σ(⋅)\sigma(\cdot)σ(⋅)是激活函数。这里假设对于所有的z∈Rz\in\mathbb{R}z∈R,∣σ′(z)∣|\sigma'(z)|∣σ′(z)∣和∣σ′′(z)∣|\sigma''(z)|∣σ′′(z)∣的上界均为1,例如σ(z)=log(1+exp(z))\sigma(z)=\log(1+\exp(z))σ(z)=log(1+exp(z))就满足这个假设。假设所有的输入xxx的Euclidean范数均为1,即∥x∥2=1\parallel x\parallel_2=1∥x∥2=1。缩放因子1m\frac{1}{\sqrt{m}}m1在证明H(t)H(t)H(t)接近于固定核H∗H^*H∗上扮演者重要的角色。使用范式∥⋅∥2\parallel\cdot\parallel_2∥⋅∥2来衡量两个矩阵AAA和BBB的接近程度。
先计算H(0)H(0)H(0),并展示m→∞m\rightarrow\inftym→∞时H(0)H(0)H(0)收敛至固定矩阵H∗H^*H∗。 注意,∂f(a,W,xi)∂wr=1marxiσ′(wr⊤xi)\frac{\partial f(a,W,x_i)}{\partial w_r}=\frac{1}{\sqrt{m}}a_r x_i\sigma'(w_r^\top x_i)∂wr∂f(a,W,xi)=m1arxiσ′(wr⊤xi)。因此,H(0)H(0)H(0)中的每个元素为
[H(0)]ij=∑r=1m⟨∂f(a,W(0),xi)∂wr(0),∂f(a,W(0),xj)∂wr(0)⟩=∑r=1m⟨1marxiσ′(wr(0)⊤xi),1marxjσ′(wr(0)⊤xi)⟩=xi⊤xj⋅∑r=1mσ′(wr(0)⊤xi)σ′(wr(0)⊤xj)m(8) \begin{align} [H(0)]_{ij}&=\sum_{r=1}^m\Big\langle \frac{\partial f(a,W(0),x_i)}{\partial w_r(0)},\frac{\partial f(a,W(0),x_j)}{\partial w_r(0)} \Big\rangle \\ &=\sum_{r=1}^m\Big\langle\frac{1}{\sqrt{m}}a_rx_i\sigma'(w_r(0)^\top x_i),\frac{1}{\sqrt{m}}a_rx_j\sigma'(w_r(0)^\top x_i)\Big\rangle \\ &=x_i^\top x_j\cdot\frac{\sum_{r=1}^m\sigma'(w_r(0)^\top x_i)\sigma'(w_r(0)^\top x_j)}{m} \\ \end{align} \tag{8} \\ [H(0)]ij=r=1∑m⟨∂wr(0)∂f(a,W(0),xi),∂wr(0)∂f(a,W(0),xj)⟩=r=1∑m⟨m1arxiσ′(wr(0)⊤xi),m1arxjσ′(wr(0)⊤xi)⟩=xi⊤xj⋅m∑r=1mσ′(wr(0)⊤xi)σ′(wr(0)⊤xj)(8)
最后一步,由于ar∼Unif[{−1,1}]a_r\sim\text{Unif}[\{-1,1\}]ar∼Unif[{−1,1}],因此对于所有的r=1,…,mr=1,\dots,mr=1,…,m,有ar2=1a_r^2=1ar2=1。对于所有的wr(0)w_r(0)wr(0)都是从标准高斯分布中独立同分布采样出来的。因此,可以将[H(0)]ij[H(0)]_{ij}[H(0)]ij看做是m个独立同分布随机变量的平均值。若mmm很大,那么基于大数定律,这个平均值接近于随机变量的期望。在xix_ixi和xjx_jxj上由NTK评估的期望为:
Hij∗≜xi⊤xj⋅Ew∼N(0,I)[σ′(w⊤xi)σ′(wTxj)](9) H_{ij}^*\triangleq x_i^\top x_j\cdot\mathbb{E}_{w\sim N(0,I)}[\sigma'(w^\top x_i)\sigma'(w^T x_j)] \tag{9} \\ Hij∗≜xi⊤xj⋅Ew∼N(0,I)[σ′(w⊤xi)σ′(wTxj)](9)
基于Hoeffding不等式和Boole不等式,可以容易得知H(0)H(0)H(0)逼近于H∗H^*H∗。
引理2
对于某个ϵ>0\epsilon>0ϵ>0。若m=Ω(n4log(n/δ)ϵ2)m=\Omega(\frac{n^4\log(n/\delta)}{\epsilon^2})m=Ω(ϵ2n4log(n/δ)),那么w1(0),…,wm(0)w_1(0),\dots,w_m(0)w1(0),…,wm(0)至少以概率1−δ1-\delta1−δ满足
∥H(0)−H∗∥2≤ϵ \parallel H(0)-H^*\parallel_2\leq\epsilon \\ ∥H(0)−H∗∥2≤ϵ
证明。对于分量(i,j)(i,j)(i,j),由于∣σ′(z)∣≤1|\sigma'(z)|\leq 1∣σ′(z)∣≤1且∥x∥=1\parallel x\parallel=1∥x∥=1,那么有
∣xi⊤xjσ′(wt(0)⊤xi)σ′(wr(0)⊤xj)∣≤1 |x_i^\top x_j\sigma'(w_t(0)^\top x_i)\sigma'(w_r(0)^\top x_j)|\leq 1 \\ ∣xi⊤xjσ′(wt(0)⊤xi)σ′(wr(0)⊤xj)∣≤1
因此,[H(0)]ij[H(0)]_{ij}[H(0)]ij的边界为[0,1][0,1][0,1]。应用Hoeffding不等式,有
P(∣[H(0)]ij−Hij∗∣≥ϵn2)≤exp(−2m2(ϵn2)2∑i=1m(1−0)2)=exp(−2mϵ2n4)≤exp(−2ϵ2n4n4log(n/δ)ϵ2)=exp(−2log(n/δ))=δ2n2≤δn2 \begin{align} P\Big(|[H(0)]_{ij}-H_{ij}^*|\geq \frac{\epsilon}{n^2}\Big)&\leq \exp(-\frac{2m^2(\frac{\epsilon}{n^2})^2}{\sum_{i=1}^m(1-0)^2}) \\ &=\exp(-\frac{2m\epsilon^2}{n^4}) \\ &\leq\exp(-\frac{2\epsilon^2}{n^4}\frac{n^4\log(n/\delta)}{\epsilon^2}) \\ &=\exp(-2\log(n/\delta)) \\ &=\frac{\delta^2}{n^2}\leq\frac{\delta}{n^2} \\ \end{align} \\ P(∣[H(0)]ij−Hij∗∣≥n2ϵ)≤exp(−∑i=1m(1−0)22m2(n2ϵ)2)=exp(−n42mϵ2)≤exp(−n42ϵ2ϵ2n4log(n/δ))=exp(−2log(n/δ))=n2δ2≤n2δ
(注:nnn是训练样本数,mmm是网络宽度)那么有
P(∣[H(0)]ij−Hij∗∣≤ϵn2)=1−P(∣[H(0)]ij−Hij∗∣≥ϵn2)≥1−δn2 \begin{align} P\Big(|[H(0)]_{ij}-H_{ij}^*|\leq \frac{\epsilon}{n^2}\Big)&=1-P\Big(|[H(0)]_{ij}-H_{ij}^*|\geq \frac{\epsilon}{n^2}\Big)\geq 1-\frac{\delta}{n^2} \\ \end{align} \\ P(∣[H(0)]ij−Hij∗∣≤n2ϵ)=1−P(∣[H(0)]ij−Hij∗∣≥n2ϵ)≥1−n2δ
将上面的结论应用在所有(i,j)∈[n]×[n](i,j)\in[n]\times[n](i,j)∈[n]×[n],并使用Boole不等式
∥H(0)−H∗∥2≤∥H(0)−H∗∥F≤∑ij∣[H(0)]ij−Hij∗∣≤n2⋅ϵn2=ϵ \parallel H(0)-H^* \parallel_2\leq\parallel H(0)-H^* \parallel_F\leq\sum_{ij}|[H(0)]_{ij}-H_{ij}^*|\leq n^2\cdot\frac{\epsilon}{n^2}=\epsilon \\ ∥H(0)−H∗∥2≤∥H(0)−H∗∥F≤ij∑∣[H(0)]ij−Hij∗∣≤n2⋅n2ϵ=ϵ
接下来证明在训练过程中,H(t)H(t)H(t)逼近H(0)H(0)H(0)。
引理3
假设对于所有的i=1,…,ni=1,\dots,ni=1,…,n都有yi=O(1)y_i=O(1)yi=O(1)。给定t>0t>0t>0,对任意的0≤τ≤t0\leq\tau\leq t0≤τ≤t,所有的i=1,…,ni=1,\dots,ni=1,…,n都有ui(τ)=O(1)u_i(\tau)=O(1)ui(τ)=O(1)。若m=Ω(n6t2ϵ2)m=\Omega(\frac{n^6t^2}{\epsilon^2})m=Ω(ϵ2n6t2),有
∥H(t)−H(0)∥2≤ϵ \parallel H(t)-H(0) \parallel_2\leq\epsilon \\ ∥H(t)−H(0)∥2≤ϵ
(直观解释:若所有样本的标签值均不大于1,且0到ttt时刻中的任意时刻τ\tauτ,模型的预测值也不大于1。那么当网络宽度mmm大于n6t2ϵ2\frac{n^6t^2}{\epsilon^2}ϵ2n6t2时,ttt时刻的NTK核逼近于初始的NTK核)。 证明。第一个关键思想是:当mmm很大时,每个权重向量变化量很小。下面是单个权重向量的变化
∥wr(t)−wr(0)∥2=∥∫0tdwr(τ)dτdτ∥2=∥∫0t∑i=1n(ui(τ)−yi)∂ui(τ)∂wdτ∥2=∥∫0t∑i=1n(ui(τ)−yi)1marxiσ′(wr(τ)⊤xi)dτ∥2≤1m∫∥∑i=1n(ui(τ)−yi)arxiσ′(wr(τ)⊤xi)∥2dτ≤1m∑i=1n∫0t∥ui(τ)−yiarxiσ′(wr(τ)⊤xi)∥2dτ≤1m∑i=1n∫0tO(1)dτ=O(tnm) \begin{align} \parallel w_r(t)-w_r(0) \parallel_2&=\Big\| \int_{0}^t\frac{dw_r(\tau)}{d\tau}d\tau \Big\|_2 \\ &=\Big\|\int_{0}^t \sum_{i=1}^n(u_i(\tau)-y_i)\frac{\partial u_i(\tau)}{\partial w} d\tau \Big\|_2 \\ &=\Big\| \int_{0}^t\sum_{i=1}^n(u_i(\tau)-y_i)\frac{1}{\sqrt{m}}a_rx_i\sigma'(w_r(\tau)^\top x_i) d\tau \Big\|_2 \\ &\leq\frac{1}{\sqrt{m}}\int\Big\|\sum_{i=1}^n(u_i(\tau)-y_i)a_rx_i\sigma'(w_r(\tau)^\top x_i) \Big\|_2d\tau \\ &\leq\frac{1}{\sqrt{m}}\sum_{i=1}^n\int_{0}^t\| u_i(\tau)-y_ia_rx_i\sigma'(w_r(\tau)^\top x_i) \|_2 d\tau \\ &\leq\frac{1}{\sqrt{m}}\sum_{i=1}^n\int_{0}^t O(1) d\tau=O(\frac{tn}{\sqrt{m}}) \\ \end{align} \\ ∥wr(t)−wr(0)∥2= ∫0tdτdwr(τ)dτ 2= ∫0ti=1∑n(ui(τ)−yi)∂w∂ui(τ)dτ 2= ∫0ti=1∑n(ui(τ)−yi)m1arxiσ′(wr(τ)⊤xi)dτ 2≤m1∫ i=1∑n(ui(τ)−yi)arxiσ′(wr(τ)⊤xi) 2dτ≤m1i=1∑n∫0t∥ui(τ)−yiarxiσ′(wr(τ)⊤xi)∥2dτ≤m1i=1∑n∫0tO(1)dτ=O(mtn)
上面的结果表明:给定任意ttt,只要mmm足够大,则wr(t)w_r(t)wr(t)就接近于wr(0)w_r(0)wr(0)。下面将证明这意味着核矩阵H(t)H(t)H(t)接近于H(0)H(0)H(0)。这里证明单个分量的差距
[H(t)]ij−[H(0)]ij=∣1m∑r=1m(σ′(wr(t)⊤xi)σ′(wr(t)⊤xj)−σ′(wr(0)⊤xi)σ′(wr(0)⊤xj))∣≤1m∑r=1m∣σ′(wr(t)⊤xi)(σ′(wr(t)⊤xj)−σ′(wr(0)⊤xj))∣+1m∑r=1m∣σ′(wr(0)⊤xj)(σ′(wr(t)⊤xj)−σ′(wr(0)⊤xi))∣≤1m∑r=1m∣maxrσ′(wr(t)⊤xi)∥xi∥2∥wr(t)−wr(0)∥2∣+1m∑r=1m∣maxrσ′(wr(t)⊤xi)∥xi∥2∥wr(t)−wr(0)∥2∣=1m∑r=1mO(tnm) \begin{align} &[H(t)]_{ij}-[H(0)]_{ij} \\ =&\Big| \frac{1}{m}\sum_{r=1}^m\Big( \sigma'(w_r(t)^\top x_i)\sigma'(w_r(t)^\top x_j)- \sigma'(w_r(0)^\top x_i)\sigma'(w_r(0)^\top x_j)\Big) \Big| \\ \leq&\frac{1}{m}\sum_{r=1}^m\Big|\sigma'(w_r(t)^\top x_i)(\sigma'(w_r(t)^\top x_j)-\sigma'(w_r(0)^\top x_j)) \Big| \\ &+\frac{1}{m}\sum_{r=1}^m\Big|\sigma'(w_r(0)^\top x_j)(\sigma'(w_r(t)^\top x_j)-\sigma'(w_r(0)^\top x_i)) \Big| \\ \leq&\frac{1}{m}\sum_{r=1}^m\Big|\max_r \sigma'(w_r(t)^\top x_i)\|x_i\|_2\| w_r(t)-w_r(0) \|_2 \Big| \\ &+\frac{1}{m}\sum_{r=1}^m\Big|\max_r \sigma'(w_r(t)^\top x_i)\|x_i\|_2\| w_r(t)-w_r(0) \|_2 \Big| \\ =&\frac{1}{m}\sum_{r=1}^m O(\frac{tn}{\sqrt{m}}) \\ \end{align} \\ =≤≤=[H(t)]ij−[H(0)]ij m1r=1∑m(σ′(wr(t)⊤xi)σ′(wr(t)⊤xj)−σ′(wr(0)⊤xi)σ′(wr(0)⊤xj)) m1r=1∑m σ′(wr(t)⊤xi)(σ′(wr(t)⊤xj)−σ′(wr(0)⊤xj)) +m1r=1∑m σ′(wr(0)⊤xj)(σ′(wr(t)⊤xj)−σ′(wr(0)⊤xi)) m1r=1∑m rmaxσ′(wr(t)⊤xi)∥xi∥2∥wr(t)−wr(0)∥2 +m1r=1∑m rmaxσ′(wr(t)⊤xi)∥xi∥2∥wr(t)−wr(0)∥2 m1r=1∑mO(mtn)
因此,有
∥H(t)−H(0)∥2≤∑i,j∣[H(t)]ij−[H(0)]ij∣=O(tn3m) \| H(t)-H(0)\|_2\leq\sum_{i,j}\Big|[H(t)]_{ij}-[H(0)]_{ij} \Big|=O\Big(\frac{tn^3}{\sqrt{m}}\Big) \\ ∥H(t)−H(0)∥2≤i,j∑ [H(t)]ij−[H(0)]ij =O(mtn3)
四、用NTK解释无限宽网络的优化和泛化
基于上面的结论有
du(t)dt≈−H∗⋅(u(t)−y)(10) \frac{du(t)}{d_t}\approx -H^*\cdot(u(t)-y) \tag{10}\\ dtdu(t)≈−H∗⋅(u(t)−y)(10)
其中H∗H^*H∗是NTK矩阵。接下来基于该近似分析无限宽神经网络的优化和泛化。
1. 优化
U(t)U(t)U(t)的动力学遵循
du(t)dt=−H∗⋅(u(t)−y)(11) \frac{du(t)}{d_t}= -H^*\cdot(u(t)-y) \tag{11}\\ dtdu(t)=−H∗⋅(u(t)−y)(11)
本质上是线性动力系统。对H∗H^*H∗进行特征值分解的
H∗=∑i=1nλivivi⊤(12) H^*=\sum_{i=1}^n\lambda_i v_i v_i^\top \tag{12}\\ H∗=i=1∑nλivivi⊤(12)
其中λ1≥⋯≥λn≥0\lambda_1\geq\dots\geq\lambda_n\geq 0λ1≥⋯≥λn≥0是特征值,v1,…,vnv_1,\dots,v_nv1,…,vn是特征向量。基于该分解可以分别研究u(t)u(t)u(t)在每个特征向量上的动力学。对等式(12)两边同时乘以viv_ivi得,得到u(t)u(t)u(t)在特征向量viv_ivi上的动力学
dvi⊤u(t)dt=−vi⊤H∗⋅(u(t)−y)=−vi⊤∑i=1nλivivi⊤⋅(u(t)−y)=−λi(vi⊤(u(t)−y))(13) \begin{align} \frac{dv_i^\top u(t)}{dt}&=-v_i^\top H^*\cdot(u(t)-y) \\ &=-v_i^\top\sum_{i=1}^n\lambda_i v_i v_i^\top\cdot(u(t)-y) \\ &=-\lambda_i(v_i^\top(u(t)-y)) \\ \end{align} \tag{13}\\ dtdvi⊤u(t)=−vi⊤H∗⋅(u(t)−y)=−vi⊤i=1∑nλivivi⊤⋅(u(t)−y)=−λi(vi⊤(u(t)−y))(13)
可以看到vi⊤u(t)v_i^\top u(t)vi⊤u(t)的动力学仅依赖于其本身和λi\lambda_iλi,这其实是一个常微分方程。该常微分方程的一个解析解为
vi⊤(u(t)−y)=exp(−λit)(vi⊤(u(0)−y))(14) v_i^\top(u(t)-y)=\exp(-\lambda_i t)\Big(v_i^\top(u(0)-y) \Big) \tag{14}\\ vi⊤(u(t)−y)=exp(−λit)(vi⊤(u(0)−y))(14)
现在使用上面的等式来解释为什么可以找到0训练误差解。假设对于所有的i=1,…,ni=1,\dots,ni=1,…,n均有λi>0\lambda_i>0λi>0,即核矩阵的所有特征值均严格为正。
(u(t)−y)(u(t)-y)(u(t)−y)表示ttt时刻预测值和训练标签之间的差值。若当t→∞t\rightarrow\inftyt→∞,有u(t)−y→0u(t)-y\rightarrow 0u(t)−y→0时,表示存在一个训练误差为0的算法。等式(14)表示该差值的分量,由于项exp(−λit)\exp(-\lambda_i t)exp(−λit),所以vi⊤(u(t)−y)v_i^\top(u(t)-y)vi⊤(u(t)−y)会以指数级的速度收敛至0。此外,由于{v1,…,vn}\{v_1,\dots,v_n\}{v1,…,vn}是Rn\mathbb{R}^nRn上的一个正交基,因此(u(t)−y)=∑i=1nvi⊤(u(t)−y)(u(t)-y)=\sum_{i=1}^nv_i^\top(u(t)-y)(u(t)−y)=∑i=1nvi⊤(u(t)−y)。因此,当每个vi⊤(ui(t)−y)→0v_i^\top(u_i(t)-y)\rightarrow 0vi⊤(ui(t)−y)→0,可以得到(u(t)−y)→0(u(t)-y)\rightarrow 0(u(t)−y)→0。
等式(14)本质上给出了关于收敛相关的信息,即每个分量vi⊤(u(t)−y)v_i^\top(u(t)-y)vi⊤(u(t)−y)以不同的速率收敛至0。较大的λi\lambda_iλi对应的分量收敛到0的速度快于较小的λi\lambda_iλi。若期望在给定标签下能够更快的收敛,那么yyy投影至顶部的特征应该更大。因此,可以通过下面直观的来定性比较收敛速度
- 若标签集合yyy对齐至顶部特征,即(vi⊤y)(v_i^\top y)(vi⊤y)对应较大的特征值,那么梯度下降收敛较快;
- 若标签集合yyy投影至特征向量{(vi⊤y)}i=1n\{(v_i^\top y)\}_{i=1}^n{(vi⊤y)}i=1n是均匀分布,那么梯度下降的收敛速度就较慢;
2. 泛化
等式(10)中的近似意味着无限宽神经网络最终预测的函数近似于等式(8)的核预测函数。因此,可以使用核的泛化理论来分析无限宽神经网络的泛化行为。等式(8)中定义的核预测函数,使用Rademacher复杂度边界来推断下面1-Lipschitz损失函数的泛化边界
2y⊤(H∗)−1y⋅tr(H∗)n(15) \frac{\sqrt{2y^\top(H^*)^{-1}y\cdot tr(H^*)}}{n} \tag{15}\\ n2y⊤(H∗)−1y⋅tr(H∗)(15)
这是一个依赖于数据的复杂度度量的泛化误差上界。
五、多层全连接神经网络的NTK形式
先来定义全连接神经网络。令x∈Rdx\in\mathbb{R}^dx∈Rd表示输入,为了方便令g(0)(x)=xg^{(0)}(x)=xg(0)(x)=x且d0=dd_0=dd0=d。那么LLL层全连接神经网络表示为
f(h)(x)=W(h)g(h−1)(x)∈Rdh,g(h)(x)=cσdhσ(f(h)(x))∈Rdh(16) f^{(h)}(x)=W^{(h)}g^{(h-1)}(x)\in\mathbb{R}^{d_h},g^{(h)}(x)=\sqrt{\frac{c_{\sigma}}{d_h}}\sigma\Big(f^{(h)}(x)\Big)\in\mathbb{R}^{d_h} \tag{16}\\ f(h)(x)=W(h)g(h−1)(x)∈Rdh,g(h)(x)=dhcσσ(f(h)(x))∈Rdh(16)
其中h=1,2,…,Lh=1,2,\dots,Lh=1,2,…,L,W(h)∈Rdh×dh−1W^{(h)}\in\mathbb{R}^{d_h\times d_{h-1}}W(h)∈Rdh×dh−1表示第hhh层的权重矩阵,σ:R→R\sigma:\mathbb{R}\rightarrow\mathbb{R}σ:R→R是激活函数,cσ=(Ez∼N(0,1)[σz2])−1c_{\sigma}=\Big(E_{z\sim\mathcal{N}(0,1)}[\sigma z^2]\Big)^{-1}cσ=(Ez∼N(0,1)[σz2])−1。神经网络的最后一层来自于
f(w,x)=f(L+1)(x)=W(L+1)⋅g(L)(x)=W(L+1)⋅cσdLσW(L)⋅cσdL−1σW(L−1)⋯⋅cσd1σW(1)x(17) \begin{align} f(w,x)&=f^{(L+1)}(x)=W^{(L+1)}\cdot g^{(L)}(x) \\ &=W^{(L+1)}\cdot\sqrt{\frac{c_{\sigma}}{d_L}}\sigma W^{(L)}\cdot\sqrt{\frac{c_{\sigma}}{d_{L-1}}}\sigma W^{(L-1)}\dots \cdot\sqrt{\frac{c_{\sigma}}{d_1}}\sigma W^{(1)}x \end{align} \tag{17}\\ f(w,x)=f(L+1)(x)=W(L+1)⋅g(L)(x)=W(L+1)⋅dLcσσW(L)⋅dL−1cσσW(L−1)⋯⋅d1cσσW(1)x(17)
其中W(L+1)∈R1×dLW^{(L+1)}\in\mathbb{R}^{1\times d_L}W(L+1)∈R1×dL表示最后一层的权重,w=(W(1),…,W(L+1))w=\Big(W^{(1)},\dots,W^{(L+1)}\Big)w=(W(1),…,W(L+1))表示神经网络的所有权重。
使用标准正态分布来初始化权重并考虑hidden宽度的极限为:d1,d2,…,dL→∞d_1,d_2,\dots,d_L\rightarrow\inftyd1,d2,…,dL→∞。缩放因子cσ/dh\sqrt{c_{\sigma}/d_h}cσ/dh用于确保g(h)(x)g^{(h)}(x)g(h)(x)近似于初始化。对于ReLU集合函数,有
E[∥g(h)(x)∥22]=∥x∥22(∀h∈[L])(18) E\Big[\Big\| g^{(h)}(x) \Big\|_2^2\Big]=\|x\|_2^2(\forall h\in[L]) \tag{18} \\ E[
g(h)(x)
22]=∥x∥22(∀h∈[L])(18)
正如引理1中需要计算⟨∂f(w(t),x)∂w,∂f(w(t),x′)∂w⟩\langle\frac{\partial f(w(t),x)}{\partial w},\frac{\partial f(w(t),x')}{\partial w}\rangle⟨∂w∂f(w(t),x),∂w∂f(w(t),x′)⟩在无限宽下收敛至随机初始化。可以将关于特定权重矩阵W(h)W^{(h)}W(h)的偏导数写作
∂f(w,x)∂W(h)=b(h)(x)⋅(g(h−1)(x))⊤,h=1,2,…,L+1(19) \frac{\partial f(w,x)}{\partial W^{(h)}}=b^{(h)}(x)\cdot\Big(g^{(h-1)}(x)\Big)^\top,\quad h=1,2,\dots,L+1 \tag{19} \\ ∂W(h)∂f(w,x)=b(h)(x)⋅(g(h−1)(x))⊤,h=1,2,…,L+1(19)
其中
b(h)(x)={1∈R,h=L+1cσdhD(h)(x)(W(h+1))⊤b(h+1)(x)∈Rdh,h=1,…,L(20) b^{(h)}(x)=\begin{cases} 1\in\mathbb{R},& h=L+1 \\ \sqrt{\frac{c_\sigma}{d_h}}D^{(h)}(x)\Big(W^{(h+1)} \Big)^\top b^{(h+1)}(x)\in\mathbb{R}^{d_h},& h=1,\dots,L \end{cases} \tag{20} \\ b(h)(x)=⎩
⎨
⎧1∈R,dhcσD(h)(x)(W(h+1))⊤b(h+1)(x)∈Rdh,h=L+1h=1,…,L(20)
KaTeX parse error: Expected 'EOF', got '&' at position 93: …d_h\times d_h},&̲h=1,\dots,L \ta…
对于两个任意的输入xxx和x′x'x′,任意的h∈[L+1]h\in[L+1]h∈[L+1],可以计算
⟨∂f(w,x)∂W(h),∂f(w,x′)∂W(h)⟩=⟨b(h)(x)⋅(g(h−1)(x))⊤,b(h)(x′)⋅(g(h−1)(x′))⊤⟩=⟨g(h−1)(x),g(h−1)(x′)⟩⋅⟨b(h)(x),b(h)(x′)⟩(22) \begin{align} &\Big\langle\frac{\partial f(w,x)}{\partial W^{(h)}},\frac{\partial f(w,x')}{\partial W^{(h)}}\Big\rangle \\ =&\Big\langle b^{(h)}(x)\cdot\Big(g^{(h-1)}(x)\Big)^\top, b^{(h)}(x')\cdot\Big(g^{(h-1)}(x')\Big)^\top\Big\rangle \\ =&\langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle\cdot\langle b^{(h)}(x),b^{(h)}(x') \rangle \\ \end{align} \tag{22}\\ ==⟨∂W(h)∂f(w,x),∂W(h)∂f(w,x′)⟩⟨b(h)(x)⋅(g(h−1)(x))⊤,b(h)(x′)⋅(g(h−1)(x′))⊤⟩⟨g(h−1)(x),g(h−1)(x′)⟩⋅⟨b(h)(x),b(h)(x′)⟩(22)
第一项⟨g(h−1)(x),g(h−1)(x′)⟩\langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle⟨g(h−1)(x),g(h−1)(x′)⟩是xxx和x′x'x′在第hhh层的协方差。当宽度趋于无穷时,⟨g(h−1)(x),g(h−1)(x′)⟩\langle g^{(h-1)}(x),g^{(h-1)}(x') \rangle⟨g(h−1)(x),g(h−1)(x′)⟩收敛至固定的数,这里表示为Σ(h−1)(x,x′)\Sigma^{(h-1)}(x,x')Σ(h−1)(x,x′)。对于h∈[L]h\in[L]h∈[L],该协方差的递归形式为
Σ(0)(x,x′)=x⊤x′Λ(h)(x,x′)=(Σ(h−1)(x,x)Σ(h−1)(x,x′)Σ(h−1)(x′,x)Σ(h−1)(x′,x′))∈R2×2Σ(h)(x,x′)=cσE(u,v)∼N(0,Λ(h))[σ(u)σ(v)](23) \begin{align} \Sigma^{(0)}(x,x')&=x^\top x' \\ \Lambda^{(h)}(x,x')&= \begin{pmatrix} \Sigma^{(h-1)}(x,x)&\Sigma^{(h-1)}(x,x') \\ \Sigma^{(h-1)}(x',x)&\Sigma^{(h-1)}(x',x') \\ \end{pmatrix}\in\mathbb{R}^{2\times 2} \\ \Sigma^{(h)}(x,x')&=c_\sigma E_{(u,v)\sim\mathcal{N}(0,\Lambda^{(h)})}[\sigma(u)\sigma(v)] \end{align}\tag{23} \\ Σ(0)(x,x′)Λ(h)(x,x′)Σ(h)(x,x′)=x⊤x′=(Σ(h−1)(x,x)Σ(h−1)(x′,x)Σ(h−1)(x,x′)Σ(h−1)(x′,x′))∈R2×2=cσE(u,v)∼N(0,Λ(h))[σ(u)σ(v)](23)
更多推荐
所有评论(0)