当前位置: 首页 > 科技观察

在PyTorch中使用数据集和DataLoader自定义数据

时间:2023-03-19 21:16:43 科技观察

有时,在处理大型数据集时,一次将整个数据加载到内存中变得非常困难。因此,唯一的办法就是将数据批量加载到内存中进行处理,这需要额外编写代码来完成。为此,PyTorch已经提供了Dataloader功能。PyTorch库中DataLoader函数的语法及其参数信息如下所示DataLoader。DataLoader(数据集,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collat??e_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,*,prefetch_factor=2,persistent_workers=False)几个重要的参数dataset:首先要使用dataset来构造DataLoader类。Shuffle:是否重新排列数据。Sampler:指的是一个可选的torch.utils.data.Sampler类实例。采样器定义了一种策略,用于按顺序或随机或任何其他方式检索样本。使用采样器时应将Shuffle设置为false。Batch_Sampler:批处理级别。num_workers:加载数据所需的子进程数。collat??e_fn:将样本整理成批次。在Torch中可以进行自定义整理。加载内置的MNIST数据集MNIST是一个众所周知的包含手写数字的数据集。下面介绍如何使用DataLoader功能来处理PyTorch的内置MNIST数据集。importtorchiimportmatplotlib.pyplotaspltfromtorchvisionimportdatasets,transforms以上代码导入了torchvision的torch计算机视觉模块。通常在处理图像数据集时使用,可以帮助规范化、调整大小和裁剪图像。对于MNIST数据集,下面使用了归一化技术。ToTensor()可以将灰度范围从0-255变换到0-1。transform=transforms.Compose([transforms.ToTensor()])下面的代码用于加载所需的数据集。使用PyTorchDataLoader通过给定batch_size=64来加载数据。shuffle=True随机播放数据。trainset=datasets.MNIST('~/.pytorch/MNIST_data/',download=True,train=True,transform=transform)trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)得到数据集的所有图像,一般使用iter函数和数据加载器DataLoader。dataiter=iter(trainloader)图像,labels=dataiter.next()print(images.shape)print(labels.shape)plt.imshow(images[1].numpy().squeeze(),cmap='Greys_r')自定义数据集下面的代码创建了一个包含1000个随机数的自定义数据集。fromtorch.utils.dataimportDatasetimportrandomclassSampleDataset(数据集):def__init__(self,r1,r2):randomlist=[]foriinrange(120):n=random.randint(r1,r2)randomlist.append(n)self.samples=randomlistdef__len__(self):returnlen(self.samples)def__getitem__(self,idx):return(self.samples[idx])dataset=SampleDataset(1,100)dataset[100:120]这里插入图片描述最后会在自定义数据集上使用数据加载器功能。将batch_size设置为12,并使用num_workers=2启用并行多进程数据加载。fromtorch.utils.dataimportDataLoaderloader=DataLoader(dataset,batch_size=12,shuffle=True,num_workers=2)fori,batchinenumerate(loader):print(i,batch)是后面写的通过几个例子,了解到PyTorchDataloader是将一个大量数据批量加载到内存中的效果。