当前位置: 首页 > 科技观察

使用不到256KB内存实现边缘训练,开销不到PyTorch的千分之一

时间:2023-03-14 16:56:29 科技观察

说到神经网络训练,大家的第一印象就是GPU+服务器+云平台。由于内存开销巨大,传统训练往往在云端进行,边缘平台只负责推理。然而,这样的设计让AI模型很难适应新的数据:毕竟现实世界是一个动态的、变化的、发展的场景。一次培训如何涵盖所有场景?为了让模型不断适应新的数据,是否可以在边缘进行on-device训练,让设备不断地自我学习?在这项工作中,我们使用不到256KB的内存实现了设备上训练,开销不到PyTorch的1/1000,同时在VisualWakeWord任务(VWW)上实现了云端训练的准确性。这种技术使模型能够适应新的传感器数据。用户无需将数据上传至云端即可享受定制服务,从而保护隐私。网站:https://tinytraining.mit.edu/论文:https://arxiv.org/abs/2206.15472Demo:https://www.bilibili.com/video/BV1qv4y1d7MV代码:https://github.com/mit-han-lab/tiny-training背景On-deviceTraining允许预先训练的模型在部署后适应新环境。通过移动端的本地训练和适配,模型可以不断改进其结果并为用户定制模型。例如,微调语言模型可以让它从输入历史中学习;调整视觉模型可以让智能相机持续识别新物体。通过让训练更靠近终端而不是云端,我们可以在保护用户隐私的同时有效提高模型质量,尤其是在处理医疗数据和输入历史等隐私信息时。然而,在小型物联网设备上进行训练与云端训练有着根本的不同,并且非常具有挑战性。首先,AIoT设备(MCU)的SRAM大小通常有限(256KB)。这种程度的记忆对于推理来说都非常困难,更不用说训练了。再者,现有的低成本、高效率的迁移学习算法,比如只训练最后一个分类器(lastFC),只学习biasitems,往往精度不理想,无法在实践中使用,更不用说目前的一些深度学习框架了未能将这些算法的理论数字转化为测量的节省。最后,现代深度训练框架(PyTorch、TensorFlow)通常是为云服务器设计的,即使将batch-size设置为1,训练小模型(MobileNetV2-w0.35)也需要大量内存占用。因此,我们需要协同设计算法和系统,以实现在智能终端设备上的训练。方法和结果我们发现设备上训练有两个独特的挑战:(1)模型在边缘设备上量化。由于低精度张量和缺少批量归一化层,真正量化的图(如下所示)很难优化;(2)小型硬件有限的硬件资源(内存和计算)不允许完全反向传播,这内存使用很容易超过微控制器SRAM的限制(超过一个数量级),但如果只有最后一层是更新了,最后的精度难免会差强人意。为了解决优化的困难,我们提出了量化感知缩放(QAS)来自动缩放具有不同位精度的张量的梯度(如下左图所示)。QAS可以自动匹配梯度和参数尺度并稳定训练,而无需额外的超参数。在8个数据集上,QAS都能达到与浮点训练一致的性能(如下右图)。为了减少反向传播所需的内存占用,我们提出稀疏更新来跳过不太重要的层和子表的梯度计算。我们开发了一种基于贡献分析的自动方法来寻找最佳更新方案。与之前的仅偏置、最后k层更新相比,我们搜索的稀疏更新方案节省了4.5到7.5倍的内存,并且在8个下游数据集上的平均精度更高。为了将算法的理论减少转化为实数,我们设计了TinyTrainingEngine(TTE):它将自动微分的工作卸载到编译时,并使用代码生成来减少运行时开销。它还支持图形修剪和重新排序,以实现真正的节省和加速。与FullUpdate相比,SparseUpdate有效减少了7-9倍的峰值内存,并且可以通过重新排序进一步提高到20-21倍的总内存节省。与TF-Lite相比,TTE中优化的内核和稀疏更新使整体训练速度提升了23-25倍。结论在本文中,我们提出了第一个在微控制器(仅256KB内存和1MB闪存)上实施训练的解决方案。我们的算法系统协同设计(System-AlgorithmCo-design)大大减少了训练所需的内存(1000次vsPyTorch)和训练时间(20次vsTF-Lite),并在下游任务上达到更高的准确率。TinyTraining可以启用许多有趣的应用程序。例如,手机可以根据用户邮件/输入历史自定义语言模型,智能相机可以不断识别新面孔/物体,一些无法联网的AI场景也可以继续学习(如农业、海洋、工业流水线)。通过我们的工作,小型终端设备不仅可以进行推理,还可以进行训练。在这个过程中,个人数据永远不会上传到云端,因此不存在隐私风险,同时AI模型可以不断自我学习,以适应动态变化的世界!