当前位置: 首页 > 科技观察

PyTorch中的数据集Torchvision和Torchtext

时间:2023-03-15 22:19:31 科技观察

PyTorch为了加载和处理不同类型的数据,官方提供了torchvision和torchtext。之前使用torchDataLoader类直接加载图片并转化为tensor。现在结合torchvision和torchtext来介绍torch中内置的数据集。TorchvisionMNIST中的数据集MNIST是由归一化和中心裁剪的手写图像组成的数据集。它有超过60,000张训练图像和10,000张测试图像。这是用于学习和实验目的的最常用的数据集之一。要加载和使用数据集,请使用以下语法导入:torchvision.datasets.MNIST()。FashionMNISTFashionMNIST数据集与MNIST类似,但该数据集包含T恤、裤子、包等服装项目,而不是手写数字,训练和测试样本的数量分别为60,000和10,000。要加载和使用数据集,请使用以下语法导入:torchvision.datasets.FashionMNIST()CIFARCIFAR数据集有两个版本,CIFAR10和CIFAR100。CIFAR10由具有10个不同标签的图像组成,而CIFAR100有100个不同的类别。其中包括卡车、青蛙、船、汽车、鹿等常见图像。torchvision.datasets.CIFAR10()torchvision.datasets.CIFAR100()COCOCOCO数据集包含超过100,000个日常物体,例如人、瓶子、文具、书籍等。该图像数据集广泛用于对象检测和图像字幕应用。这里是可以加载COCO的地方:torchvision.datasets.CocoCaptions()EMNISTEMNIST数据集是MNIST数据集的高级版本。它由包含数字和字母的图像组成。如果您正在处理基于从图像中识别文本的问题,EMNIST是一个不错的选择。您可以在此处加载EMNIST::torchvision.datasets.EMNIST()IMAGE-NETImageNet是训练高端神经网络的旗舰数据集之一。它由分布在10,000个类别中的超过120万张图像组成。通常,此数据集加载在高端硬件系统上,因为单个CPU无法处理如此大的数据集。这是加载ImageNet数据集的类:torchvision.datasets.ImageNet()Torchtext中的数据集IMDBIMDB是一个用于情感分类的数据集,包含一组25,000条高度极端的电影评论用于训练,另外25,000条用于测试。使用以下类加载这些数据torchtext:torchtext.datasets.IMDB()WikiText2WikiText2语言建模数据集是超过1亿个标记的集合。它摘自维基百科并保留标点符号和实际字母大小写。它广泛用于涉及长期依赖关系的应用程序。这些数据可以从torchtext加载:torchtext.datasets.WikiText2()除了上面提到的两个流行的数据集之外,torchtext库中还有更多可用的数据集,例如SST、TREC、SNLI、MultiNLI、WikiText-2,WikiText103,PennTreebank,Multi30k,etc.深入了解MNIST数据集MNIST是最流行的数据集之一。现在我们将看到PyTorch如何从pytorch/vision存储库加载MNIST数据集。让我们首先下载数据集并将其加载到名为data_trainfromtorchvision.datasetsimportMNIST#DownloadMNISTdata_train=MNIST('~/mnist_data',train=True,download=True)importmatplotlib.pyplotaspltrandom_image=data_train[0][0]random_image_label=data_train[0]的变量中][1]#PrinttheImageusingMatplotlibplt.imshow(random_image)print("Thelabeloftheimageis:",random_image_label)DataLoader加载MNIST接下来,我们使用DataLoader类加载数据集,如下图。importtorchfromtorchvisionimporttransformsdata_train=torch.utils.data.DataLoader(MNIST('~/mnist_data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor()])),batch_size=64,shuffle=True)forbatch_idx,samplesinnumerate(data_train):print(batch_idx,samples)CUDAload我们可以让GPU更快地训练我们的模型。现在让我们使用一个可以在使用CUDA加载数据时使用的配置(GPU支持PyTorch)。device="cuda"iftorch.cuda.is_available()else"cpu"kwargs={'num_workers':1,'pin_memory':True}ifdevice=='cuda'else{}train_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('/files/',train=True,download=True),batch_size=batch_size_train,**kwargs)test_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('files/',train=False,download=True),batch_size=batch_size,**kwargs)ImageFolderImageFolder是一个通用数据加载器类torchvision,有助于加载自己的图像数据集。以分类问题为例,构建一个神经网络来识别给定图像是苹果还是橙子。要在PyTorch中执行此操作,第一步是按以下默认文件夹结构排列图像:root├──orange├──orange_image1.png│└──orange_image1.png├──apple│└──apple_image1。png│└──apple_image2.png│└──apple_image3.png所有这些图像都可以使用ImageLoader类加载。torchvision.datasets.ImageFolder(root,transform)transformsPyTorch转换定义了简单的图像转换技术,可将整个数据集转换为一种独特的格式。如果它是一个包含不同分辨率的不同汽车图片的数据集,那么在训练时,我们训练数据集中的所有图像应该具有相同的分辨率大小。如果我们手动将所有图像转换为所需的输入尺寸会很耗时,因此我们可以使用变换;使用几行PyTorch代码,我们数据集中的所有图像都可以转换为所需的输入大小和分辨率。现在让我们加载CIFAR10torchvision.datasets并应用以下转换:将所有图像调整为32×32将中心裁剪转换应用于图像将裁剪图像转换为张量归一化图像32),#center-crop裁剪变换transforms.CenterCrop(32),#to-tensortransforms.ToTensor(),#normalizenormalizetransforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=False)创建一个PyTorch中的自定义数据集下面将创建一个由数字和文本组成的简单自定义数据集。需要封装Dataset类中的__getitem__()和__len__()方法。__getitem__()方法按索引返回数据集中选定的样本。__len__()方法返回数据集的总大小。以下是用于封装FruitImagesDataset数据集的代码,它基本上是在PyTorch中创建自定义数据集的更好模板。importosimportnumpyasnpimportcv2importtorchimportmatplotlib.patchesaspatchesimportalbumentationsasAfromalbumentations.pytorch.transformsimportToTensorV2frommatplotlibimportpyplotaspltfromtorch.utils.dataimportDatasetfromxml.etreeimportElementTreeasetfromtorchvisionimporttransformsastorchtransclassFruitImagesDataset(torch.utils.data.Dataset):def__init__(self,files_dir,width,height,transforms=None):self.transforms=transformsself.files_dir=files_dirself.height=heightself.width=widthself.imgs=[imageforimageinsorted(os.listdir(files_dir))ifimage[-4:]=='.jpg']self.classes=['_','apple','banana','orange']def__getitem__(self,idx):img_name=self.imgs[idx]image_path=os.path.join(self.files_dir,img_name)#readingtheimagesandconvertingthemtocorrectsizeandcolorimg=cv2.imread(image_path)img_rgb=cv2.cvtColor(img,cv2.COLOR_BGR2RGB).astype(np.float32)img_res=cv2.resize(img_rgb,(self.width,self.height),cv2.INTER_AREA)#divingby255img_res/=255.0#annotationfileannot_filename=img_name[:-4]+'.xml'annot_file_path=os.path.join(self.files_dir,annot_filename)boxes=[]labels=[]tree=et.parse(annot_file_path)root=tree.getroot()#cv2imagegivessizeaheightxwidthwt=img.shape[1]ht=img.shape[0]#boxcoordinatesforxmlfilesareextractedandcorrectedforimagesizegivenformemberinroot.findall('object'):labels.append(self.classes.index(member.find('name').text))#boundingboxxmin=int(member.find('bndbox').find('xmin').text)xmax=int(member.find('bndbox').find('xmax').text)ymin=int(member.查找('bndbox').find('ymin').text)ymax=int(member.find('bndbox').find('ymax').text)xmin_corr=(xmin/wt)*self.widthxmax_corr=(xmax/wt)*self.widthymin_corr=(ymin/ht)*self.heightymax_corr=(ymax/ht)*self.heightboxes.append([xmin_corr,ymin_corr,xmax_corr,ymax_corr])#convertboxesintoatorch.Tensorboxes=torch.as_tensor(boxes,dtype=torch.float32)#gettingtheareasoftheboxesarea=(boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])#supposeallinstancesarenotcrowdiscrowd=torch.zeros((boxes.shape[0],),dtype=torch.int64)labels=torch.as_tensor(labels,dtype=torch.int64)target={}目标["boxes"]=boxestarget["labels"]=labelstarget["area"]=areatarget["iscrowd"]=iscrowd#image_idimage_id=torch.tensor([idx])target["image_id"]=image_idifself.transforms:sample=self.transforms(image=img_res,bboxes=target['boxes'],labels=labels)img_res=sample['image']target['boxes']=torch.Tensor(样本['bboxes'])returnimg_res,targetdef__len__(self):returnlen(self.imgs)defget_transform(火车):iftrain:returnA.Compose([A.Horizo??ntalFlip(0.5),ToTensorV2(p=1.0)],bbox_params={'format':'pascal_voc','label_fields':['labels']})else:returnA.Compose([ToTensorV2(p=1.0)],bbox_params={'format':'pascal_voc','label_fields':['labels']})files_dir='../input/fruit-images-for-object-检测/train_zip/train'test_dir='../input/fruit-images-for-object-detection/test_zip/test'dataset=FruitImagesDataset(train_dir,480,480)