当前位置: 首页 > 科技观察

2022年,我应该使用JAX吗?GitHub上有16k星,这个年轻的工具并不完美

时间:2023-03-15 19:33:23 科技观察

JAX自2018年底推出以来,人气一直在稳步上升。2020年,DeepMind宣布使用JAX来加速其研究。来自GoogleBrain和其他机构的越来越多的项目也在使用JAX。目前,在JAX的GitHub项目首页,Star数已达16.3k。项目地址:https://github.com/google/jaxJAX是一个非常有前途的项目,用户一直在稳步增长。JAX已广泛应用于深度学习、机器人/控制系统、贝叶斯方法和科学模拟等众多领域。那么,这是否意味着JAX也将成为下一个重要的深度学习框架?近日,在AssemblyAI博客上发表的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中,作者RyanO'Connor为我们深入解读了JAX的概念,使用JAX的原因,以及是否应该使用JAX。JAX简介JAX不是深度学习框架或库,也不是设计成这样的。简而言之,JAX是一个包含可组合函数转换的数值计算库。我们可以看到,深度学习只是JAX能力的一小部分:JAX的定位是科学计算(ScientificComputing)和函数转换(FunctionTransformations)的交叉融合,拥有训练深度学习模型以外的一系列能力,包括以下:即时编译自动并行化自动矢量化自动微分使用JAX的原因是什么?简而言之,速度。这是与任何用例相关的JAX的一般功能。让我们使用NumPy和JAX对矩阵的前三个幂(按元素)求和。首先是NumPy实现。我们发现这个计算大约需要851毫秒。然后使用JAX实现计算:JAX执行计算仅需5.54毫秒,比NumPy快150多倍。JAX比NumPy快N个数量级。注意JAX使用TPU而NumPy使用CPU,强调JAX的速度上限比NumPy高得多。作者列出了您可能想要使用JAX的以下六个原因:NumPyAccelerator。NumPy是使用Python进行科学计算的基础包之一,但它只与CPU兼容。JAX提供了一个非常容易在GPU和TPU上运行的NumPy实现(具有几乎相同的API)。对于许多用户来说,仅此一项功能就足以证明使用JAX是合理的;XLA。XLA(AcceleratedLinearAlgebra)是专为线性代数设计的全程序优化编译器。JAX建立在XLA之上,显着提高了计算速度的上限;准时制。JAX允许用户使用XLA将他们自己的函数转换为即时编译(JIT)版本。这意味着通过在计算函数中添加一个简单的函数装饰器,可以将计算速度提高几个数量级;自动微分。JAX结合了Autograd(自动区分原生Python代码和NumPy代码)和XLA,其自动区分能力在科学计算的许多领域都至关重要。JAX提供了几个强大的自动微分工具;深度学习。虽然JAX本身不是深度学习框架,但它确实为深度学习提供了良好的基础。许多基于JAX构建的库旨在提供深度学习功能,包括Flax、Haiku和Elegy。甚至一些最近的PyTorch和TensorFlow文章也强调JAX是一个值得关注的“框架”,并推荐它用于基于TPU的深度学习研究。JAX对Hessian的高效计算也与深度学习相关,因为它们使高阶优化技术更加可行;一般可微分编程范式。虽然我们可以使用JAX来构建和训练深度学习模型,但它也为通用的可微分编程提供了一个框架。这意味着JAX可以通过使用基于模型的机器学习方法来解决问题,该方法可以利用数十年来研究积累的给定领域的先验知识。JAX转换到目前为止,我们已经讨论了XLA以及它如何允许JAX在加速器上实现NumPy;但请记住,这只是JAX定义的一半。JAX不仅为强大的科学计算提供了工具,还为可组合的函数转换提供了工具。例如,如果我们将梯度函数变换应用于标量值函数f(x),那么我们将得到一个向量值函数f'(x),它给出函数在f(X)。在函数上使用grad()可以让我们获得域中任意点的梯度。JAX包含一个可扩展的系统来实现这种功能转换。典型的方式有四种:Grad()用于自动微分;Vmap()用于自动矢量化;Pmap()并行化计算;Jit()将函数转换为即时编译版本。使用grad()自动微分训练机器学习模型需要反向传播。在JAX中,就像在Autograd中一样,用户可以使用grad()函数来计算梯度。例如,下面是函数f(x)=abs(x^3)的导数。我们可以看到,在x=2和x=-3处评估函数及其导数时,我们得到了预期的结果。那么grad()到底能区分到什么程度呢?JAX通过重复应用grad()使微分变得容易,正如我们在下面的程序中看到的,输出函数的三阶导数给出恒定的预期输出f'''(x)=6。可能有人会问,grad()有什么用呢?标量值函数:grad()采用标量值函数的梯度,将标量/向量映射到标量函数。还有向量值函数:对于将向量映射到向量的向量值函数,梯度的类比是雅可比行列式。使用jacfwd()和jacrev(),JAX返回一个函数,当在域中的一个点计算时,生成雅可比行列式。从深度学习的角度来看,JAX使计算Hessians变得非常简单和高效。由于XLA,JAX可以比PyTorch更快地计算Hessians,这使得实现像AdaHessian这样的高阶优化更快。下面的代码在PyTorch中对输入的简单求和进行Hessian:如我们所见,上述计算大约需要16.3毫秒,在JAX中尝试相同的计算:使用JAX,计算仅需1.55毫秒,比PyTorch10x+快:JAX可以非常快速地计算Hessians,使高阶优化更加可行。使用vmap()自动矢量化JAX在其API中有另一个转换:vmap()自动矢量化。下面是向量化向量加法的演示:使用pmap()自动并行化分布式计算越来越重要,尤其是在深度学习中。如下图所示,SOTA模型已经发展到非常大的规模。得益于XLA,JAX可以轻松地在加速器上执行计算,但JAX也可以轻松地使用多个加速器进行计算,即通过单个命令-pmap()来执行SPMD程序的分布式训练。让我们以向量矩阵乘法为例,非并行向量矩阵乘法如下:使用JAX,我们可以通过简单地将操作包装在pmap()中,轻松地将这些计算分布在4个TPU上。这允许用户同时在每个TPU上执行点积,从而显着提高计算速度(对于大型计算)。使用jit()来加速功能性JIT编译是一种执行介于解释和AoT(提前)编译之间的代码的方法。重要的是,JIT编译器在运行时将代码编译成快速的可执行文件,代价是首次运行速度较慢。JIT不是一次将一个操作分配给GPU内核,而是使用XLA将一系列操作编译到单个内核中,从而为函数的端到端编译提供高效的XLA实现。例如,在下图中,代码定义了一个以三种方式计算5000x5000矩阵的函数-一次使用NumPy,一次使用JAX,一次在函数的JIT编译版本上使用JAX。我们首先在CPU上进行了实验:JAX在元素计算方面明显更快,尤其是在使用jit.我们看到JAX比NumPy快2.3倍以上,而当我们JIT函数时,JAX比NumPy快30倍。这些结果已经令人印象深刻,但让我们继续让JAX在TPU上进行计算:当在TPU上执行相同的计算时,JAX的相对性能会进一步提高(NumPy计算仍然在CPU上执行,因为它不支持TPU计算)在这种情况下,我们可以看到JAX比NumPy快13倍,如果我们同时在TPU上JIT函数和计算,我们会发现JAX比NumPy快80倍。当然,这种巨大的速度提升是有代价的。JAX对JIT允许的函数施加了限制,尽管通常只允许涉及上述NumPy操作的函数。此外,通过Python控制流进行JITing也有一些限制,因此在编写函数时必须牢记这一点。现在是2022年,我应该使用JAX吗?不幸的是,这个问题的答案仍然是“视情况而定”。是否迁移到JAX取决于您的情况和目标。为了具体分析2022年是否应该(或不应该)使用JAX,建议汇总到下面的流程图中,不同的图表针对不同的兴趣领域。科学计算如果您对通用计算中的JAX感兴趣,首先要问的问题是——您只是想在加速器上运行NumPy吗?如果答案是肯定的,那么您显然应该开始迁移到JAX。如果您不仅要处理数字,还要参与建模动态计算,那么您是否应该使用JAX将取决于具体用例。如果您的大部分工作是在Python中完成的,并带有大量自定义代码,那么开始学习JAX以增强您的工作流程是值得的。如果大部分工作不是用Python完成的,但你想要构建的是某种基于混合模型/神经网络的系统,那么使用JAX可能是值得的。如果你的大部分工作不使用Python,或者你正在使用一些专门的软件进行研究(热力学、半导体等),那么JAX可能不是正确的工具,除非你想从这些程序中导出数据来做自定义计算。如果您的兴趣领域更接近物理学/数学并且包括计算方法(动力系统、微分几何、统计物理学)并且您的大部分工作都在例如在大型自定义代码库的情况下。深度学习虽然我们强调了JAX不是为深度学习而构建的通用框架,但是JAX速度很快,并且具有自动微分性。您一定想知道使用JAX进行深度学习是什么感觉。如果你想在TPU上训练,那么你应该开始使用JAX,特别是如果你目前正在使用PyTorch。虽然有PyTorch-XLA,但使用JAX进行TPU训练绝对是更好的体验。如果您正在研究“非标准”架构/建模,例如SDE-Nets,您绝对应该也尝试一下JAX。此外,如果您想利用高级优化技术,JAX也是可以尝试的。如果你不是在构建临时架构,只是在GPU上训练通用架构,那么你现在应该坚持使用PyTorch或TensorFlow。然而,这一建议可能会在未来一两年内迅速改变。虽然PyTorch仍然主导着研究领域,但使用JAX的论文数量一直在稳步增长。随着重量级DeepMind和谷歌继续为JAX开发高级深度学习API,JAX可能会在几年内出现爆炸式增长。这意味着您至少应该对JAX有一点熟悉,特别是如果您是一名研究人员。初学者的深度学习但是如果我只是初学者怎么办?情况会有所不同。如果你有兴趣了解深度学习和实现一些想法,你应该使用JAX或PyTorch。如果你想自上而下地学习深度学习,或者对Python软件有一些经验,你应该从PyTorch开始。如果您想自下而上地学习深度学习,或者具有数学背景,您可能会发现JAX很直观。在这种情况下,请确保在开始任何大型项目之前了解如何使用JAX。如果你对深度学习感兴趣,想转行,那么你需要用到PyTorch或TensorFlow。虽然最好熟悉这两个框架,但你必须知道TensorFlow被广泛认为是“行业”框架,不同框架的职位发布数量证明了这一点:初学者,那么你不想使用JAX。相反,Keras是更好的选择。您不应该使用JAX的四个原因虽然上面已经讨论了对JAX的积极反馈,但它有可能大大提高用户程序的性能。但作者也列举了以下四个不应该使用JAX的理由:JAX仍然被官方认为是一个实验性框架。JAX是一个相对“年轻”的项目。目前,JAX仍被视为一个研究项目,而不是成熟的Google产品,所以如果您正在考虑迁移到JAX,请记住这一点;勤奋地使用JAX。调试的时间成本,或者更严重的是,未跟踪的副作用(untrackedsideeffects)的风险可能使JAX不适合没有扎实掌握函数式编程的用户。在开始将它用于一个严肃的项目之前,请确保您了解使用JAX的常见陷阱;JAX并未针对CPU计算进行优化。鉴于JAX是以“加速器优先”的方式开发的,每个操作的分派并未针对JAX进行完全优化。在某些情况下,NumPy实际上可能比JAX更快,特别是对于小程序,因为JAX引入了开销;JAX与Windows不兼容。Windows当前不支持JAX。如果您使用的是Windows但仍想尝试JAX,则可以使用Colab或将其安装在虚拟机(VM)上。