使用 PyTorch 進(jìn)行簡單的回歸與分類

最近正在學(xué) PyTorch,作為一個(gè)初學(xué)者僻澎,在此分享下自己所學(xué)到的一些皮毛吧捧挺。



這里不對(duì)神經(jīng)網(wǎng)絡(luò)的概念作太多的介紹,重點(diǎn)是如何使用 Pytorch 搭建一個(gè)簡單的神經(jīng)網(wǎng)絡(luò)來作一些簡單的回歸與分類冶共。

神經(jīng)網(wǎng)絡(luò)簡單介紹

這是一個(gè)包含三個(gè)層次的神經(jīng)網(wǎng)絡(luò)。紅色的是輸入層每界,紫色的是中間層(也叫隱藏層)捅僵,綠色的是輸出層。
輸入層有3個(gè)輸入單元眨层,隱藏層有4個(gè)單元命咐,輸出層有2個(gè)單元。

image.png

這里就不對(duì)神經(jīng)網(wǎng)絡(luò)作太多的介紹谐岁,網(wǎng)上有很多很好的博客醋奠,大家可以去查閱查閱榛臼。

使用 PyTorch 進(jìn)行簡單的回歸

你所需要安裝的 python 庫是 pytorch 和 matplotlib。如果你正確安裝了這兩個(gè)庫窜司,并且使用的是python3沛善,那么理論上你就可以使用它了,兩個(gè)程序參考了 github 上一位莫煩大神的教程(https://github.com/MorvanZhou
)塞祈,我對(duì)其做了一些“魔改”金刁,簡單封裝了下,使其更容易“調(diào)參”议薪。

激勵(lì)函數(shù)使用的是 relu


image.png

以上是常見的激勵(lì)函數(shù)尤蛮,具體什么場景使用怎么的激勵(lì)函數(shù)這里就不作贅述了。
程序里有詳細(xì)的注釋斯议,還是見程序吧:

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.animation as animation


class MineNet(torch.nn.Module):
    def __init__(self, _net):
        super(MineNet, self).__init__()
        self.n_feature = _net[0]
        self.n_hidden = _net[1]
        self.n_output = _net[2]
        """torch.nn.Linear(self, in_features, out_features, bias=True)
        in_features : 前一層網(wǎng)絡(luò)神經(jīng)元的個(gè)數(shù)
        out_features : 該網(wǎng)絡(luò)層神經(jīng)元的個(gè)數(shù)
        """
        '''隱藏層線性輸出'''
        self.hidden = torch.nn.Linear(self.n_feature, self.n_hidden)
        '''輸出層線性輸出'''
        self.predict = torch.nn.Linear(self.n_hidden, self.n_output)

    def forward(self, values):
        """
        正向傳播輸入值, 神經(jīng)網(wǎng)絡(luò)分析出輸出值
        :param values:
        :return:  輸出值
        """
        '''激勵(lì)函數(shù)(隱藏層的線性值)'''
        '''relu: x<=0 y=0;x>0 y=x'''
        values = torch.relu(self.hidden(values))
        return self.predict(values)


class Net(object):
    def __init__(self, x, y, count, lr, mine_net):
        """
        :param x: 自變量
        :param y: 因變量
        :param count: 訓(xùn)練次數(shù)
        :param lr: 學(xué)習(xí)效率
        :param mine_net: MineNet對(duì)象
        """
        self.x = Variable(x)
        self.y = Variable(y)
        self.count = count
        self.lr = lr
        self.net = mine_net
        '''net 的所有參數(shù),學(xué)習(xí)率'''
        self.optimizer = torch.optim.SGD(self.net.parameters(), lr=self.lr)
        '''預(yù)測值和真實(shí)值的誤差計(jì)算公式 (均方差)'''
        self.loss_fun = torch.nn.MSELoss()

    def train_show(self):
        """
        訓(xùn)練與可視化
        :return:
        """
        plt.ion()
        plt.show()
        for t in range(self.count):
            '''給net訓(xùn)練數(shù)據(jù), 輸出預(yù)測值'''
            prediction = self.net(self.x)
            '''計(jì)算兩者的誤差'''
            loss = self.loss_fun(prediction, self.y)
            '''清空上一步的殘余更新參數(shù)值'''
            self.optimizer.zero_grad()
            '''誤差反向傳播, 計(jì)算參數(shù)更新值'''
            loss.backward()
            '''將參數(shù)更新值施加到net的parameters上'''
            self.optimizer.step()
            '''作圖'''
            plt.cla()
            plt.scatter(self.x.data.numpy(), self.y.data.numpy())
            plt.plot(self.x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
            plt.text(0.5, 0, 'Count =%.d\nLoss=%.4f' % (t + 1, loss.data.numpy()),
                     fontdict={'size': 14, 'color': 'red'})
            plt.pause(0.1)

        plt.ioff()
        plt.show()


if __name__ == '__main__':
    '''自變量'''
    _x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
    '''因變量'''
    _y = _x.pow(2) + 0.2 * torch.rand(_x.size())
    '''學(xué)習(xí)次數(shù)'''
    c = 300
    '''學(xué)習(xí)效率'''
    _lr = 0.5
    net_list = [1, 10, 1]
    net = MineNet(net_list)
    n = Net(_x, _y, c, _lr, net)
    n.train_show()


呃产捞,由于我暫時(shí)還不太會(huì)把這個(gè)過程保存為動(dòng)圖,這里暫時(shí)把過程的結(jié)果展示下哼御,之后再放動(dòng)圖坯临。
結(jié)果如下:


image.png

注:Count 是學(xué)習(xí)次數(shù),Loss 是誤差恋昼。

使用 Pytorch 進(jìn)行簡單的分類

分類在回歸的基礎(chǔ)上作一些修改就可以了看靠。
程序中有詳細(xì)的注釋:

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F


class MineNet(torch.nn.Module):
    def __init__(self, _net):
        super(MineNet, self).__init__()
        self.n_feature = _net[0]
        self.n_hidden = _net[1]
        self.n_output = _net[2]
        """torch.nn.Linear(self, in_features, out_features, bias=True)
        in_features : 前一層網(wǎng)絡(luò)神經(jīng)元的個(gè)數(shù)
        out_features : 該網(wǎng)絡(luò)層神經(jīng)元的個(gè)數(shù)
        """
        '''隱藏層線性輸出'''
        self.hidden = torch.nn.Linear(self.n_feature, self.n_hidden)
        '''輸出層線性輸出'''
        self.output = torch.nn.Linear(self.n_hidden, self.n_output)

    def forward(self, values):
        """
        正向傳播輸入值, 神經(jīng)網(wǎng)絡(luò)分析出輸出值
        :param values:
        :return:  輸出值
        """
        '''激勵(lì)函數(shù)(隱藏層的線性值)'''
        '''relu: x<=0 y=0;x>0 y=x'''
        values = F.relu(self.hidden(values))
        return self.output(values)


class Net(object):
    def __init__(self, x, y, count, lr, mine_net):
        """
        :param x: 自變量
        :param y: 因變量
        :param count: 訓(xùn)練次數(shù)
        :param lr: 學(xué)習(xí)效率
        :param mine_net: MineNet對(duì)象
        """
        self.x = torch.cat((x[0], x[1]), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit floating
        self.y = torch.cat((y[0], y[1]), ).type(torch.LongTensor)  # LongTensor = 64-bit integer
        self.count = count
        self.lr = lr
        self.net = mine_net
        '''net 的所有參數(shù),學(xué)習(xí)率'''
        self.optimizer = torch.optim.SGD(self.net.parameters(), lr=self.lr)
        '''預(yù)測值和真實(shí)值的誤差計(jì)算公式 (均方差)'''
        self.loss_fun = torch.nn.CrossEntropyLoss()

    def train_show(self):
        """
        訓(xùn)練與可視化
        :return:
        """
        plt.ion()
        plt.show()
        for t in range(self.count):
            '''給net訓(xùn)練數(shù)據(jù), 輸出預(yù)測值'''
            out = self.net(self.x)
            '''計(jì)算兩者的誤差'''
            loss = self.loss_fun(out, self.y)
            '''清空上一步的殘余更新參數(shù)值'''
            self.optimizer.zero_grad()
            '''誤差反向傳播, 計(jì)算參數(shù)更新值'''
            loss.backward()
            '''將參數(shù)更新值施加到net的parameters上'''
            self.optimizer.step()
            '''作圖'''
            plt.cla()
            '''經(jīng)過 softmax 的激勵(lì)函數(shù)后的最大概率才是預(yù)測值'''
            prediction = torch.max(F.softmax(out, dim=1), 1)[1]
            predict_y = prediction.data.numpy().squeeze()
            target_y = self.y.data.numpy()
            plt.scatter(self.x.data.numpy()[:, 0], self.x.data.numpy()[:, 1], c=predict_y, s=100, lw=0, cmap='RdYlGn')
            ''''預(yù)測中有多少和真實(shí)值一樣'''
            accuracy = sum(predict_y == target_y) / 200.
            plt.text(1.5, -4, 'Count =%.d\nAccuracy=%.2f' % (t + 1, accuracy), fontdict={'size': 14, 'color': 'red'})
            plt.pause(0.1)

        plt.ioff()
        plt.show()


if __name__ == '__main__':
    n = torch.ones(100, 2)
    '''自變量'''
    _x = [torch.normal(2 * n, 1), torch.normal(-2 * n, 1)]
    '''因變量'''
    _y = [torch.zeros(100), torch.ones(100)]
    '''學(xué)習(xí)次數(shù)'''
    c = 60
    '''學(xué)習(xí)效率'''
    _lr = 0.02
    net_list = [2, 10, 2]
    net = MineNet(net_list)
    n = Net(_x, _y, c, _lr, net)
    n.train_show()



結(jié)果如下:


image.png

注:Count 是學(xué)習(xí)次數(shù),Accuracy 是準(zhǔn)確率液肌。

可以在 if name == 'main': 中調(diào)節(jié)部分參數(shù)挟炬。
就本人的使用來看,鑒于隱藏層和輸出層用的是 torch.nn.Linear嗦哆,如果在作回歸時(shí)谤祖,x 與 y 是多項(xiàng)式關(guān)系,擬合效果還不錯(cuò)吝秕,但是對(duì)于其他關(guān)系,效果可能就不太好了空幻,當(dāng)然你也可以調(diào)節(jié)其他參數(shù)烁峭,比如學(xué)習(xí)次數(shù),學(xué)習(xí)效率等等秕铛,得到的結(jié)果每次都會(huì)有差別约郁。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市但两,隨后出現(xiàn)的幾起案子鬓梅,更是在濱河造成了極大的恐慌,老刑警劉巖谨湘,帶你破解...
    沈念sama閱讀 218,386評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件绽快,死亡現(xiàn)場離奇詭異芥丧,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)坊罢,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,142評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門续担,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人活孩,你說我怎么就攤上這事物遇。” “怎么了憾儒?”我有些...
    開封第一講書人閱讀 164,704評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵询兴,是天一觀的道長。 經(jīng)常有香客問我起趾,道長诗舰,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,702評(píng)論 1 294
  • 正文 為了忘掉前任阳掐,我火速辦了婚禮始衅,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘缭保。我一直安慰自己汛闸,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,716評(píng)論 6 392
  • 文/花漫 我一把揭開白布艺骂。 她就那樣靜靜地躺著诸老,像睡著了一般。 火紅的嫁衣襯著肌膚如雪钳恕。 梳的紋絲不亂的頭發(fā)上别伏,一...
    開封第一講書人閱讀 51,573評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音忧额,去河邊找鬼厘肮。 笑死,一個(gè)胖子當(dāng)著我的面吹牛睦番,可吹牛的內(nèi)容都是我干的类茂。 我是一名探鬼主播,決...
    沈念sama閱讀 40,314評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼托嚣,長吁一口氣:“原來是場噩夢(mèng)啊……” “哼巩检!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起示启,我...
    開封第一講書人閱讀 39,230評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤兢哭,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后夫嗓,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體迟螺,經(jīng)...
    沈念sama閱讀 45,680評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡冲秽,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,873評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了煮仇。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片劳跃。...
    茶點(diǎn)故事閱讀 39,991評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖浙垫,靈堂內(nèi)的尸體忽然破棺而出刨仑,到底是詐尸還是另有隱情,我是刑警寧澤夹姥,帶...
    沈念sama閱讀 35,706評(píng)論 5 346
  • 正文 年R本政府宣布杉武,位于F島的核電站,受9級(jí)特大地震影響辙售,放射性物質(zhì)發(fā)生泄漏轻抱。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,329評(píng)論 3 330
  • 文/蒙蒙 一旦部、第九天 我趴在偏房一處隱蔽的房頂上張望祈搜。 院中可真熱鬧,春花似錦士八、人聲如沸容燕。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,910評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽蘸秘。三九已至,卻和暖如春蝗茁,著一層夾襖步出監(jiān)牢的瞬間醋虏,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,038評(píng)論 1 270
  • 我被黑心中介騙來泰國打工哮翘, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留颈嚼,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,158評(píng)論 3 370
  • 正文 我出身青樓饭寺,卻偏偏與公主長得像阻课,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子佩研,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,941評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容