当前位置: 首页 > 科技观察

Github13000stars,JAX相对于TensorFlow、PyTorch的快速发展

时间:2023-03-11 20:59:21 科技观察

在机器学习领域,大家可能对TensorFlow和PyTorch比较熟悉,但是除了这两个框架之外,一些新的力量也不容小觑,它是谷歌推出的JAX。很多研究者对它寄予厚望,希望它能够替代TensorFlow等众多机器学习框架。JAX最初由GoogleBrain团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起。目前,JAX在GitHub上已经积累了13700颗星。项目地址:https://github.com/google/jaxJAX的快速发展,JAX的前身是Autograd,它使用Autograd的更新版本,结合XLA进行Python程序自动微分和NumPy运算,支持循环和分支,递归和闭包函数推导,三阶导数也可以得到;依赖XLA,JAX可以在GPU和TPU上编译运行NumPy程序;通过grad,可以支持自动模式反向传播和前向传播,并且两者可以任意顺序组合。开发JAX的出发点是什么?说到这里,就不得不提一下NumPy。NumPy是Python中的基础数值运算库,应用广泛。但是,numpy不支持GPU或其他硬件加速器,也没有内置的反向传播支持。此外,Python本身的速度限制阻碍了NumPy的使用,因此很少有研究人员直接使用numpy在生产环境中训练或部署深度学习模型。.在这种情况下,出现了众多的深度学习框架,例如PyTorch、TensorFlow等。但是,numpy具有灵活性、调试方便、API稳定等独特优势。JAX的主要出发点就是将numpy的上述优势与硬件加速结合起来。目前,有很多基于JAX的优秀开源项目。例如,Google的神经网络库团队开发了Haiku,这是一个用于Jax的深度学习代码库。通过Haiku,用户可以对Jax进行面向对象的开发;另一个例子是RLax,这是一个基于Jax的强化学习库,用户可以使用RLax来构建和训练Q-learning模型;此外,它还包括基于JAX的深度学习库JAXnet,可以用一行代码定义计算图,并可以进行GPU加速。可以说,在过去的几年里,JAX掀起了一场深度学习研究的风暴,推动了科学研究的快速发展。JAX安装如何使用JAX?首先需要在Python环境或者Googlecolab中安装JAX,使用pip安装:$pipinstall--upgradejaxjaxlib注意上面的安装方式只支持在CPU上运行。如果要在GPU上执行程序,首先需要有CUDA、cuDNN,然后运行以下命令(确保将jaxlib版本映射到CUDA版本):$pipinstall--upgradejaxjaxlib==0.1.61+cuda110-fhttps://storage.googleapis.com/jax-releases/jax_releases.html现在integrateJAXwithNumpy一起导入:importjaximportjax.numpyasjnpimportnumpyasnpJAX的一些特性使用grad()函数进行自动微分:这个对深度学习很有用应用程序,以便可以轻松运行反向传播,下面是一个简单的二次函数和点1.0推导示例:fromjaximportgraddeff(x):return3*x**2+2*x+5deff_prime(x):return6*x+2grad(f)(1.0)#DeviceArray(8.,dtype=float32)f_prime(1.0)#8.0jit(Justintime):为了利用XLA的强大特性,必须将代码编译到XLA内核中。这就是jit发挥作用的地方。要将XLA与jit一起使用,用户可以使用jit()函数或@jit注释。fromjaximportjitx=np.random.rand(1000,1000)y=jn??p.array(x)deff(x):for_inrange(10):x=0.5*x+0.1*jnp.sin(x)返回xg=jit(f)%timeit-n5-r5f(y).block_until_ready()#5loops,bestof5:10.8msperloop%timeit-n5-r5g(y).block_until_ready()#5loops,bestof5:341μsperlooppmap:自动将计算分配给所有当前设备,并且处理它们之间的所有通信。JAX通过pmap转换支持大规模数据并行,从而处理单个处理器无法处理的大数据。要检查可用设备,您可以运行jax.devices():fromjaximportpmapdeff(x):returnjnp.sin(x)+x**2f(np.arange(4))#DeviceArray([0.,1.841471,4.9092975,9.14112],dtype=float32)pmap(f)(np.arange(4))#ShardedDeviceArray([0.,1.841471,4.9092975,9.14112],dtype=float32)vmap:是一个函数转换,JAX大大提供了自动Vectorization算法简化这种类型的计算,使研究人员在研究新算法时不必处理批处理。示例如下:fromjaximportvmapdeff(x):returnjnp.square(x)f(jnp.arange(10))#DeviceArray([0,1,4,9,16,25,36,49,64,81],dtype=int32)vmap(f)(jnp.arange(10))#DeviceArray([0,1,4,9,16,25,36,49,64,81],dtype=int32)TensorFlowvsPyTorchvsJax深入学习领域有几家巨头公司,他们提出的框架被大量研究人员使用。比如谷歌的TensorFlow、Facebook的PyTorch、微软的CNTK、亚马逊AWS的MXnet等,每一个框架都有优缺点,选择的时候需要根据自己的需求来选择。我们以Python中的3大深度学习框架——TensorFlow、PyTorch和Jax为例进行比较。尽管这些框架不同,但它们有两个共同点:它们都是开源的。这意味着如果库中存在错误,用户可以在GitHub中发布问题(并修复),而且您可以将自己的功能添加到库中;由于全局解释器锁,Python在内部运行缓慢。所以这些框架使用C/C++作为后端来处理所有的计算和并行过程。那么它们有什么区别呢?如下表所示,是TensorFlow、PyTorch、JAX这三个框架的对比。TensorFlowTensorFlow是由谷歌开发的。最初的版本可以追溯到2015年开源的TensorFlow0.1,之后稳步发展,拥有强大的用户群,成为最流行的深度学习框架。然而,用户在使用的过程中,也暴露出TensorFlow的不足,比如API稳定性不足、静态计算图编程复杂等。因此,在TensorFlow2.0版本中,Google收录了Keras,成为tf.keras。目前,TensorFlow的主要特点包括:这是一个非常友好的框架,先进的API-Keras的可用性使得模型层定义、损失函数和模型创建变得非常容易;TensorFlow2.0有EagerExecution(动态图机制),使库更加人性化,是对之前版本的重大升级;Keras的高层接口有一定的缺点,因为TensorFlow抽象了很多底层机制(只是为了方便最终用户),这使得研究人员处理模型变得困难。自由度较少;Tensorflow提供了TensorBoard,其实就是Tensorflow可视化工具包。它允许研究人员可视化损失函数、模型图、模型分析等。PyTorchPyTorch(Python-Torch)是来自Facebook的机器学习库。TensorFlow还是PyTorch?一年前,这个问题没有争议,大多数研究人员会选择TensorFlow。但现在情况大不相同,越来越多的研究人员使用PyTorch。PyTorch的一些最重要的特性包括:与TensorFlow不同,PyTorch使用动态类型图,这意味着执行图是动态创建的。它允许我们随时修改和检查图的内部结构;除了用户友好的高级API之外,PyTorch还包括精心构建的低级API,允许对机器学习模型进行越来越多的控制。我们可以在训练过程中检查和修改模型的前向和后向传递过程中的输出。这已被证明对梯度裁剪和神经风格迁移非常有效;PyTorch允许用户扩展代码,可以轻松添加新的损失函数和用户定义的层。PyTorch的Autograd模块实现了深度学习算法中的反向传播导数。对于Tensor类的所有操作,Autograd可以自动提供微分,简化了手动计算导数的复杂过程;PyTorch具有使用数据并行性和GPU的优势。广泛支持;PyTorch比TensorFlow更Pythonic。PyTorch非常适合Python生态系统,允许使用类似Python的调试器工具调试PyTorch代码。JAXJAX是来自Google的一个相对较新的机器学习库。它更像是一个区分原生python和NumPy代码的autograd库。JAX的一些特性主要包括:如官网所述,JAX能够对Python+NumPy程序进行可组合的转换:向量化、JIT到GPU/TPU等;与PyTorch相比,JAX最重要的方面是什么?计算梯度。在Torch中,图形是在正向传递过程中创建的,梯度是在反向传递过程中计算的,另一方面,在JAX中,计算表示为函数。在函数上使用grad()返回一个梯度函数,该函数直接计算给定输入的函数的梯度;JAX是一个autograd工具,不建议单独使用。有各种基于JAX的机器学习库,其中值得注意的是ObJax、Flax和Elegy。由于它们都使用相同的核心并且接口只是JAX库的包装器,因此它们可以放在同一个支架下;Flax最初是在PyTorch生态下开发的,更注重使用的灵活性。另一方面,Elegy的灵感来自Keras。ObJAX主要是为研究目的而设计的,更加注重简单性和可理解性。