简介Tensorflow在更新1.0版本后增加了很多新功能,发布了很多用tf框架编写的深度网络结构(https://github.com/tensorflow/models),大大降低了开发难度,使用现成的网络结构,无论是fine-tuning还是retraining都方便很多。最近笔者终于跑通了TensorFlowObjectDetectionAPI的ssd_mobilenet_v1模型。这里记录下如何完成从数据准备到模型使用的全过程。相信对自己和一些同学都有帮助。ObjectDetectionAPI提供了5种网络结构的预训练权重,所有这些都使用COCO数据集进行训练。这五个模型分别是SSD+mobilenet、SSD+inception_v2、R-FCN+resnet101、fasterRCNN+resnet101、fasterRCNN+inception+resnet101。每个模型所需的精度和计算时间如下。下面介绍如何使用ObjectDetection来训练自己的模型。TensorFlow的安装这里就不多说了。网上有很多教程,可以找到非常详细的安装TensorFlow的文档。训练前的准备:使用protobuf配置模型和训练参数,所以必须编译API才能正常使用。protobuf库可以在这里(https://github.com/google/protobuf/releases)下载,解压压缩包后,在环境变量中添加protoc:$cdtensorflow/models$protocobject_detection/protos/*.proto--python_out=。(我在环境变量中添加了protoc,遇到找不到*.proto文件的错误,然后将protoc.exe放到models/object_detection目录下,再次执行)然后添加models和slim(tfadvancedframework)到python环境变量:PYTHONPATH=$PYTHONPATH:/your/path/to/tensorflow/models:/your/path/to/tensorflow/models/slim数据准备:数据集需要转成PASCALVOC结构。API提供create_pascal_tf_record.py将VOC结构数据集转换为.record格式。但是我们找到了一种更简单的方法,Datitran提供了一种更简单的方法来生成.record格式。首先,你需要标记图片对应的label,这里可以使用labelImg工具。每次标记样本时,都会生成一个xml标记文件。然后,将这些标注好的xml文件按照训练集和验证集分别放到两个目录下,在Datitran中提供xml_to_csv.py脚本。这里只需要指定标签的目录名即可。接下来,我们需要将对应的csv格式转换成.record格式。defmain():#image_path=os.path.join(os.getcwd(),'annotations')image_path=r'D:\training-sets\object-detection\sunglasses\label\test'xml_df=xml_to_csv(image_path)xml_df.to_csv('sunglasses_test_labels.csv',index=None)print('Successfullyconvertedxmltocsv.')调用generate_tfrecord.py,注意指定两个参数-csv_input和-output_path。执行以下命令:pythongenerate_tfrecord.py--csv_input=sunglasses_test_labels.csv--output_path=sunglass_test.record这将生成用于训练和验证的train.record和test.record。接下来指定label名称,按照models/object_detection/data/pet_label_map.pbtxt,重新创建一个文件,指定label名称。item{id:1name:'sunglasses'}训练:根据自己的需要,选择用coco数据集预训练好的模型,将前缀model.ckpt放在要训练的目录下,meta文件保存graph和元数据。ckpt保存了网络的权重,这些文件代表了预训练模型的初始状态。打开ssd_mobilenet_v1_pets.config文件,修改如下:num_classes:修改为自己的classesnum,PATH_TO_BE_CONFIGURED全部修改为之前设置的路径(共5处),其他参数保持默认。准备好以上文件后,就可以直接调用train文件进行训练了。pythonobject_detection/train.py\--logtostderr\--pipeline_config_path=D:/training-sets/data-translate/training/ssd_mobilenet_v1_pets.config\--train_dir=D:/training-sets/data-translate/trainingTensorBoard监控:通过tensorboard工具可以监控训练过程。输入west命令后,在浏览器中输入localhost:6006(默认)。tensorboard--logdir=D:/training-sets/data-translate/training有很多指标曲线甚至模型网络架构。这里很多指标的含义我都没有弄明白,但是感觉TensorBoard这个工具应该是非常强大的。但是,我们可以通过Total_Loss来看整体的训练情况。整体来看,loss曲线确实是收敛的,整体训练效果还是比较满意的。另外TensorFlow也提供了在训练过程中使用验证集来验证准确率的能力,但是笔者在调用的时候还是遇到了一些问题,这里就不详细说明了。FreezeModel模型导出:在查看模型实际效果之前,我们需要将训练过程文件导出,生成.pb模型文件。本来tensorflow/python/tools/freeze_graph.py提供了freezemodelapi,但是需要提供输出的最终节点名(一般以softmax等最后一层的激活函数命名),以及objectdetectionapiprovided经过预训练的网络,最终的节点名不好找,所以在object_detection目录下也提供了export_inference_graph.py。pythonexport_inference_graph.py\--input_typeimage_tensor--pipeline_config_pathD:/training-sets/data-translate/training/ssd_mobilenet_v1_pets.config\--trained_checkpoint_prefixD:/training-sets/data-translate/training/ssd_mobilenet_v1_pets.config/model.ckpt-*\--output_directoryD:/training-sets/data-translate/training/result导出完成后,在output_directory下,会生成frozen_inference_graph.pb、model.ckpt.data-00000-of-00001、model.ckpt.meta、model.ckpt.data文件。调用生成模型:目录下本身有一个调用的例子,稍微改造如下:importcv2importnumpyasnpimporttensorflowastffromobject_detection.utilsimportlabel_map_utilfromobject_detection.utilsimportvisualization_utilsasvis_utilclassTOD(object):def__init__(self):self.PATH_TO_CKPT=r'D:\lib\tf-model\models-master\object_detection\training\frozen_inference_graph.pb'self.PATH_TO_LABELS=r'D:\lib\tf-model\models-master\object_detection\training\sunglasses_label_map.pbtxt'self.NUM_CLASSES=1self.detection_graph=self._load_model()self.category_index=self._load_label_map()def_load_model(self):detection_graph=tf.Graph()withdetection_graph.as_default():od_graph_def=tf.GraphDef()withtf.gfile.GFile(self.PATH_TO_CKPT,'rb')asfid:serialized_graph=fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def,name='')returndetection_graphdef_load_label_map(self):label_map=label_map_util.load_labelmap(self.PATH_TO_LABELS)类别=label_map_util.convert_label_map_to_categories(label_map,max_num_classes=self.NUM_CLASSES,use_display_name=True)category_index=label_map_util.create_category_index(类别)returncategory_indexdefdetect(self,image):withself.detection_graph.as_default():withtf.Session(graph=self.detection_graph)评估:#Expanddimensionssincethemodelexpectsimagestohaveshape:[1,无,无,3]image_np_expanded=np.expand_dims(image,axis=0)image_tensor=self.detection_graph.get_tensor_by_name('image_tensor:0')boxes=self.detection_graph.get_tensor_by_name('detection_boxes:0')scores=self.detection_graph.get_tensor_by_name('detection_scores:0')classes=self.detection_graph.get_tensor_by_name('detection_classes:0')num_detections=self.detection_graph.get_tensor_by_name('num_detections:0')#Actualdetection.(boxes,scores,classes,num_detections)=sess.run([boxes,scores,classes,num_detections],feed_dict={image_tensor:image_np_expanded})#Visualizationoftheresultsofadetection.vis_util.visualize_boxes_and_labels_on_image_array(image,np.sque)eze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),self.category_index,use_normalized_coordinates=True,line_thickness=8)cv2.namedWindow(“检测”,cv2.WINDOW_NORMAL)cv2.imshow("detection",image)cv2.waitKey(0)if__name__=='__main__':image=cv2.imread('image.jpg')detecotr=TOD()detecotr.detect(image)下面是一些图片识别效果:结束。
