当前位置: 首页 > 科技观察

深度学习框架Flash是如何仅用几行代码就构建出图像分类器的?

时间:2023-03-17 22:50:01 科技观察

【.com快译】1.简介图像分类是我们要预测图像属于哪个类别的任务。由于图像表示,这项任务很困难。如果我们展平图像,它会创建一个长的一维向量。此外,这种表示会丢失邻居信息。因此,我们需要深度学习来提取特征和预测结果。有时构建深度学习模型可能是一项艰巨的任务。虽然我们创建了一个图像分类的基本模型,但创建代码还是花了很多时间。我们必须准备代码来准备数据、训练模型、测试模型并将模型部署到服务器。这就是Flash的用武之地!Flash是一种高级深度学习框架,用于快速构建、训练和测试深度学习模型。Flash基于PyTorch框架。所以如果你了解PyTorch,就会熟悉Flash。与PyTorch和Lighting相比,Flash易于使用,但不如以前的库灵活。如果你想构建更复杂的模型,你可以使用Lightning或直接使用PyTorch。使用Flash,只需几行代码即可构建深度学习模型!所以,如果您是深度学习的新手,请不要被吓倒。Flash可以帮助您构建深度学习模型,而不会与代码混淆。本文介绍如何使用Flash构建图像分类器。2.实施安装库要安装库,您可以使用pip命令,如下所示:pipinstalllightning-flash如果该命令不起作用,您可以使用其GitHub存储库安装库。命令如下所示:pipinstallgit+https://github.com/PyTorchLightning/lightning-flash.git在我们成功下载包之后,现在可以加载库了。我们还设置了种子编号42。这是执行此操作的代码:frompytorch_lightningimportseed_everythingimportflashfromflash.core.classificationimportLabelsfromflash.core.data.utilsimportdownload_datafromflash.imageimportImageClassificationData,ImageClassifier#settherandomseeds.seed_everything(42)Globalseedsetto4242DownloadData现在您已经安装了库,是时候获取数据了。出于演示目的,我们将使用名为Cat和Dog数据集的数据集。该数据集包含两类:猫和狗的图像。要访问数据集,您可以在Kaggle找到它。可以在此处访问数据集。加载数据下载数据后,让我们将数据集加载到一个对象中。我们将使用from_folders方法将数据放入ImageClassification对象中。这是执行此操作的代码:datamodule=ImageClassificationData.from_folders(train_folder="cat_and_dog/training_set",val_folder="cat_and_dog/validation_set",)加载模型加载数据后,下一步就是加载模型。由于我们不会从头开始构建自己的架构,因此我们将使用基于现有卷积神经网络架构的预训练模型。我们将使用预训练的ResNet-50模型。此外,我们根据数据集设置类别数。这是执行此操作的代码:model=ImageClassifier(backbone="resnet50",num_classes=datamodule.num_classes)训练模型加载模型后,就可以训练模型了。我们需要先初始化Trainer对象。我们将用3个epoch训练模型。此外,我们启用GPU来训练模型。这是执行此操作的代码:trainer=flash.Trainer(max_epochs=3,gpus=1)GPUavailable:True,used:TrueTPUavailable:False,using:0TPUcores初始化对象后,让我们训练模型。为了训练模型,我们可以使用一个叫做finetune的函数。在函数内部,我们设置了模型和数据。此外,我们将训练策略设置为冻结,这表明我们不想训练特征提取器。换句话说,我们只训练分类器部分。这是执行此操作的代码:trainer.finetune(model,datamodule=datamodule,strategy="freeze")LOCAL_RANK:0-CUDA_VISIBLE_DEVICES:[0]|Name|Type|Params------------------------------------------0|metrics|ModuleDict|01|backbone|Sequential|23.5M2|head|Sequential|4.1K------------------------------------------57.2KTrainableparams23.5MNon-trainableparams23。5MTotalparams94.049Totalestimatedmodelparamssize(MB)Validationsanitycheck:0it[00:00,?it/s]Globalseedsetto42Training:0it[00:00,?it/s]验证:0it[00:00,?it/s]验证:0it[00:00,?it/s]Validating:0it[00:00,?it/s]这是评估结果:从结果可以看出,我们的模型准确率在97%左右。不错!现在让我们在一些新数据上测试模型。为了测试模型,我们将使用未训练模型的样本数据。这是我们将测试的模型示例:importmatplotlib.pyplotaspltfromPILimportImagefig,ax=plt.subplots(1,5,figsize=(40,8))foriinrange(5):ax[i].imshow(Image.open(f'cat_and_dog/testing/{i+1}.jpg'))plt.show()来测试模型,我们可以使用flash库中的predict方法。这是执行此操作的代码:model.serializer=Labels()predictions=model.predict(["cat_and_dog/testing/1.jpg","cat_and_dog/testing/2.jpg","cat_and_dog/testing/3.jpg","cat_and_dog/testing/4.jpg","cat_and_dog/testing/5.jpg"])print(predictions)['dogs','dogs','cats','cats','dogs']从上面从结果可以看出,模型预测了带有正确标签的样本。伟大的!让我们保存模型以备后用。保存模型我们已经训练和测试了模型。不妨使用save_checkpoint方法保存模型。下面是执行此操作的代码:trainer.save_checkpoint("cat_dog_classifier.pt")如果要针对其他代码加载模型,可以使用load_from_checkpoint方法。这是执行此操作的代码:model=ImageClassifier.load_from_checkpoint("cat_dog_classifier.pt")3.干得好!您学习了如何使用Flash构建图像分类器。正如文章开头所述,只需要几行代码!是不是很酷?希望本文能帮助您针对您的情况构建自己的深度学习模型。如果你想实现更复杂的模型,希望你可以开始学习PyTorch。原标题:HowtoBuildanImageClassifierinFewLinesofCodewithFlash,作者:IrfanAlghaniKhalid