当前位置: 首页 > 科技观察

秒秒钟揪出张量形状错误,这个工具能防止ML模型训练白忙一场

时间:2023-03-21 16:57:16 科技观察

在几秒钟内找出张量形状错误。这个工具可以防止ML模型训练被浪费时间。模型训练了半天,结果发现是张量形状定义错误,一定把你逼疯了。那么对于这种情况,有没有更好的解决办法呢?不久前,韩国首尔国立大学的研究人员研制出了一种“利器”——PyTea。据研究人员称,它可以帮助您在训练模型之前的几秒钟内静态分析潜在的张量形状错误。那么PyTea是怎么做到的,靠谱与否,让我们一起来了解一下。PyTea怎么玩?为什么张量形状误差如此重要?神经网络涉及一系列矩阵计算。前矩阵中的列数必须与后矩阵中的行数匹配。如果尺寸不匹配,则后续操作将不起作用。上面的代码是一个典型的张量形状错误,[Bx120]*[80x10]不能进行矩阵运算。不管是PyTorch、TensorFlow还是Keras,在训练神经网络的时候,大部分都是按照图上的流程进行的。先定义一系列神经网络层(也就是矩阵),然后合成神经网络模块……那么为什么需要PyTea呢?以往,当模型读取大量数据,开始训练,代码运行到错误的张量时,我们只能发现张量形状定义错误。由于模型可能非常复杂,训练数据非常庞大,因此查找错误的时间成本会非常高。有时候代码是放在后台训练的,出问题了你不知道……PyTea可以有效的帮助我们避免这个问题,因为它可以在运行模型代码之前,帮助我们分析形状错误。网友们已经在热烈讨论了。PyTea是如何工作的,它能有效地检测错误吗?受各种约束的影响,代码可能的运行路径有很多,不同的数据会走不同的路径。所以PyTea需要静态扫描所有可能的运行路径,跟踪张量变化,推导出每个张量形状的精确保守范围。上图展示了PyTea的整体结构,分为翻译语言、收集约束、求解器判断和反馈四个步骤。首先PyTea将原始Python代码翻译成内核语言。PyTea内部表示(PyTeaIR)。PyTea然后跟踪PyTeaIR的每个可能执行路径并收集对张量形状的约束。判断是否满足约束分为在线分析和离线分析两步:在线分析node.js(TypeScript/JavaScript):发现tensorshape值不匹配和API函数的误用。如果PyTea发现问题,它会停在当前位置并向用户报告错误。离线分析Z3/Python:如果在线分析没有问题,PyTea会将收集到的约束传递给SMT(SatisfiabilityModuloTheories)求解器Z3,求解器负责检查每条路径的约束是否可以满足,如果不是,则向用户返回第一个错误路径的约束。如果求解器长时间没有响应,PyTea将不知道是否有问题返回。然而,跟踪所有可能的路径是一项指数级的任务。对于复杂的神经网络,肯定会出现路径爆炸的问题。例如,在这个例子中,网络的最终结构是由24个相同的模块组成的(第17行),那么可能的路径就有16M之多。所以必须处理路径爆炸。PyTea是如何做到的?PyTea选择保守的路径剪枝和超时判断来应对这种路径爆炸。什么样的路径可以被修剪?PyTea给出的答案是,如果前馈函数不改变全局值,其输出值不受分支条件的影响,且每条路径都相等,我们可以忽略很多完全一致的路径,以节省计算资源。如果路径剪枝仍然失败,那么只能按照timeout进行处理。原理就这么多,我觉得还是值得一试的,现在代码已经在GitHub上开源了,快去看看吧!使用依赖库:安装方式:运行命令: