(一)數(shù)據(jù)增強(qiáng)(增廣)
(1)為什么要做數(shù)據(jù)增強(qiáng)概行?
一個(gè)原因是可能你的數(shù)據(jù)集比較小细层,所以需要對(duì)數(shù)據(jù)進(jìn)行簡(jiǎn)單的操作,讓數(shù)據(jù)集增加倦逐。第二是有這樣的一個(gè)真實(shí)實(shí)例轿曙。有一家做智能售貨機(jī)的公司在公司內(nèi)部調(diào)試好參數(shù)訓(xùn)練好模型以后,將售貨機(jī)拿去展廳進(jìn)行測(cè)試的時(shí)候僻孝,發(fā)現(xiàn)原本準(zhǔn)確率非常高的機(jī)器忽然識(shí)別不出來(lái)了。原因是展廳的光源不一樣守谓,導(dǎo)致整個(gè)的測(cè)試數(shù)據(jù)集就和訓(xùn)練集發(fā)生了很大的變化穿铆。所以在產(chǎn)品研發(fā)的時(shí)候適當(dāng)?shù)耐ㄟ^(guò)數(shù)據(jù)增強(qiáng)技術(shù)能夠給模型增加魯棒性。
(2)數(shù)據(jù)增強(qiáng)方法有什么斋荞?
可通過(guò)在圖片中加入各種不一樣的背景噪音荞雏,改變圖片的顏色和形狀。
- 翻轉(zhuǎn):左右翻轉(zhuǎn)平酿、上下翻轉(zhuǎn)
- 切割:在圖片中切割出一塊凤优,然后變形到固定的形狀
- 顏色:改變色調(diào),飽和度蜈彼,明度等
(二)代碼實(shí)現(xiàn)
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib.image as img
import matplotlib.pyplot as plt
# 打開(kāi)圖片的方法
# image = img.imread('../img/cat1.jpg')
# plt.title("cat.jpg")
# plt.axis("off")
# plt.imshow(image)
# plt.show()
d2l.set_figsize()
img = d2l.Image.open('../img/cat1.jpg')
d2l.plt.imshow(img)
# 參數(shù)列表(圖片筑辨,增強(qiáng)的辦法,多少行幸逆,多少列棍辕,倍數(shù))
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
y = [aug(img) for _ in range(num_rows*num_cols)]
d2l.show_images(y, num_rows, num_cols, scale=scale)
# 左右翻轉(zhuǎn)圖片
apply(img,torchvision.transforms.RandomHorizontalFlip())
# 上下翻轉(zhuǎn)
apply(img,torchvision.transforms.RandomVerticalFlip())
# 隨即裁剪
shape_aug = torchvision.transforms.RandomResizedCrop(
# (輸出大小,選擇的比例还绘,高寬比)
size=(200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)
# 隨機(jī)更改圖片的亮度
apply(img,torchvision.transforms.ColorJitter(
# (亮度區(qū)間楚昭,對(duì)比度,飽和度拍顷,色調(diào))
brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
))
# 結(jié)合多種數(shù)據(jù)增強(qiáng)方法
augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
shape_aug,
torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
])
apply(img, augs)
# 如果下載報(bào)錯(cuò)的話抚太,自己去網(wǎng)頁(yè)上下載
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
all_images = torchvision.datasets.CIFAR10(
train=True,
root="../data",
download=False
)
d2l.show_images([all_images[i][0] for i in range(32)],4,8,scale=0.8)
# d2l.show_images([all_images.data[i] for i in range(32)],4,8,scale=0.8)
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()
])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
def load_cifar10(is_train,augs, batch_size):
dataset = torchvision.datasets.CIFAR10(
train=is_train,
root="../data",
download=False,
transform=augs,
)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
return data_loader
def train_batch_ch13(net, X, y, loss, trainer, devices):
"""用多GPU進(jìn)行小批量訓(xùn)練"""
if isinstance(X, list):
# 微調(diào)BERT中所需(稍后討論)
X = [x.to(devices[0]) for x in X]
else:
X = X.to(devices[0])
y = y.to(devices[0])
net.train()
trainer.zero_grad()
pred = net(X)
l = loss(pred, y)
l.sum().backward()
trainer.step()
train_loss_sum = l.sum()
train_acc_sum = d2l.accuracy(pred, y)
return train_loss_sum, train_acc_sum
#@save
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
devices=d2l.try_all_gpus()):
"""用多GPU進(jìn)行模型訓(xùn)練"""
timer, num_batches = d2l.Timer(), len(train_iter)
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
legend=['train loss', 'train acc', 'test acc'])
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
for epoch in range(num_epochs):
# 4個(gè)維度:儲(chǔ)存訓(xùn)練損失,訓(xùn)練準(zhǔn)確度昔案,實(shí)例數(shù)尿贫,特點(diǎn)數(shù)
metric = d2l.Accumulator(4)
for i, (features, labels) in enumerate(train_iter):
timer.start()
l, acc = train_batch_ch13(
net, features, labels, loss, trainer, devices)
metric.add(l, acc, labels.shape[0], labels.numel())
timer.stop()
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(metric[0] / metric[2], metric[1] / metric[3],
None))
test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
animator.add(epoch + 1, (None, None, test_acc))
print(f'loss {metric[0] / metric[2]:.3f}, train acc '
f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
f'{str(devices)}')
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)
def init_weights(m):
if type(m) in [nn.Linear, nn.Conv2d]:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
train_iter = load_cifar10(True, train_augs, batch_size)
test_iter = load_cifar10(False, test_augs, batch_size)
loss = nn.CrossEntropyLoss(reduction="none")
trainer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)