1. 數(shù)據(jù)準備
基本步驟: 生成Dataset(或其子類)對象->傳入DataLoader(為可迭代對象尚胞,可以用for迭代)
1.1 Dataset類
Dataset為抽象類
- 注意
- 直接從
Dataset
中取出的數(shù)據(jù)是沒有經(jīng)過transform
的,只有通過Dataloader
加載才可以training_data=torchvision.datasets.MNIST(root="./mnist", train=True, transform=torchvision.transforms.ToTensor(), download=True) # 像素點的范圍仍然是0-255, 不是0-1 print(training_data.train_data[0])
- 直接從
1.1.1 Dataset子類:TensorDataset
- 源碼閱讀
class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. 代碼示例: x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) dataset = TensorDataset(x, y) """ def __init__(self, *tensors): """ &1 tensors[0]為x; tensor[1]為y疫蔓。因為x,y的batch_size要相同弦叶,所以要assert TensorDataset(x, y, z...)傳入任意多參數(shù)都是可以的 """ assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): """ &2 相當于重載[]運算符 """ return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
- 示例代碼
import torch import torch.utils.data as Data if __name__ == "__main__": x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) dataset = Data.TensorDataset(x, y) # &1 # 當最后一個step不足5個(假設僅剩2個)贞盯,則僅會返回2個 # shuffle: 訓練時為True則打亂數(shù)據(jù)集 # num_workers為子進程數(shù)量 dataloader = Data.DataLoader(dataset=dataset, batch_size=5, shuffle=True, num_workers=2) for epoch in range(3): for step,input_data in enumerate(dataloader): print(f"{epoch}-{step}:\n{input_data}")
2. 網(wǎng)絡搭建
2.1 class模式
2.2 Sequential模式
net = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2)
)