当前位置: 首页 > 科技观察

被PyTorch震撼!谷歌放弃了TensorFlow,押注于JAX

时间:2023-03-20 19:53:31 科技观察

很喜欢有网友的话:“这小子真不行,再养一个吧。”谷歌确实做到了这一点。养了七年的TensorFlow,终于在一定程度上被Meta的PyTorch打下了基础。谷歌见不对劲,赶紧又要了一个——“JAX”,一个全新的机器学习框架。大家都知道最近超火的DALL·EMini。其模型基于JAX编程,从而充分利用了谷歌TPU带来的优势。TensorFlow的暮光与PyTorch的崛起2015年,谷歌开发的机器学习框架TensorFlow面世。当时,TensorFlow还只是GoogleBrain的一个小项目。没有人想到TensorFlow一出来就大受欢迎。Uber和Airbnb等大公司以及NASA等国家机构都在使用它。它们都用于最复杂的项目。截至2020年11月,TensorFlow已被下载1.6亿次。不过,谷歌似乎并不太在乎这么多用户的感受。千奇百怪的界面和频繁的更新让TensorFlow越来越不友好,也越来越难操作。甚至,即使在谷歌内部,也感觉这个框架正在走下坡路。其实谷歌这么频繁的更新真的很无奈。毕竟只有这样才能赶上机器学习领域的快速迭代。于是,加入项目的人越来越多,导致整个团队慢慢失去了重心。而那些原本让TensorFlow成为首选工具的闪光点,也已经被淹没在浩瀚的元素之中,不再被人们所重视。这种现象被Insider描述为“猫捉老鼠的游戏”。企业就像猫,不断涌现的新需求就像老鼠。猫要时刻保持警惕,随时扑向老鼠。对于最先进入某个市场的企业来说,这种困境是无法避免的。例如,说到搜索引擎,谷歌并不是第一。所以谷歌可以从其前辈(AltaVista、雅虎等)的失败中吸取教训,并将其应用到自己的发展中。遗憾的是,说到TensorFlow,被坑的是谷歌。也正是因为以上原因,让原本在谷歌工作的开发者逐渐对自己的老东家失去了信心。过去风靡一时的TensorFlow逐渐没落,输给了后起之秀Meta-PyTorch。2017年,PyTorch的测试版开源。2018年,Facebook的人工智能研究实验室发布了完整版的PyTorch。值得一提的是,PyTorch和TensorFlow都是基于Python开发的,而Meta更注重维护开源社区,甚至投入了大量资源。而且,Meta已经关注到谷歌的问题,认为不能重蹈覆辙。他们专注于一小部分功能,并使它们成为最好的。Meta并没有追随谷歌的脚步。这个框架最初是在Facebook开发的,后来逐渐成为行业标杆。一家机器学习初创公司的研究工程师说,“我们基本上用的是PyTorch,它的社区和开源是最好的,它不仅能解答问题,还能给出实际的例子。”面对这种情况,谷歌的开发者、硬件专家、云提供商,以及所有参与谷歌机器学习的人在采访中都说了同样的话,他们认为TensorFlow已经失去了开发者的心。经过一系列的内斗,Meta终于占了上风。有专家表示,谷歌未来继续引领机器学习的机会正在慢慢消失。PyTorch逐渐成为普通开发人员和研究人员的首选工具。从StackOverflow提供的交互数据来看,开发者论坛上关于PyTorch的问题越来越多,而TensorFlow近几年一直停滞不前。甚至文章开头提到的Uber等公司也转向了PyTorch。甚至,PyTorch的每一次后续更新似乎都在打TensorFlow的脸。谷歌机器学习的未来——JAX就在TensorFlow和PyTorch如火如荼的时候,谷歌内部的一个“小黑马研究团队”开始着手开发一个可以更方便地利用TPU的新框架。2018年,RoyFrostig、MatthewJamesJohnson和ChrisLeary发表了一篇名为《Compiling machine learning programs via high-level tracing》的论文,使JAX项目浮出水面。从左到右分别是三位大神,然后是PyTorch的原作者之一AdamPaszke,他也在2020年初全职加入了JAX团队。JAX为机器学习中最复杂的问题之一提供了更直接的方法:多核处理器调度。根据应用程序,JAX会自动将多个芯片组合成一个小组,而不是让一个芯片单独运行。这样做带来的好处是可以在短时间内响应尽可能多的TPU,从而燃烧我们的“炼金小宇宙”。最终,相比臃肿的TensorFlow,JAX解决了Google内部的一个重大难题:如何快速接入TPU。下面简单介绍一下构成JAX的Autograd和XLA。Autograd主要用于基于梯度的优化,可以自动区分Python和Numpy代码。它既可以用于处理Python的一个子集,包括循环、递归和闭包,也可以用于求导数的导数。此外,Autograd支持梯度的反向传播,这意味着它可以高效地获取标量值函数相对于数组值参数的梯度,以及前向模式微分,两者可以任意组合。XLA(AcceleratedLinearAlgebra)可以在不改变源代码的情况下加速TensorFlow模型。当一个程序运行时,所有的操作都是由执行者单独完成的。每个操作都有一个预编译的GPU内核实现,执行程序被分派到该内核实现。例如:defmodel_fn(x,y,z):returntf.reduce_sum(x+y*z)在没有XLA的情况下运行,这部分启动三个核心:一个用于乘法,一个用于加法,一个用于减法。另一方面,XLA通过将加法、乘法和减法“融合”到单个GPU内核中来实现优化。这种融合操作不会将内存产生的中间值写入y*z内存x+y*z;相反,它将这些中间计算的结果直接“流”给用户,同时将它们完全保存在GPU中。在实践中,XLA可以实现约7倍的性能提升和约5倍的批量大小提升。此外,XLA和Autograd可以任意组合,甚至可以使用pmap方法同时使用多个GPU或TPU内核进行编程。如果将JAX与Autograd和Numpy结合使用,您可以获得一个易于编程的高性能机器学习系统,适用于CPU、GPU和TPU。显然,谷歌这次吸取了教训。除了在国内全面铺开,还特别积极推动开源生态的建设。2020年,DeepMind正式投入JAX的怀抱,而这也宣告了谷歌自己的终结。此后,各种开源库层出不穷。纵观整个“明争暗斗”,贾扬清表示,在批判TensorFlow的过程中,AI系统认为Pythonic的科学研究是它所需要的。但一方面,纯Python无法实现高效的软硬件协同设计,另一方面,上层的分布式系统仍然需要高效的抽象。而JAX正在寻求更好的平衡。谷歌甘于自我颠覆的实用主义值得我们学习。causactR包和相关贝叶斯分析教科书的作者说,他很高兴看到谷歌从TF过渡到JAX,这是一种更清洁的解决方案。Google的挑战作为菜鸟,Jax虽然可以借鉴PyTorch和TensorFlow这两位老前辈的优点,但有时也可能带来缺点。首先,JAX还是太“年轻”了。作为一个实验性的框架,它远没有达到成熟的谷歌产品的标准。除了各种隐藏的bug,JAX在一些问题上还依赖于其他框架。对于加载和预处理数据,TensorFlow或PyTorch将处理大部分设置。显然,这远非理想的“一站式”框架。其次,JAX主要针对TPU进行了高度优化,但在GPU和CPU方面就差很多了。一方面,谷歌在2018年到2021年的组织和战略混乱,导致GPU支持的研发资金不足,相关问题处理的优先级不高。同时,它大概也过于专注于让自家的TPU在AI加速上分享更多的蛋糕,与Nvidia的合作自然是十分稀少,更别说提升GPU支持的细节了。另一方面,谷歌自己的内部研究,不用说,都集中在TPU上,这导致谷歌失去了对GPU使用的良好反馈循环。此外,更长的调试时间、不兼容Windows、未跟踪副作用的风险等,都增加了Jax的使用门槛和友好性。现在,PyTorch已经快6岁了,但它并没有TensorFlow当年表现出的衰退。从这一点来看,贾克斯要想赶上后来者,还有很长的路要走。