当前位置: 首页 > 科技观察

深度学习利器:TensorFlow在智能终端的应用

时间:2023-03-22 11:13:15 科技观察

前言深度学习在图像处理、语音识别、自然语言处理等领域取得了巨大的成功,但通常在强大的服务器上进行计算。如果智能手机通过网络远程连接服务器,也可以使用深度学习技术,但这可能会很慢,而且只有在设备处于良好的网络连接环境下才会起作用,这就需要深度学习的迁移学习模型到智能终端。由于智能终端的CPU和内存资源有限,为了提高计算性能和内存利用率,需要对服务端模型进行量化,支持低精度算法。TensorFlow版本增加了对Android、iOS和RaspberryPi硬件平台的支持,使其可以在这些设备上执行图像分类等操作。这使得创建在智能手机上运行的机器学习模型成为可能,并且不需要云来全天候24/7支持它们,从而产生新的应用程序。本文主要以看花知命APP应用为主,讲解TensorFlow模型如何应用于Android系统;在服务端训练TensorFlow模型,并将模型文件迁移到智能终端;构建TensorFlowAndroid开发环境和应用程序开发API。看花知命APP利用AlexNet模型、Flowers数据和Android平台打造“看花知命”APP。TensorFlow模型在五种花卉数据上进行了训练。如下图:Daisy:DaisyDandelion:DandelionRoses:RoseSunflowers:SunflowerTulips:Tulip在服务器端训练好模型后,将模型文件迁移到安卓平台,手机端安装APP.使用效果如下图所示。界面上方显示模型识别的置信度,界面中间显示待识别的花:如何将TensorFlow模型应用到看花知命APP中,主要包括以下几个关键步骤:模型选择和应用、模型文件转换和Android开发。如下图所示:模型训练和模型文件本章使用AlexNet模型对Flowers数据进行训练。AlexNet在2012年取得了ImageNet最好的成绩,top5准确率达到80.2%。这对于传统的机器学习分类算法来说已经相当不错了。模型结构如下:本文使用官方的TensorFlowSlim(https://github.com/tensorflow/models/tree/master/slim)AlexNet模型进行训练。首先下载Flowers数据,转换成TFRecord格式:DATA_DIR=/tmp/data/flowerspythondownload_and_convert_data.py--dataset_name=flowers--dataset_dir="${DATA_DIR}"执行模型训练,经过36618次迭代,模型精度达到85%TRAIN_DIR=/tmp/data/trainpythontrain_image_classifier.py--train_dir=${TRAIN_DIR}--dataset_dir=${DATASET_DIR}--dataset_name=flowers--dataset_split_name=train--model_name=alexnet_v2--preprocessing_name=vgg生成推理GraphPB文件pythonexport_inference_graph.py--alsologtostderr--model_name=alexnet_v2--dataset_name=flowers--dataset_dir=${DATASET_DIR}--output_file=alexnet_v2_inf_graph.pb结合CheckPoint文件和InferenceGraphPB文件生成PB文件pythonfreeze_graph冻结图。py--input_graph=alexnet_v2_inf_graph.pb--input_checkpoint=${TRAIN_DIR}/model.ckpt-36618--input_binary=true--output_graph=frozen_alexnet_v2.pb--output_node_names=alexnet_v2/fc8/freezeGraph数据量化的压缩PB文件处理,减小模型文件的大小,生成的quantized_alexnet_v2_graph.pb为智能终端中应用的模型文件bazel-bin/tensorflow/tools/graph_transforms/transform_graph--in_graph=frozen_alexnet_v2.pb--outputs="alexnet_v2/fc8/squeezed"--out_graph=quantized_alexnet_v2_graph.pb--transforms='add_default_attributesstrip_unused_nodes(type=float,shape="1,224,224,3")remove_nodes(op=Identity,op=CheckNumerics)fold_constants(ignore_errors=true)fold_batch_normsfold_old_batch_normsquantize_weightsquantize_nodesstrip_unused_nodessort_by_execution_order'为了在智能终端上减小模型文件的大小,TensorFlow中常用的方法是对模型文件进行量化。本文对AlexNetCheckPoint文件进行Freeze和Quantize,文件大小变化如下图所示:量化操作的主要思想是将32位浮点数操作替换成一个等效的8位整数模型推理阶段的操作。被替换的运算包括:卷积运算、矩阵乘法、激活函数池化、池化运算等量化节点的输入输出为浮点数,但内部运算会转换为8位整数(范围为0~255)通过量化计算。浮点数与8位量化整数的对应关系示例如下图所示:量化Relu运算的基本思路如下图:基于Android搭建的TensorFlowAndroid应用开发环境系统使用TensorFlow模型做Inference依赖于两个文件libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。下载TensorFlow源码后可以用bazel编译这两个文件,如下:下载TensorFlow源码gitclone–recurse-submoduleshttps://github.com/tensorflow/tensorflow.git下载并安装AndroidNDK下载并安装安装AndroidSDK在tensorflow/WORKSPACE配置android开发工具路径android_sdk_repository(name="androidsdk",api_level=23,build_tools_version="25.0.2",path="/opt/android",)android_ndk_repository(name="androidndk",path="/opt/android/android-ndk-r12b",api_level=14)编译libtensorflow_inference.sobazelbuild-copt//tensorflow/contrib/android:libtensorflow_inference.so--crosstool_top=//external:android/crosstool--host_crosstool_top=@bazel_tools//tools/cpp:toolchain--cpu=armeabi-v7a编译libandroid_tensorflow_inference_java.jarbazelbuild//tensorflow/contrib/android:android_tensorflow_inference_javaTensorFlow提供了Android开发的示例框架。下面是基于AlexNet模型的。修改源码,编译生成Android安装包:基于AlexNet模型,修改Inference的输入输出Tensor名称privatestaticfinalStringINPUT_NAME="input";privatestaticfinalStringOUTPUT_NAME="alexnet_v2/fc8/squeezed";将quantized_alexnet_v2_graph.pb和对应的labels.txt文件放在assets目录下,修改Android文件路径privatestaticfinalStringMODEL_FILE="file:///android_asset/quantized_alexnet_v2_graph.pb";privatestaticfinalStringLABEL_FILE="file:///android_asset/labels.txt";编译生成安装包bazelbuild-copt//tensorflow/examples/android:tensorflow_demo将tensorflow_demo.apk复制到手机上执行安装向日葵识别效果如下图所示:(点击放大图片)TensorFlow移动应用开发API在Android系统中执行TensorFlowInference操作,需要调用libandroid_tensorflow_inference_java.jar中的JNI接口,主要界面如下:构造TensorFlowInference对象,构造对象时会加载TensorFlow动态链接The库libtensorflow_inference.so添加到系统中;参数assetManager是android资源管理器;参数modelFilename为TensorFlow模型文件在android_asset中的路径TensorFlowInferenceInterfaceinferenceInterface=newTensorFlowInferenceInterface(assetManager,modelFilename);将输入数据加载到TensorFlow图中。本App的输入数据为摄像头拍摄的图像;参数inputName为TensorFlowInference中输入数据Tensor的名称;参数floatValues为输入图像的像素数据,预处理后的浮点值;[1,inputSize,inputSize,3]为裁剪后的图像尺寸,比如224*224*3的RGB图像。inferenceInterface.feed(inputName,floatValues,1,inputSize,inputSize,3);执行模型推理;outputNames是TensorFlowInference模型中要计算的Tensor的名称,是本APP中分类的Logist值。inferenceInterface.run(输出名称);获取模型Inference的运算结果,其中outputName为Tensor的名称,参数outputs存放的是Tensor的运算结果。在这个APP中,outputs就是计算出来的Logist浮点数组。inferenceInterface.fetch(输出名称,输出);总结本文基于看花知命APP讲解TensorFlow在安卓智能终端的应用技术。一、回顾AlexNet模型结构,基于AlexNet的slim模型训练Flowers数据;对训练好的CheckPoint数据进行Freeze和Quantized处理,生成智能终端的Inference模型。然后介绍了TensorFlowAndroid应用开发环境的搭建,在Android上编译生成TensorFlow的动态链接库和java开发包;文章最后介绍了InferenceAPI的使用。参考资料http://www.tensorflow.org深度学习工具:分布式TensorFlow与实例分析深度学习工具:TensorFlow使用实战深度学习工具:TensorFlow系统架构与高性能编程网络深度学习工具:TensorFlow与NLP模型