当前位置: 首页 > 科技观察

只知道TF和PyTorch还不够,让我们看看如何从PyTorch切换到自动微分神器JAX

时间:2023-03-22 11:11:06 科技观察

说到现在的深度学习框架,我们往往绕不开TensorFlow和PyTorch。但是除了这两个框架之外,一些新的力量也不容小觑,其中之一就是JAX。它具有正向和反向自动微分,非常擅长计算高阶导数。这个崭露头角的框架有多大用处?如何用它来演示神经网络内部复杂的梯度更新和反向传播?本文是一篇教程贴,教你理解Jax的底层逻辑,让你更容易从PyTorch等学习迁移。Jax是谷歌开发的用于机器学习和数学计算的Python库。启动后,Jax将其定义为Python+NumPy包。它在TPU和GPU上具有微分、向量化、JIT语言等特性。简而言之,这是GPU版本的numpy,具有自动微分功能。甚至一些研究人员,如SkyeWanderman-Milne,也在去年的NeurlPS2019会议上介绍了Jax。然而,对于开发者来说,从已经熟悉的PyTorch或TensorFlow2.X转向Jax无疑是一个巨大的变化:两者在构建计算和反向传播的方式上有着根本的不同。PyTorch构建计算图并计算正向和反向传递。结果节点上的梯度由中间节点的梯度累加而成。另一方面,Jax允许您将计算表达为Python函数,并且grad()将它们转换为梯度函数,从而允许您对它们进行评估。但是它不是给出结果,而是给出结果的梯度。两者的对比如下:这样一来,你编程和构建模型的方式就不同了。因此,您可以使用基于磁带的自动微分方法并使用有状态对象。但是Jax可能会让您感到惊讶,因为它使微分过程在运行grad()函数时表现得像一个函数。也许您已经决定研究基于Jax的工具,例如flax、trax或haiku。查看ResNet等示例时,您会发现它与其他框架中的代码不同。除了定义层和运行训练之外,底层逻辑是什么样的?这些小的numpy程序如何训练一个庞大的架构?本文是一篇介绍Jax搭建模型的教程,机器之心选了其中的两个Sections:LSTM-LM在PyTorch上的应用速览;查看PyTorch风格的代码(基于变异状态),并查看纯函数如何构建模型(Jax);PyTorch上的LSTM语言模型我们首先在PyTorchModel中实现LSTM语言,代码如下:.weight_ih=torch.nn.Parameter(torch.rand(4*out_dim,in_dim))self.weight_hh=torch.nn.Parameter(torch.rand(4*out_dim,out_dim))self.bias=torch.nn.Parameter(torch.zeros(4*out_dim,))defforward(self,inputs,h,c):ifgo=self.weight_ih@inputs+self.weight_hh@h+self.biasi,f,g,o=torch.chunk(ifgo,4)i=torch.sigmoid(i)f=torch.sigmoid(f)g=torch.tanh(g)o=torch.sigmoid(o)new_c=f*c+i*gnew_h=o*torch。tanh(new_c)return(new_h,new_c)然后,我们基于这个LSTM神经元构建一个单层网络。将有一个嵌入层,它和可学习的(h,c)0将显示单个参数如何变化。classLSTMLM(torch.nn.Module):def__init__(self,vocab_size,dim=17):super().__init__()self.cell=LSTMCell(dim,dim)self.embeddings=torch.nn.Parameter(torch.rand(vocab_size,dim))self.c_0=torch.nn.Parameter(torch.zeros(dim))@propertydefhc_0(self):return(torch.tanh(self.c_0),self.c_0)defforward(self,seq,hc):loss=torch.tensor(0.)foridxinseq:loss-=torch.log_softmax(self.embeddings@hc[0],dim=-1)[idx]hc=self.cell(self.embeddings[idx,:],*hc)returnloss,hcdefgreedy_argmax(self,hc,length=6):withtorch.no_grad():idxs=[]foriinrange(length):idx=torch.argmax(self.embeddings@hc[0])idxs回复.append(idx.item())hc=self.cell(self.embeddings[idx,:],*hc)returnidxs构建后,进行训练:torch.manual_seed(0)#Astrainingdata,wewillhaveindicesofwords/wordpieces/characters,#我们假设它们被标记化和整数化(显然是玩具示例)。importjax.numpyasjnpvocab_size=43#primetric!:)training_data=jnp.array([4,8,15,16,23,42])lm=LSTMLM(vocab_sizevocab_size=vocab_size)print("Samplebefore:",升m.greedy_argmax(lm.hc_0))bptt_length=3#toillustratehc.detach-ingforepochinrange(101):hc=lm.hc_0totalloss=0.forstartinrange(0,len(training_data),bptt_length):batch=training_data[开始:开始+bptt_length]loss,(h,c)=lm(batch,hc)hc=(h.detach(),c.detach())ifepoch%50==0:totalloss+=loss.item()loss.backward()forname,paraminlm.named_pa??rameters():ifparam.gradisnotNone:param.data-=0.1*param.graddelparam.gradiftotalloss:print("Loss:",totalloss)print("Sampleafter:",lm.greedy_argmax(lm.hc_0))Samplebefore:[42,34,34,34,34,34]Loss:25.953862190246582Loss:3.7642268538475037Loss:1.9537211656570435Sampleafter:[4,8,15,16,23,42]可以看到,PyTorch的代码比较清晰,但仍然存在一些问题。虽然我很细心,但还是要注意计算图中的节点数量。那些中间节点需要适时清除。纯函数要理解JAX如何处理这个问题,我们首先需要理解纯函数的概念。如果您以前做过函数式编程,您可能熟悉纯函数就像数学中的函数或公式这样的概念。它定义了如何从某个输入值中获取输出值。重要的是,它没有“副作用”,即函数的任何部分都不会访问或更改任何全局状态。当我们在Pytorch中编写代码时,它充满了中间变量或状态,并且这些状态经常变化,这使得推理和优化非常棘手。因此,JAX选择将程序员限制在纯函数的范围内,防止上述情况的发生。在深入JAX之前,最好先看几个纯函数示例。纯函数必须满足以下条件:执行函数的时间和时间不应该影响输出——只要输入不改变,输出也不应该改变;无论我们执行函数0次、1次还是多次,之后应该是无法区分的。以下非常纯粹的数据都至少违反了上面描述的条款中的一条:importrandomimporttimenr_executions=0defpure_fn_1(x):return2*xdefpure_fn_2(xs):ys=[]forxinxs:#Mutatingstatefulvariables*inside*thefunctionisfine!ysdef_s.append(2)(xs):#Mutatingargumentshaslastingconsequencesoutsidethefunction!:(xs.append(sum(xs))returnxsdefimpure_fn_2(x):#Veryobviouslymutatingglobalstateisbad...globalnr_executionsnr_executions+=1return2*xdefimpure_fn_3(x):#...但也只是访问它,因为现在函数执行依赖于#executionrreturn!*xdefimpure_fn_4(x):#ThingslikeIOareclassicexamplesofimpurity.#Allthreeofthefollowinglinesareviolationsofpurity:print("Hello!")user_input=input()execution_time=time.time()return2*xdefimpure_fn_5(x):#Whichconstraintdoesthisviolate?两者,实际上!你访问当前#stateofrandomness*和*advancethenumbergenerator!p=random.random()returnp*x让我们看看JAX操作的一个纯函数:简介中的例子。#(几乎)1-D线性回归deff(w,x):returnw*xprint(f(13.,42.))546.0到目前为止什么都没有发生JAX现在允许您将以下函数转换为另一个函数,而不是返回结果,返回函数结果梯度为函数的第一个参数。importjaximportjax.numpyasjnp#Gradient:withrespecttoweights!JAX默认使用第一个参数.df_dw=jax.grad(f)defmanual_df_dw(w,x):returnxassertdf_dw(13.,42.)==manual_df_dw(13.,42.)print(df.dw(13).,42.)42.))42.0到此为止,JAX的README文档中前面的内容你大概都看过了,内容也很合理。但是如何跳转到像PyTorch代码中那样的大模块呢?首先,让我们添加一个偏差项,并尝试将一维线性回归变量包装到我们习惯的对象中——一个线性回归“层”(LinearRegressor"layer"):classLinearRegressor():def__init__(self,w,b):self.w=wself.b=bdefpredict(self,x):returnsself.w*x+self.bdefrms(self,xs:jnp.ndarray,ys:jnp.ndarray):返回jnp.sqrt(jnp.sum(jnp.square(self.w*xs+self.b-ys)))my_regressor=LinearRegressor(13.,0.)#Akindoflossfuction,usedfortrainingxs=jnp.array([42.0])ys=jnp.array([500.0])print(my_regressor.rms(xs,ys))#Predictionfortestdataprint(my_regressor.predict(42.))46.0546.0接下来怎么用梯度训练呢?我们需要一个纯函数,它将我们的模型权重作为函数的输入参数,它可能看起来像这样:defloss_fn(w,b,xs,ys):my_regressor=LinearRegressor(w,b)returnmy_regressor.rms(xsxs=xs,ysys=ys)#Weuseargnums=(0,1)totellJAXtogiveus#gradientswrtfirstandsecondparameter.grad_fn=jax.grad(loss_fn,argnums=(0,1))print(loss_fn(13.,0.,xs,ys))print(grad_fn(13.,0.,xs,ys))46.0(DeviceArray(42.,dtype=float32),DeviceArray(1.,dtype=float32))你必须说服自己这是对的。现在,这是可行的,但显然在loss_fn的定义部分枚举所有参数是不可行的。幸运的是,JAX不仅可以区分标量、向量、矩阵,还可以区分许多树状数据结构。Thisstructureiscalledapytree,includingpythondicts:defloss_fn(params,xs,ys):my_regressor=LinearRegressor(params['w'],params['b'])returnmy_regressor.rms(xsxs=xs,ysys=ys)grad_fn=jax.grad(loss_fn)print(loss_fn({'w':13.,'b':0.},xs,ys))print(grad_fn({'w':13.,'b':0.},xs,ys))46.0{'b':DeviceArray(1.,dtype=float32),'w':DeviceArray(42.,dtype=float32)}Youcanwriteatraininglooplikethis:params={'w':13.,'b':0.}for_inrange(15):print(loss_fn(params,xs,ys))grads=grad_fn(params,xs,ys)fornameinparams.keys():params[name]-=0.002*grads[name]#Now,predict:LinearRegressor(params['w'],params['b']).predict(42.)46.042.4700338.94000235.41003431.88006628.35009824.82006821.290117.76013214.23016410.7001657.1701663.64019780.1101989753.4197998DeviceArray(500.1102,dtype=float32)注意,现在已经可以使用更多的JAXhelper来进行自我更新:由于参数和梯度拥有共同的(类似树的)结构,Wecanimagineputtingthemontop,creatinganewtreewhosevalueiseverywherea"combination"ofthesetwotrees,likethis:defupdate_combiner(param,grad,lr=0.002):returnparam-lr*gradparams=jax.tree_multimap(update_combiner,params,grads)#insteadof:#fornameinparams.keys():#params[name]-=0.1*grads[name]参考链接:https://sjmielke.com/jax-purify.htm【本文是《机器之心》专栏的原文翻译,微信公众号《机器之心(id:almosthuman2014)》】点此查看作者更多好文