本文经人工智能新媒体量子比特(公众号ID:QbitAI)授权转载。转载请联系出处。DeepMind今天发布了两个基于JAX的库,Haiku和RLax。JAX由Google提出,是TensorFlow的简化库。将线性代数编译器XLA与自动区分本机Python和Numpy代码的库Autograd相结合,用于高性能机器学习研究。此次发布的两个库,分别针对神经网络和强化学习,大大简化了JAX的使用。Haiku是一个基于JAX的神经网络库,允许用户使用熟悉的面向对象编程模型,完全访问JAX的纯函数转换。RLax是JAX之上的一个库,它为实现强化学习代理提供了有用的构建块。有趣的是,Reddit网友惊讶地发现Haiku库的名字并不以“ax”结尾。当然,也有网友对这两个库表示了肯定:毫无疑问,它们对JAX起到了一定的推动作用。那么,让我们来看看Haiku和RLex的真面目吧。HaikuHaiku是JAX的神经网络库,允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。它提供了两个核心工具:模块抽象hk.Module,和简单的函数转换hk.transform。hk.Module是一个Python对象,包含对其自身参数、其他模块和将函数应用于用户输入的方法的引用。hk.transform允许完全访问JAX纯函数转换。实际上,JAX中有很多神经网络库,那么Haiku有什么特别之处呢?有5分。1.Haiku已经过DeepMind研究人员的大规模测试。DeepMind已经相对轻松地在Haiku和JAX中复制了许多实验。其中包括图像和语言处理、生成模型和强化学习方面的大规模成果。2.Haiku是一个库,而不是一个框架。它旨在简化一些特定的事情,包括管理模型参数和其他模型状态。可以与其他库一起编写并与JAX的其他部分一起使用。3.俳句不是从零开始的。它建立在Sonnet的编程模型和API之上。Sonnet是DeepMind几乎普遍使用的神经网络库。它保留了用于状态管理的Sonnet基于模块的编程模型,同时保留了对JAX函数转换的访问。4.相对容易过渡到Haiku通过精心设计,从TensorFlow和Sonnet过渡到JAX和Haiku相对容易。除了新功能(如hk.transform),Haiku的目标是成为Sonnet2API。5.Haiku简化了JAX,它提供了一个处理随机数的简单模型。在转换后的函数中,hk.next_rng_key()返回一个唯一的rng键。那么,如何安装Haiku?Haiku是用纯Python编写的,但通过JAX依赖于C++代码。首先,按照以下链接中的说明安装具有相关加速器支持的JAX。https://github.com/google/jax#installation然后,只需要一个简单的pip命令就可以完成安装。$pipinstallgit+https://github.com/deepmind/haiku接下来是一个神经网络和损失函数的例子。importhaikuashkimportjax.numpyasjnpdefsoftmax_cross_entropy(logits,labels):one_hot=hk.one_hot(labels,logits.shape[-1])return-jnp.sum(jax.nn.log_softmax(logits)*one_hot,axis=-1)defloss_fn(images),标签):model=hk.Sequential([hk.Linear(1000),jax.nn.relu,hk.Linear(100),jax.nn.relu,hk.Linear(10),])logits=模型(images)returnjnp.mean(softmax_cross_entropy(logits,labels))loss_obj=hk.transform(loss_fn)RLaxRLax是一个基于JAX的库,它为实现强化学习代理提供了有用的构建块。它提供的运算和函数不是完整的算法,而是强化学习具体数学运算的实现。RLax的安装也很简单,一个pip命令就可以搞定。pipinstallgit+git://github.com/deepmind/rlax.git使用了JAX的jax.jit功能,所有RLax代码都可以在不同的硬件上编译。RLax需要注意的是它的命名规则。许多函数在连续的时间步长中考虑策略、动作、奖励和值,以便计算它们的输出。在这种情况下,后缀_t和tm1通常是用来表示每个输入是在哪一步生成的,例如:q_tm1:转换源状态中的操作值。a_tm1:在源状态选择的操作。r_t:在目标状态下收集的最终奖励。q_t:操作在目标状态下的值。Haiku和RLax都已经在GitHub上开源,感兴趣的读者可以通过“传送门”中的链接访问。传送门俳句:https://github.com/deepmind/haikuRLax:https://github.com/deepmind/rlax
