操練代碼之LSTM

得不到股票數(shù)據(jù)慧脱,接口過(guò)期了渗柿,先記錄代碼

一瘩欺,代碼

# 使用LSTM神經(jīng)網(wǎng)絡(luò),對(duì)股票進(jìn)行預(yù)測(cè)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import tushare as ts
from tqdm import tqdm

import torch.utils.data as Data
import matplotlib
import matplotlib.pyplot as plt

from copy import deepcopy as copy


# 獲取數(shù)據(jù)
class GetData:
    def __init__(self, stock_id, save_path):
        self.stock_id = stock_id
        self.save_path = save_path
        self.data = None

    def get_data(self):
        self.data = ts.get_hist_data(self.stock_id).iloc[::-1]
        self.data = self.data[['open', 'close', 'high', 'low', 'volume']]
        self.close_min = self.data['volume'].min()
        self.close_max = self.data['volume'].max()
        self.data = self.data.apply(lambda x: (x-min(x))/(max(x)-min(x)))
        self.data.to_csv(self.save_path)

    def process_data(self, n):
        if self.data is None:
            self.get_data()
        feature = [
            self.data.iloc[i:i+n].values.tolist()
            for i in range(len(self.data) - n + 2)
            if i + n < len(self.data)
        ]
        label = [
            self.data.close.values[i + n]
            for i in range(len(self.data) - n + 2)
            if i + n < len(self.data)
        ]
        train_x = feature[:500]
        test_x = feature[500:]
        train_y = label[:500]
        test_y = label[500:]

        return train_x, test_x, train_y, test_y


# 搭建LSTM模型
class Model(nn.Module):
    def __init__(self, n):
        super(Model, self).__init__()
        self.lstm_layer = nn.LSTM(input_size=n, hidden_size=256, batch_first=True)
        self.linear_layer = nn.Linear(in_features=256, out_features=1, bias=True)

    def forward(self, x):
        out1, (h_n, h_c) = self.lstm_layer(x)
        a, b, c = h_n.shape
        out2 = self.linear_layer(h_n.reshape(a*b, c))
        return out2


# 訓(xùn)練模型赞咙,計(jì)算損失等
def train_module(epoch, train_data_loader, test_data_loader):
    best_model = None
    train_loss = 0
    test_loss = 0
    best_loss = 100
    epoch_cnt = 0
    for _ in range(epoch):
        total_train_loss = 0
        total_train_num = 0
        total_test_loss = 0
        total_test_num = 0
        for x, y in tqdm(train_data_loader, desc='Epoch: {}|Train Losss:{}|TestLoss:{}'.format(
            _, train_loss, test_loss
        )):
            x_num = len(x)
            p = model(x)
            loss = loss_func(p, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            total_train_num += x_num

        train_loss = total_train_loss / total_train_num
        for x, y, in test_data_loader:
            x_num = len(x)
            p = model(x)
            loss = loss_func(p ,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_test_loss += loss.item()
            total_test_num += x_num

        test_loss = total_test_loss / total_test_num

        if best_loss > test_loss:
            best_loss = test_loss
            best_model = copy(model)
            epoch_cnt = 0
        else:
            epoch_cnt -= 1

        if epoch_cnt > early_stop:
            torch.save(best_model.state_dict(), './lstm_.pth')
            break


# 測(cè)試模型
def test_model(test_data_loader_):
    pred = []
    label = []
    model_ = Model(5)
    model_.load_state_dict(torch.load('./lstm_.pth'))
    model_.eval()
    total_test_loss = 0
    total_test_num = 0

    for x, y in test_data_loader_:
        x_num = len(x)
        p = model_(x)
        loss = loss_func(p, y)
        total_test_loss += loss.item()
        total_test_num += x_num
        pred.extend(p.data.squeeze(1).tolist())
        label.extend(y.tolist())
    test_loss = total_test_loss / total_test_num
    return pred, label, test_loss

# 繪制折線圖
def plot_img(data, pred):
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.figure(figsize=(12, 7))
    plt.plot(range(len(pred)), pred, color='green')
    plt.plot(range(len(data)), data, color='blue')
    for i in range(0, len(pred)-3, 5):
        price = [data[i]+pred[j]-pred[i] for j in range(i, i+3)]
        plt.plot(range(i, i+3), price, color='red')
    plt.xticks(fontproperties='Times New Roman', size='15')
    plt.yticks(fontproperties='Times New Roman', size='15')
    plt.xlabel('日期', fontsize=18)
    plt.ylabel('日期', fontsize=18)
    plt.show()


if __name__ == '__main__':
    days_num = 5
    epoch = 20
    fea = 5
    batch_size = 20
    early_stop = 5
    model = Model(fea)
    GD = GetData(stock_id='000001', save_path='./data.csv')
    x_train, x_test, y_train, y_test = GD.process_data(days_num)
    x_train = torch.tensor(x_train).float()
    x_test = torch.tensor(x_test).float()
    y_train = torch.tensor(y_train).float()
    y_test = torch.tensor(y_test).float()

    train_data = TensorDataset(x_train, y_train)
    train_data_loader = DataLoader(train_data, batch_size=batch_size)
    test_data = TensorDataset(x_test, y_test)
    test_data_loader = DataLoader(test_data, batch_size=batch_size)

    loss_func = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train_module(epoch, train_data_loader, test_data_loader)
    p, y, test_loss = test_model(test_data_loader)
    pred = [ele * (GD.close_max - GD.close_min) + GD.close_min for ele in p]
    data = [ele * (GD.close_max - GD.close_min) + GD.close_min for ele in y]
    plot_img(data, pred)
    print('模型損失:', test_loss)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市糟港,隨后出現(xiàn)的幾起案子人弓,更是在濱河造成了極大的恐慌,老刑警劉巖着逐,帶你破解...
    沈念sama閱讀 219,539評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件崔赌,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡耸别,警方通過(guò)查閱死者的電腦和手機(jī)健芭,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,594評(píng)論 3 396
  • 文/潘曉璐 我一進(jìn)店門(mén),熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)秀姐,“玉大人慈迈,你說(shuō)我怎么就攤上這事∈∮校” “怎么了痒留?”我有些...
    開(kāi)封第一講書(shū)人閱讀 165,871評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)蠢沿。 經(jīng)常有香客問(wèn)我伸头,道長(zhǎng),這世上最難降的妖魔是什么舷蟀? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,963評(píng)論 1 295
  • 正文 為了忘掉前任恤磷,我火速辦了婚禮面哼,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘扫步。我一直安慰自己魔策,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,984評(píng)論 6 393
  • 文/花漫 我一把揭開(kāi)白布河胎。 她就那樣靜靜地躺著闯袒,像睡著了一般。 火紅的嫁衣襯著肌膚如雪游岳。 梳的紋絲不亂的頭發(fā)上搁吓,一...
    開(kāi)封第一講書(shū)人閱讀 51,763評(píng)論 1 307
  • 那天,我揣著相機(jī)與錄音吭历,去河邊找鬼堕仔。 笑死,一個(gè)胖子當(dāng)著我的面吹牛晌区,可吹牛的內(nèi)容都是我干的摩骨。 我是一名探鬼主播,決...
    沈念sama閱讀 40,468評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼朗若,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼恼五!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起哭懈,我...
    開(kāi)封第一講書(shū)人閱讀 39,357評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤灾馒,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后遣总,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體睬罗,經(jīng)...
    沈念sama閱讀 45,850評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,002評(píng)論 3 338
  • 正文 我和宋清朗相戀三年旭斥,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了容达。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,144評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡垂券,死狀恐怖花盐,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情菇爪,我是刑警寧澤算芯,帶...
    沈念sama閱讀 35,823評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站凳宙,受9級(jí)特大地震影響熙揍,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜近速,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,483評(píng)論 3 331
  • 文/蒙蒙 一诈嘿、第九天 我趴在偏房一處隱蔽的房頂上張望堪旧。 院中可真熱鬧削葱,春花似錦奖亚、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,026評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至首繁,卻和暖如春作郭,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背弦疮。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,150評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工夹攒, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人胁塞。 一個(gè)月前我還...
    沈念sama閱讀 48,415評(píng)論 3 373
  • 正文 我出身青樓咏尝,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親啸罢。 傳聞我的和親對(duì)象是個(gè)殘疾皇子编检,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,092評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容