文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書
本文主要是關(guān)于PyTorch的一些用法伪很。
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
# 許多沒解釋的東西可以去查文檔, 文檔中都有, 已查過
# pytorch文檔: http://pytorch.org/docs/master/index.html
# matplotlib文檔: https://matplotlib.org/
# 隨機算法的生成種子
torch.manual_seed(1)
# 生成數(shù)據(jù)
n_data = torch.ones(100, 2)
# 類別一的數(shù)據(jù)
x0 = torch.normal(2 * n_data, 1)
# 類別一的標(biāo)簽
y0 = torch.zeros(100)
# 類別二的數(shù)據(jù)
x1 = torch.normal(-2 * n_data, 1)
# 類別二的標(biāo)簽
y1 = torch.ones(100)
# x0, x1連接起來, 按維度0連接, 并指定數(shù)據(jù)的類型
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
# y0, y1連接, 由于只有一維, 因此沒有指定維度, torch中標(biāo)簽類型必須為LongTensor
y = torch.cat((y0, y1), ).type(torch.LongTensor)
# x,y 轉(zhuǎn)為變量, torch只支持變量的訓(xùn)練, 因為Variable中有g(shù)rad
x, y = Variable(x), Variable(y)
# 繪制數(shù)據(jù)散點圖
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c = y.data.numpy(), s = 100, lw = 0, cmap = 'RdYlGn')
plt.show()
png
# 定義分類網(wǎng)絡(luò)
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.prediction = torch.nn.Linear(n_hidden, n_output)
def forward(self, x)
x = F.relu(self.hidden(x))
x = self.prediction(x)
return x
# 定義網(wǎng)絡(luò)
net = Net(n_feature = 2, n_hidden = 10, n_output = 2)
print(net)
Net (
(hidden): Linear (2 -> 10)
(prediction): Linear (10 -> 2)
)
# 定義優(yōu)化方法
optimizer = torch.optim.SGD(net.parameters(), lr = 0.02)
# 定義損失函數(shù)
loss_func = torch.nn.CrossEntropyLoss()
plt.ion()
# 訓(xùn)練過程
for i in xrange(100):
prediction = net(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 2 == 0:
plt.cla()
# 獲取概率最大的類別的索引
prediction = torch.max(F.softmax(prediction), 1)[1]
# 將輸出結(jié)果變?yōu)橐痪S
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c = pred_y, s = 100, lw = 0, cmap = 'RdYlGn')
# 計算準(zhǔn)確率
accuracy = sum(pred_y == target_y) / 200.0
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict = {'size': 10, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
png
# torch.max用法
a = torch.randn(4, 4)
print a
print torch.max(a, 1)
-1.8524 -1.0491 0.5382 -0.5129
0.1233 -0.1821 2.1519 -1.4547
-1.0267 0.2644 -0.8832 -0.2647
0.3944 -1.2512 -0.1158 0.5071
[torch.FloatTensor of size 4x4]
(
0.5382
2.1519
0.2644
0.5071
[torch.FloatTensor of size 4]
,
2
2
1
3
[torch.LongTensor of size 4]
)