當(dāng)我們想要使用torchvision中自帶的數(shù)據(jù)集時瑰谜,應(yīng)該怎么做呢羔巢?
1 導(dǎo)包
import torchvision
2 下載
train_set = torchvision.datasets.CIFAR10(root="../dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset",train=False,download=True)
參數(shù)解釋:
root:數(shù)據(jù)集要存放的地址
train:值為True時下載訓(xùn)練集蚜点,值為False時下載測試集
download:一般均設(shè)置為True
3 使用
print(test_set[0])#打印test數(shù)據(jù)集的第一張圖片的所有信息
print(test_set.classes)#打印數(shù)據(jù)集所有類別信息
img ,target = test_set[0]
print(img)#打印test數(shù)據(jù)集第一張圖片的圖片信息
print(target)#打印test數(shù)據(jù)集第一張圖片的類別信息
img.show()#顯示圖片(因為這張圖片的格式為PIL躲舌,故可以直接.show())
使用結(jié)果
—————————————————————————————————————————
4 擴展
和上節(jié)課的tensorboard震肮、transforms等知識相結(jié)合后的代碼及結(jié)果如下:
完整代碼:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="../dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset",train=False,transform=dataset_transform,download=True)
# print(test_set[0])
# print(test_set.classes)
#
# img ,target = test_set[0]
# print(img)
# print(target)
# img.show()
print(test_set[0])
writer = SummaryWriter("../logs/P10_logs")
for i in range(10):
img , target = test_set[i]
writer.add_image("test_set",img ,i)
writer.close()
結(jié)果:
torchvision