当前位置: 首页 > 科技观察

BAIR最新的RL算法超越了GoogleDreamer,性能提升了2.8倍

时间:2023-03-16 23:53:39 科技观察

基于像素的RL算法。BAIR提出了一种结合了对比学习和RL的算法,其样本效率与基于状态的RL相当。这项研究的本质是回答一个问题——使用图像作为观察(基于像素)的RL能否与使用坐标状态作为观察的RL一样有效?传统上,人们普遍认为以图像作为观察的RL数据效率低下,通常需要1亿个交互步骤来解决像Atari游戏这样的基准任务。研究人员介绍了CURL:强化学习的无监督对比表示。CURL使用对比学习从原始像素中提取高阶特征,并在提取的特征之上进行离策略控制。CURL在DeepMindControlSuite和AtariGames中的复杂任务上优于以前基于像素的方法(基于模型和无模型),性能分别提高了2.8倍和1.6倍。在DeepMindControlSuite上,CURL是第一个基于图像的算法,其样本效率和性能几乎与基于状态的基于特征的方法相匹配。论文链接:https://arxiv.org/abs/2004.04136网址:https://mishalaskin.github.io/curl/GitHub链接:https://github.com/MishaLaskin/curl背景介绍CURL是比较组合学习和结合RL的通用框架。理论上,任何RL算法,无论是相同策略还是偏离策略,都可以在CURL管道中使用。对于连续控制基准(DMControl),研究团队使用了著名的SoftActor-Critic(SAC)(Haarnojaetal.,2018);对于离散控制基准(Atari),研究团队使用了RainbowDQN(Hesseletal.,2017))。下面,我们简要回顾一下SAC、RainbowDQN和对比学习。SoftActorCriticSAC是一种off-policyRL算法,可优化随机策略以最大化预期轨迹奖励。与其他最先进的端到端RL算法一样,SAC在解决状态观察任务方面非常有效,但无法从像素中学习有效策略。Rainbow将RainbowDQN(Hesseletal.,2017)总结为对NatureDQN(Mnihetal.,2015)原始应用的多项改进。具体来说,深度Q网络(DQN)(Mnih等人,2015年)将离策略Q学习算法与卷积神经网络相结合作为函数逼近器,将原始像素映射为动作值函数。除此之外,价值分布强化学习(Bellemareetal.,2017)提出了一种通过C51算法预测可能价值函数bins分布的技术。RainbowDQN将上述所有技术结合在一个单一的离策略算法中,以在Atari基准测试中实现最先进的样本效率。此外,Rainbow使用多步回归(Suttonetal.,1998)。对比学习CURL的一个关键部分是能够使用对比无监督学习来学习高维数据的丰富表示。对比学习可以理解为一种可微分的字典查找任务。给定查询q,键K={k_0,k_1,...},和显式K(关于q)P(K)=({k+},K{k+})分区,对比学习的目标是确保q比K{k+}中的任何键更匹配k+。在对比学习中,q、K、k+和K{k+}也分别称为anchors、targets、positives和negatives。CURL实现CURL通过在批量更新期间使用训练比较目标作为辅助损失函数,对底层RL算法的改变最小。在实验中,研究人员使用两种无模型RL算法训练CURL——用于DMControl实验的SAC和用于Atari实验的RainbowDQN。一般框架概述CURL使用类似于SimCLR、MoC和CPC的实例区分。大多数深度强化学习框架将一系列堆叠在一起的图像作为输入。因此,该算法在多个堆叠帧而不是单个图像实例中执行实例区分。研究人员发现,使用类似MoCo的动量编码管道(momentumencoding)来处理目标在RL中表现更好。最后,研究人员使用类似于CPC的双线性内积来处理InfoNCE分数方程,研究人员发现效果优于MoCo和SimCLR中的单位范数向量积。对比表示与RL算法一起训练,同时从对比目标和Q函数获得梯度。整体框架如下图所示。图2:CURL整体框架示意图。关于锚点的正负样本的判别目标选择是对比表示学习的关键组成部分之一。与同一图像上的图像补丁不同,经过判别式转换的图像实例使用InfoNCE损失项优化了简化的实例判别目标,并且需要对结构进行最少的调整。在RL设置中,选择更简化的判别目标有两个主要原因:鉴于RL算法的脆弱性,复杂的判别目标可能导致RL目标的不稳定。RL算法是在动态生成的数据集上训练的,复杂的判别目标会显着增加训练所需的时间。因此,CURL使用实例区分而不是补丁区分。我们可以将SimCLR和MoCo等对比实例判别设置视为最大化图像与其相应增强版本之间的公共信息。query-key-value对的生成类似于图像设置中的实例判别,其中anchors和正观察值是来自同一图像的两个不同的增强值,而负观察值来自其他图像。CURL主要依靠随机裁剪数据增强方法从原始渲染图像中随机裁剪出一个正方形补丁。研究人员对批次使用随机数据增强,但在同一堆栈的帧之间保持一致,以保留有关观察时间结构的信息。数据增强过程如图3所示。图3:使用随机裁剪生成锚点及其正样本的过程的可视化表示。相似性度量的区分目标的另一个决定因素是用于度量查询键对的内部产品。CURL使用双线性内积sim(q,k)=q^TW_k,其中W是学习的参数矩阵。研究小组发现,这种相似性度量优于最近在计算机视觉领域最先进的对比学习方法(如MoCo和SimCLR)中使用的归一化点积。MomentumGoalEncoding在CURL中使用对比学习的目标是训练一个编码器,从高维像素映射到更多语义隐藏状态。InfoNCE是一种无监督损失,它通过学习编码器f_q和f_k将原始锚点(查询)x_q和目标(关键字)x_k映射到潜在值q=f_q(x_q)和k=f_k(x_k),其中团队应用相似性点积。通常在锚图和目标图之间共享相同的编码器,即f_q=f_k。CURL将帧堆栈实例识别与目标的动量编码相结合,而RL在编码器特征之上执行。CURL对比学习伪代码(PyTorch风格)实验。研究人员评估(i)样本效率,通过测量最佳性能基线需要多少交互步骤才能匹配100k交互步骤的CURL性能,以及(ii)通过测量100k步的性能水平,通过测量循环返回值实现CURL与性能最佳的基线相比。换句话说,在谈论数据或样本效率时,您指的是(i),而在谈论性能时,您指的是(ii)。DMControl实验中DMControl的主要发现:CURL是一种SOTAImageBasedRL算法,我们在每个DMControl环境上进行基准测试,以针对现有的基于图像的基准进行采样效率测试。在DMControl100k上,CURL比Dreamer(Hafner等人,2019年)高出2.8倍,这是一种领先的基于模型的方法,数据效率高出9.9倍。从图7所示的16个DMControl环境中的大多数状态开始,仅像素CURL几乎匹配(有时超过)SAC的采样效率。它是基于模型的,无模型的,有或没有辅助任务。在500,000步内,CURL解决了16个DMControl实验中的大部分(收敛到接近1000的最佳分数)。它仅需100,000步就可与类似SOTA的性能相媲美,并且大大优于该方案中的其他方法。表1.使用500k(DMControl500k)和100k(DMControl100k)环境步长基准在CURL和DMControl基准上获得的基线分数。图4.10个种子的CURL耦合SAC性能相对于SLAC、PlaNet、PixelSAC和StateSAC基线的平均值。图6.领先的基于像素的方法Dreamer获得与CURL在100k训练步骤中得分相同的分数所需的步骤数。图7.将CURL与基于状态的SAC进行比较,在16个选定的DMControl环境中的每一个上运行2个种子。Atari在Atari实验中的主要发现:就26个Atari100k实验中的大多数数据效率而言,CURL是SOTAPixelBasedRL算法。平均而言,在Atari100k上,CURL比SimPLe高出1.6倍,比EfficientRainbowDQN高出2.5倍。CURL的人类归一化分数(HNS)达到24%,而SimPLe和EfficientRainbowDQN分别达到13.5%和14.7%。CURL、SimPLe和EfficientRainbowDQN的平均HNS分别为37.3%、39%和23.8%。CURL在JamesBond(98.4%HNS)、Freeway(94.2%HNS)和RoadRunner(86.5%HNS)这三款游戏上的效率几乎与人类相当,在所有基于像素的RL算法中排名第一。表2.通过CURL和100k时间步(Atari100k)获得的分数。CURL在26个环境中的14个上实现了SOTA。项目介绍安装所有相关项目都在conda_env.yml文件中。它们可以手动安装或使用以下命令安装:condaenvcreate-fconda_env.yml说明要根据基于图像的观察训练CURL代理以完成cartpoleswingup任务,请从该目录的根目录运行bashscript/run.sh。run.sh文件包含以下命令,也可以修改这些命令以尝试不同的环境/超参数。CUDA_VISIBLE_DEVICES=0pythontrain.py--domain_namecartpole--task_nameswingup--encoder_typepixel--action_repeat8--save_tb--pre_transform_image_size100--image_size84--work_dir./tmp--agentcurl_sac--frame_stack3--seed-1--critic_lr1e-3--actor_lr1e-3--eval_freq10000--batch_size128--num_train_steps1000000在控制台中,您应该看到如下输出:|train|E:221|S:28000|D:18.1s|R:785.2634|BR:3.8815|A_LOSS:-305.7328|CR_LOSS:190.9854|CU_LOSS:0.0000|火车|E:225|S:28500|D:18.6s|R:832.4937|BR:3.9644|A_LOSS:-308.7789|CR_LOSS:126.0638|CU_LOSS:0.0000|火车|E:229|S:29000|D:18.8s|R:683.6702|BR:3.7384|A_LOSS:-311.3941|CR_LOSS:140.2573|CU_LOSS:0.0000|train|E:233|S:29500|D:19.6s|R:838.0947|BR:3.7254|A_LOSS:-316.9415|CR_LOSS:136.5304|CU_LOSS:0.0000cartpole摆动的最高分数约为845分。此外,CURL如何在不到50k步内解决视觉推车问题。根据用户的GPU,训练大约需要一个小时。同样作为参考,最新的端到端方法D4PG需要50M时间步来解决同样的问题。Logabbreviationmapping:train-trainingepisodeE-totalnumberofepisodesS-totalnumberofenvironmentstepsD-durationinsecondstotrain1episodeR-meanepisoderrewardBR-averagewardofsampledbatchA_LOSS-averagelossofactorCR_LOSS-averagelossofcriticCU_LOSS-averagelossoftheCURLencoder存储在指定运行相关的所有数据。要启用模型或视频保存,请使用--save_model或--save_video。对于所有可用的标志,请检查train.py。使用tensorboard运行可视化:tensorboard--logdirlog--port6006,同时在浏览器中转到localhost:6006。如果它不起作用,请尝试使用ssh进行端口转发。要使用GPU加速进行渲染,请确保您的计算机上安装了EGL并设置exportMUJOCO_GL=egl。