当前位置: 首页 > 科技观察

你玩的是农药大王,可是有人在用iPhone训练神经网络_0

时间:2023-03-12 19:46:25 科技观察

,你知道吗?LeNet卷积神经网络也可以直接在iOS设备上训练,性能一点都不差,iPhone和iPad也可以变成真正的生产力。为了在移动端应用机器学习,一般分为以下两个阶段。第一阶段是训练模型,第二阶段是部署模型。常规的做法是在强大的GPU或TPU上训练模型,然后通过一系列的模型压缩方法,将其转化为可以在移动端运行的模型,连接到APP上。CoreML主要解决最终的模型部署。它为开发者提供了一个方便的模型转换工具,可以方便地将训练好的模型转换成CoreML类型的模型文件,实现模型和APP数据。互通。以上是正常操作。不过随着iOS设备运算性能的提升,有传言称iPadPro的运算能力超过了普通笔记本。于是乎,就有这样一位“勇者”开源了一个可以直接在iOS设备上训练神经网络的项目。项目作者在macOS、iOS模拟器和真实iOS设备上进行了测试。使用60,000个MNIST样本训练了10个时期。在模型架构和训练参数完全一致的前提下,在iPhone11上用CoreML训练大约需要248秒,在i7MacBookPro上用TensorFlow2.0训练大约需要158秒(只用CPU),但是准确率超过0.98。当然,248秒和158秒之间还是有非常大的差距,但是这个实验的目的不是比较速度,而是探索用移动设备或者可穿戴设备进行本地训练的可行性,因为这些数据中的设备往往敏感且涉及隐私,本地培训可以提供更好的安全性。项目地址:https://github.com/JacopoMangiavacchi/MNIST-CoreML-TrainingMNIST数据集在这篇文章中,作者介绍了如何使用MNIST数据集部署图像分类模型。值得注意的是,这个CoreML模型是直接在iOS设备上训练的,没有事先在其他ML框架中训练。作者在这里使用了一个著名的数据集——MNIST手写数字数据集。提供了60000个训练样本和10000个测试样本,都是手写数字0到9的28x28黑白图像。初始点。LeNetCNN+MNIST数据集的组合是机器学习“训练”的标准组合,相当于深度学习图像分类的“Hello,World”。这篇文章主要介绍如何直接在iOS设备上为MNIST数据集构建和训练LeNetCNN模型。接下来,研究人员将其与基于TensorFlow等知名ML框架的经典“Python”实现进行比较。在Swift中为CoreML训练准备数据在讨论如何在CoreML中创建和训练LeNetCNN网络之前,让我们看一下如何准备MNIST训练数据,以便它可以正确地批处理到CoreML运行中。在下面的Swift代码中,这批训练数据是专门为MNIST数据集准备的,只需将每张图像的“像素”值从初始范围0到255标准化为0到1之间的“可理解”值."范围。为CoreML模型(CNN)训练做准备随着训练数据批次的处理和归一化,现在是时候使用SwiftCoreMLTools库在Swift的CNNCoreML模型中执行一系列本地化准备工作了。在下面的SwiftCoreMLToolsDSL中functionbuilder代码,还可以看到同样的情况是如何传入到CoreML模型中的,同时里面还包含了基本的训练信息,超参数等,比如损失函数,优化器,学习率,epoch数,batchsize等使用Adam优化器训练神经网络,具体参数如下:接下来是构建CNN网络,卷积层、激活层和池化层定义如下:然后使用相同的和之前一样设置卷积、激活和池化操作,然后输入Flatten层,经过两个全连接层后用Softmax输出结果。由此产生的CNN模型新构建的CoreML模型有两个嵌套的卷积层和最大池化层。将所有数据展平后,连接一个隐藏层,最后连接一个全连接层,输出Softmax激活后的结果。对TensorFlow2.0模型进行基准测试为了对结果进行基准测试,尤其是在运行时方面的训练效果,作者还使用TensorFlow2.0重新创建了相同CNN模型的精确副本。下面的Python代码显示了TF中相同的模型架构和每一层的OutPutShape:如您所见,这里的层、层形状、卷积过滤器和池大小与使用SwiftCoreMLTools库模型完全一样。比较结果在查看训练执行时间性能之前,首先要确保CoreML和TensorFlow模型都训练了相同的时期数(10),使用相同的超参数在相同的10,000个测试样本上获得非常相似的准确度指标图片。从下面的Python代码可以看出,TensorFlow模型是使用Adam优化器和分类交叉熵损失函数训练的,测试用例的最终精度结果大于0.98。CoreML模型的结果如下图所示。它使用与TensorFlow相同的优化器、损失函数、训练集和测试集。可以看出识别准确率超过0.98。