本文经AI新媒体量子比特(公众号ID:QbitAI)授权转载,转载请联系出处。收获近16.6万颗Star、见证深度学习崛起的TensorFlow岌岌可危。而这一次,冲击不是来自老对手PyTorch,而是来自自家的菜鸟JAX。在最近一波AI圈的热议中,就连fast.ai的创始人JeremyHoward也表示:JAX正在逐步取代早已广为人知的TensorFlow。它现在正在发生(至少在谷歌内部)。LeCun甚至认为,深度学习框架之间的激烈竞争已经进入了一个新的阶段。LeCun表示,谷歌的TensorFlow确实比Torch更受欢迎。但是,在Meta的PyTorch出现之后,其受欢迎程度现在已经超过了TensorFlow。现在,包括GoogleBrain、DeepMind和许多外部项目,都开始使用JAX。一个典型的例子就是最近流行的DALL·EMini。为了充分利用TPU,作者使用JAX进行编程。有人用过后感叹:这比PyTorch快多了。据《商业内幕》介绍,预计在未来几年内,JAX将覆盖谷歌所有使用机器学习技术的产品。从这个角度来看,现在在内部大力推广JAX,更像是谷歌在框架上发起的“自救”。JAX从何而来?关于JAX,谷歌其实是有备而来的。早在2018年,它就由GoogleBrain的一个三人小团队打造。研究成果发表在论文《Compilingmachinelearningprogramsviahigh-leveltracing:JaxisaPythonlibraryforhigh-performancenumericalcomputing,anddeeplearningisonlyanfunctions》。它的受欢迎程度自成立以来一直在上升。最大的特点是快。一个例子来感受一下。例如,要求矩阵的前三次幂之和,使用NumPy,计算大约需要478毫秒。使用JAX只需5.54毫秒,比NumPy快86倍。为什么这么快有很多原因,包括:1.NumPy加速器。NumPy的重要性不用多说,没有人能用Python做科学计算和机器学习,只是它还没有原生支持GPU等硬件加速。JAX的计算函数API全部基于NumPy,可以让模型轻松跑在GPU和TPU上。这一点抓住了很多人。2.XLA。XLA(AcceleratedLinearAlgebra)是加速线性代数,一个优化编译器。JAX建立在XLA之上,大大提高了JAX计算速度的上限。3.准时??制。研究人员可以使用XLA将自己的函数转换为即时编译(JIT)版本,这相当于通过在计算函数中添加一个简单的函数修饰符来将计算速度提高几个数量级。此外,JAX完全兼容Autograd,支持自动微分,可以通过grad、hessian、jacfwd、jacrev等函数进行转换,支持反向模式和正向模式微分,两者可以任意顺序组合。当然,JAX也有一些缺点。例如:1.JAX虽然号称是加速器,但并没有针对CPU计算中的每一个操作都进行充分的优化。2.JAX还是太新了,无法像TensorFlow那样形成完整的基础生态。所以谷歌还没有以成品的形式推出。3、调试所需的时间和费用不确定,“副作用”不完全清楚。4、不支持Windows系统,只能在上述虚拟环境下运行。5.如果没有dataloader,就得借用TensorFlow或者PyTorch。……尽管如此,简单、灵活、易用的JAX还是率先在DeepMind流行起来。2020年诞生的一些深度学习库Haiku和RLax就是基于它开发的。这一年,PyTorch的原作者之一AdamPaszke也全职加入了JAX团队。目前,JAX的开源项目在GitHub上有18.4kstars,远高于TensorFlow。值得注意的是,在此期间,有不少声音表示它很有可能取代TensorFlow。一方面是因为JAX的强大,另一方面主要和TensorFlow本身的很多原因有关。为什么Google转向JAX?诞生于2015年的TensorFlow一度风靡一时。上线后,很快就超越了Torch、Theano、Caffe,成为最流行的机器学习框架。然而在2017年,焕然一新的PyTorch“卷土重来”。这是Meta基于Torch构建的机器学习库。由于其简单易用、易于理解,迅速受到众多研究人员的青睐,甚至有超越TensorFlow的趋势。相比之下,TensorFlow越来越臃肿,更新频繁,界面迭代频繁,逐渐失去了开发者的信任。(从StackOverflow上的题目占比来看,PyTorch逐年上升,而TensorFlow却停滞不前。)在比赛中,TensorFlow的缺点逐渐暴露出来,API不稳定,实现复杂,学习成本高.无论更新解决多少,结构都会变得更加复杂。相比之下,TensorFlow并没有继续发挥其更强大的“运行效率”优势。在学术界,PyTorch的使用率正在逐渐超越TensorFlow。尤其是在ACL、ICLR等大型会议中,PyTorch实现的算法框架近年来占比超过80%。相比之下,TensorFlow的使用率仍在下降。也正是因为如此,谷歌也坐不住了,试图利用JAX重新夺回机器学习框架的“霸主地位”。虽然JAX名义上并不是“为深度学习构建的通用框架”,但谷歌的资源自其成立以来就一直倾向于JAX。一方面,GoogleBrain和DeepMind正在逐渐在JAX上构建更多的库。包括GoogleBrain的Trax、Flax、Jax-md,以及DeepMind的神经网络库Haiku和强化学习库RLax等,都是基于JAX的。按照谷歌官方的说法:在JAX生态系统的发展中,也会考虑确保其与现有TensorFlow库(如Sonnet和TRFL)的设计(尽可能)保持一致。另一方面,越来越多的项目开始基于JAX实现,最近流行的DALL·E迷你项目就是其中之一。由于更好地利用了谷歌TPU的优势,JAX在运行性能上远优于PyTorch,更多之前基于TensorFlow构建的工业项目也转而使用JAX。甚至有网友调侃JAX之所以如此火爆:可能是TensorFlow的用户实在受不了这个框架吧。那么,JAX是否有希望取代TensorFlow,成为与PyTorch竞争的生力军呢?你更喜欢哪个框架?总体来说,很多人还是坚定地站在PyTorch上。他们似乎不喜欢谷歌每年发布新框架的速度。“虽然JAX很吸引人,但它还没有‘革命性’到让大家放弃PyTorch而使用它。”但是看好JAX的人也不在少数。有人说PyTorch是完美的,但JAX也在缩小差距。甚至有人疯狂称呼JAX,说它比PyTorch强大10倍,并表示:如果Meta不继续努力,谷歌会赢。(手动狗头)不过,总有人不太在意谁输谁赢。他们的愿景非常长远:没有最好,只有更好。最重要的是更多的玩家和好点子都参与进来,这样才能把开源和真正优秀的创新划上等号。项目地址:https://github.com/google/jax
