当前位置: 首页 > 科技观察

TensorFlow、PyTorch和JAX:哪种深度学习框架适合您?_0

时间:2023-03-12 04:14:20 科技观察

翻译|审稿人朱宪忠|大多数深度学习每天都以各种形式影响着我们的生活。无论是Siri、Alexa,基于用户语音命令的手机实时翻译应用,还是支持智能拖拉机、仓库机器人和自动驾驶汽车的计算机视觉技术,每个月似乎都在迎来新的发展。几乎所有这些深度学习应用程序都是使用以下三种框架之一编写的:TensorFlow、PyTorch或JAX。那么,您应该使用哪些深度学习框架?在本文中,我们将对TensorFlow、PyTorch和JAX进行高级比较。我们的目标是让您了解发挥其优势的应用程序类型,同时考虑到社区支持和易用性等当然因素。你应该使用TensorFlow吗?“没有人会因为购买IBM而被解雇”是1970年代和80年代计算世界的流行语。在2000年代初期使用TensorFlow进行深度学习也是如此。但众所周知,进入1990年代,IBM就“袖手旁观”。那么,TensorFlow在2015年首次发布7年之后的今天,以及新的十年,是否仍然具有竞争力?当然。TensorFlow并不总是停滞不前。首先,TensorFlow1.x以非Python的方式构建静态图,但在TensorFlow2.x中,还可以使用eager模式构建模型以立即评估操作,这让人感觉更像PyTorch。在高层,TensorFlow提供了Keras,方便开发;在底部,它提供了XLA(AcceleratedLinearAlgebra,加速线性代数)优化编译器来提高速度。XLA在提升GPU性能方面发挥着神奇的作用,它是利用谷歌TPU(张量处理单元,TensorProcessingUnits)能力的主要方式,为大规模模型训练提供无与伦比的性能。其次,多年来,TensorFlow一直试图在所有方面都做到最好。例如,您想在成熟平台上以定义明确且可重复的方式提供模型服务吗?TensorFlow已准备好提供服务。你想将模型部署重新用于低功耗计算,如网络、智能手机或物联网等资源受限的设备吗?在这一点上,TensorFlow.js和TensorFlowLite都非常成熟。显然,考虑到谷歌仍在100%使用TensorFlow来运行其生产部署,你可以肯定TensorFlow将能够满足用户的规模需求。不过,近期的项目中确实有一些因素不容忽视。简而言之,将项目从TensorFlow1.x升级到TensorFlow2.x是残酷的。考虑到更新代码以在新版本上正常工作所需的工作量,一些公司干脆决定将代码移植到PyTorch框架中。此外,TensorFlow在科学领域也失去了动力,几年前开始青睐PyTorch提供的灵活性,这导致研究论文中TensorFlow的使用持续下降。此外,“Keras事件”没有做任何事情。两年前,Keras成为TensorFlow发行版的一个组成部分,但最近被拉回到一个单独的库中,并有自己的发布计划。当然,排除Keras不会影响开发者的日常生活,但框架的小更新版本如此剧烈的变化,并不能激发程序员使用TensorFlow框架的信心。话虽如此,TensorFlow确实是一个坚实的框架,拥有广泛的深度学习生态系统,用户可以在TensorFlow上构建各种规模的应用程序和模型。如果是这样,将会有很多优秀的公司可以合作。但是今天,TensorFlow可能不是首选。你应该使用PyTorch吗?PyTorch不再是继TensorFlow之后的“新贵”,而是当今深度学习领域的主力军,可能主要用于研究,也越来越多地用于生产应用。随着Eager模式成为TensorFlow和PyTorch的默认开发方法,PyTorch的autograd提供的更加Pythonic的方法似乎正在赢得与静态图的战争。与TensorFlow不同的是,PyTorch的核心代码自0.4版本弃用变量API以来并未出现任何重大中断。以前,变量需要使用自动生成的张量,现在,一切都是张量。但这并不是说没有错误无处不在。例如,如果您一直在使用PyTorch跨多个GPU进行训练,您可能会遇到DataParallel和较新的DistributedDataParaller之间的差异。您应该始终使用DistributedDataParallel,但实际上没有什么可以反对使用DataParaller。虽然PyTorch在XLA/TPU支持方面一直落后于TensorFlow和JAX,但截至2022年,情况已经有了很大改善。PyTorch现在支持访问TPU虚拟机,支持遗留TPU节点支持,支持简单的命令行部署,无需更改代码即可在CPU、GPU或TPU上运行代码。如果您不想处理PyTorch经常让您编写的一些样板代码,您可以转向更高级别的扩展,例如PytorcheLightning,它可以让您专注于实际工作,而不是重写训练循环。另一方面,虽然PyTorchMobile的工作仍在继续,但它远不如TensorFlowLite成熟。在生产方面,PyTorch现在可以与框架无关的平台(如Kubeflow)集成,并且TorchServe项目处理部署细节,如缩放、指标和批处理推理——所有MLOps优点都在一个小包中,由PyTorch开发人员自己维护。另一方面,PyTorch是否支持缩放?没问题!Meta多年来一直在生产环境中运行PyTorch;所以任何告诉你PyTorch无法处理大规模工作负载的人都是谎言。不过,在某些情况下,PyTorch可能不如JAX友好,尤其是在需要大量GPU或TPU的非常繁重的训练方面。最后,还有一个让人不愿提及的棘手问题——PyTorch这几年的火爆,离不开HuggingFace的Transformers库的成功。是的,Transformers现在也支持TensorFlow和JAX,但它最初是一个PyTorch项目,仍然与框架紧耦合。随着Transformer架构的兴起、PyTorch在研究方面的灵活性,以及??通过HuggingFace的模型中心在发布后几天或几小时内引入如此多新模型的能力,很容易看出为什么PyTorch在这些领域处于领先地位。如此受欢迎。你应该使用JAX吗?如果您对TensorFlow不感兴趣,Google可能会为您提供其他内容。JAX是由Google构建、维护和使用的深度学习框架,但它不是Google的官方产品。但是,如果您留意过去一年左右的Google/DeepMind论文和产品发布,您会注意到Google的许多研究已经迁移到JAX。因此,虽然JAX不是Google的“官方”产品,但它是Google研究人员用来突破界限的东西。JAX到底是什么?理解JAX的一种简单方法是:想象一个GPU/TPU加速版本的NumPy,它可以用“一根魔杖”神奇地矢量化Python函数,并处理所有这些函数的导数计算。最后,它为XLA(AcceleratedLinearAlgebra,加速线性代数)编译器提供了用于获取代码并进行优化的即时(JIT:Just-In-Time)组件,从而大大提高了TensorFlow和火炬。目前一些代码的执行速度快了四到五倍,只需在JAX中重新实现它即可,而无需任何真正的优化工作。考虑到JAX在NumPy级别工作,JAX代码是在比TensorFlow/Keras(甚至PyTorch)低得多的级别编写的。令人高兴的是,有一个小型但不断发展的生态系统,其中包含一些围绕JAX的扩展。你想使用神经网络库吗?当然。其中包括来自谷歌的Flax,以及来自DeepMind(也是谷歌)的Haiku。此外,Optax可满足您所有的优化器需求,PIX可用于图像处理,等等。一旦使用了Flax之类的东西,构建神经网络就变得相对容易掌握。请注意,仍有一些轻微的扭结。例如,有经验的人经常谈论JAX如何以不同于许多其他框架的方式处理随机数。那么,您是否应该将所有内容都转换为JAX并利用这种尖端技术?这个问题因人而异。如果您深入研究需要大量资源进行训练的大型模型,建议使用此方法。另外,如果你在deterministictraining中关注JAX,以及其他需要上千个TPUPod的项目,那么也值得一试。总结那么,结论是什么?您应该使用哪种深度学习框架?不幸的是,这个问题没有单一的答案,这完全取决于您正在处理的问题类型、您计划部署模型来处理的规模,甚至您正在使用的计算平台。然而,如果你在文本和图像领域工作,并且正在进行中小型研究以在生产中部署这些模型,那么PyTorch可能是目前的最佳选择。从最近的版本来看,它是这类应用程序空间的最佳选择。如果您需要低计算设备的所有性能,我建议使用TensorFlow和极其强大的TensorFlowLite包。最后,如果您正在处理具有数百亿、数千亿或更多参数的训练模型,并且主要出于研究目的训练它们,那么可能是时候尝试一下JAX了。原文链接:https://www.infoworld.com/article/3670114/tensorflow-pytorch-and-jax-choosing-a-deep-learning-framework.html译者介绍朱宪忠,社区编辑,专家博主,讲师,潍坊大学计算机老师,自由编程的老手。