項(xiàng)目地址 https://github.com/PyTorchLightning/pytorch-lightning
以下內(nèi)容整理自項(xiàng)目作者的講解視頻:Converting from PyTorch to PyTorch Lightning (油管視頻需梯自備子)
import torch.nn as nn
import torch
import torch.optim as optim
import pytorch_lightning as pl
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
def forward(self,x):
# 可以結(jié)合training_step函數(shù)肌访,簡(jiǎn)化forward的內(nèi)容
pass
def loss_func(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x,y = batch #
y_hat = self(x)
# return {'loss':F.cross_entropy(y_hat, y)}
loss = self.loss_func(y_hat, y)
return {'loss':loss}
################################
# log = {'train_loss':loss}
# return {'loss':loss, 'log':log}
# 這樣就可以在tensorboard中看到train_loss的曲線
def log_func(self,):
# do whatever you want, print, file operation, etc.
pass
def validation_step(self, batch, batch_idx):
# !!! val data 不應(yīng)該用shuffle
x,y = batch #
y_hat = self(x)
val_loss = self.loss_func(y_hat, y)
if batch_idx == 0:
n = x.size(0)
self.log_func()
return {'val_loss':val_loss}
##############################################################
### 這里定義了dataloader fit里就不用通過(guò)參數(shù)傳入了
################################
def train_dataloader(self):
loader = torch.utils.data.DataLoader()
return loader
def val_dataloader(self):
loader = torch.utils.data.DataLoader()
return loader
################################
# 使用tensorboard等 logger, 替代validation_step中l(wèi)og_func這一部分
################################
def validation_epoch_end(self, outputs):
# 計(jì)算batch的平均損失,這里的outputs就是validation_step返回的
val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
# 也可以傳入其他數(shù)據(jù),如VAE 重建的圖像
# x_hat = outputs[0]['x_hat']
# grid = torchvision.utils.make_grid(x_hat)
# self.logger.experiment 就是 tensorboard SummaryWriter
self.logger.experiment.add_image('images', grid,0)
log = {'avg_val_loss':val_loss}
return {'log':log}
################################
# 如果return的dict中有key='val_loss'會(huì)自動(dòng)出發(fā)保存模型
# return {'val_loss':val_loss}
if __name__ == '__main__':
# dataloader 可以放到module中
train_loader = torch.utils.data.DataLoader()
val_loader= torch.utils.data.DataLoader() # shuffle=False
net =Net()
# 快速跑完一個(gè)train batch和一個(gè)dev batch
# 驗(yàn)證整個(gè)流程沒(méi)錯(cuò)
trainer = pl.Trainer(fast_dev_run=True)
# 完整的訓(xùn)練過(guò)程 Trainer() 即可
# train_percent_check=0.1 只訓(xùn)練0.1的數(shù)據(jù)
trainer.fit(net,
train_dataloader=train_loader,
val_dataloaders=val_loader
)
################################
# argparser 的使用
from argparser import ArgumentParser
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
parser.add_argument('--learning_rate', default=1e-3, type=float)
args = parser.parse_args()
net = Net()
trainer = pl.Trainer.from_argparse_args(args, fast_dev_run=True)
trainer.fit(net)
################################
# 單GPU訓(xùn)練
# terminal: python main.py --gpus 1 --batch_size 256
# 多GPU訓(xùn)練
# 默認(rèn)用DP dataparallel 但用DDP更好 distributed DP
# terminal: python main.py --gpus 2 --distributed_backend ddp
################################
# 16 bit 訓(xùn)練 pytorch 1.6 內(nèi)建 apex
# 可能需要修改一定的代碼亲怠,比如說(shuō)Loss函數(shù)
# from F.binary_cross_entropy to
# F.binary_cross_entropy_with_logits(y_hat,y,reduction='sum')