Datawhale AI 夏令營(yíng) AI for Science 催化反應(yīng)產(chǎn)率預(yù)測(cè) Task2 Baseline 學(xué)習(xí)

簡(jiǎn)介

這一次課程使用RNN進(jìn)行催化反應(yīng)產(chǎn)率預(yù)測(cè)购岗。

總體想法是把數(shù)據(jù)集中的反應(yīng)物和產(chǎn)物通過(guò)SMILES字符串表示出來(lái)铣缠,然后根據(jù)基本的原子揭糕、連接鍵等將化學(xué)反應(yīng)對(duì)應(yīng)的SMILES字符串轉(zhuǎn)化為整數(shù)序列柱蟀,再通過(guò)RNN進(jìn)行訓(xùn)練和預(yù)測(cè)击费。

下面的代碼是對(duì)DataWhale提供的baseline的學(xué)習(xí)卿拴。

過(guò)程

1. 導(dǎo)入必要的庫(kù)

import re
import time
import pandas as pd
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

re模塊是為了從化學(xué)反應(yīng)的SMILES字符串中提取出反應(yīng)的基本組成元素。

RNN主要使用pytorch中的RNN畜眨。

2. 定義RNN模型

# 定義RNN模型
class RNNModel(nn.Module):
    def __init__(self, num_embed, input_size, hidden_size, output_size, num_layers, dropout, device):
        super(RNNModel, self).__init__()
        self.embed = nn.Embedding(num_embed, input_size)
        self.rnn = nn.RNN(input_size, hidden_size, num_layers=num_layers, 
                          batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Sequential(nn.Linear(2 * num_layers * hidden_size, output_size),
                                nn.Sigmoid(),
                                nn.Linear(output_size, 1),
                                nn.Sigmoid())

    def forward(self, x):
        # x : [bs, seq_len]
        x = self.embed(x)
        # x : [bs, seq_len, input_size]
        _, hn = self.rnn(x) # hn : [2*num_layers, bs, h_dim]
        hn = hn.transpose(0,1)
        z = hn.reshape(hn.shape[0], -1) # z shape: [bs, 2*num_layers*h_dim]
        output = self.fc(z).squeeze(-1) # output shape: [bs, 1]
        return output
  1. 這個(gè)RNNModel包括了三層:

    • 嵌入層:nn.Embedding(num_embeddings, embedding_dim)鸣驱,將整數(shù)索引的序列轉(zhuǎn)換為稠密向量表示:

      • num_embeddings:詞匯表的大小泛鸟,由于vocab_full.txt有294行,這里填寫294
      • embedding_dim:每個(gè)嵌入向量的維度大小踊东,也就是輸出向量(即RNN層輸入向量)的大小北滥,這里是input_size
      • 輸入nn.Embedding(num_embeddings, embedding_dim)(x):這里面x的shape是(batch_size, seq_len),seq_len是每個(gè)樣本的元素?cái)?shù)量
      • 輸出:即RNN層的輸入闸翅,shape是(batch_size, seq_len, input_size)
    • RNN層:nn.rnn(input_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True):

      • input_size:輸入張量的特征數(shù)再芋,即輸入RNN的單元的向量的大小。輸入300是因?yàn)?/li>
      • hidden_size:RNN 單元的隱藏狀態(tài)向量的大小坚冀,也是輸出向量的大小
      • num_layers:RNN 層的數(shù)量
      • dropout:浮點(diǎn)數(shù)济赎,表示在訓(xùn)練過(guò)程中 RNN 層之間的輸出被丟棄的概率,可防止過(guò)擬合
      • bidirectional:布爾值遗菠,指示是否使用雙向 RNN联喘。若為True,則是雙向 RNN
    • 全連接層:

      • nn.Linear辙纬,包括兩個(gè)線性層和兩個(gè) Sigmoid 激活函數(shù),將RNN層的輸出轉(zhuǎn)化為一個(gè)預(yù)測(cè)結(jié)果
  2. 前向傳播self.forward(x)方法:

    輸入的數(shù)據(jù)x是DataLoader劃分后得數(shù)據(jù)叭喜。x的shape是(batch_size, seq_len)贺拣。

    這個(gè) forward方法首先通過(guò)嵌入層將輸入 x 轉(zhuǎn)換為嵌入表示,然后將其饋送到 RNN 層中。之后譬涡,它處理 RNN 的最后一個(gè)隱藏狀態(tài) (hn)闪幽,并將其通過(guò)全連接層產(chǎn)生最終的輸出。
    由于hn的shape是(2*num_layers, batch_size, hidden_size)涡匀,通過(guò)hn.transpose(0,1)將其轉(zhuǎn)換為( batch_size, 2*num_layers, hidden_size)盯腌,以便作為全連接層的輸入。

    z = hn.reshape(hn.shape[0], -1)陨瘩,將hn的shape變?yōu)?bs, 2*num_layers*h_dim)腕够。

    output = self.fc(z).squeeze(-1) ,將張量z經(jīng)過(guò)全連接層處理舌劳,最終得到模型的輸出帚湘。

    這幾步算是前向傳播的流水線操作,必須記住這些步驟甚淡。

3. 數(shù)據(jù)預(yù)處理

# tokenizer大诸,鑒于SMILES的特性,這里需要自己定義tokenizer和vocab
# 這里直接將smiles str按字符拆分贯卦,并替換為詞匯表中的序號(hào)
class Smiles_tokenizer():
    def __init__(self, pad_token, regex, vocab_file, max_length):
        self.pad_token = pad_token
        self.regex = regex
        self.vocab_file = vocab_file
        self.max_length = max_length

        with open(self.vocab_file, "r") as f:
            lines = f.readlines()
        lines = [line.strip("\n") for line in lines]
        vocab_dic = {}
        for index, token in enumerate(lines):
            vocab_dic[token] = index
        self.vocab_dic = vocab_dic

    def _regex_match(self, smiles):
        regex_string = r"(" + self.regex + r"|"
        regex_string += r".)"
        prog = re.compile(regex_string)

        tokenised = []
        for smi in smiles:
            tokens = prog.findall(smi)
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            tokenised.append(tokens) # 返回一個(gè)所有的字符串列表
        return tokenised
    
    def tokenize(self, smiles):
        tokens = self._regex_match(smiles)
        # 添加上表示開(kāi)始和結(jié)束的token:<cls>, <end>
        tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
        tokens = self._pad_seqs(tokens, self.pad_token)
        token_idx = self._pad_token_to_idx(tokens)
        return tokens, token_idx

    def _pad_seqs(self, seqs, pad_token):
        pad_length = max([len(seq) for seq in seqs])
        padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
        return padded

    def _pad_token_to_idx(self, tokens):
        idx_list = []
        for token in tokens:
            tokens_idx = []
            for i in token:
                if i in self.vocab_dic.keys():
                    tokens_idx.append(self.vocab_dic[i])
                else:
                    self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
                    tokens_idx.append(self.vocab_dic[i])
            idx_list.append(tokens_idx)
        
        return idx_list

# 讀數(shù)據(jù)并處理
def read_data(file_path, train=True):
    df = pd.read_csv(file_path)
    reactant1 = df["Reactant1"].tolist()
    reactant2 = df["Reactant2"].tolist()
    product = df["Product"].tolist()
    additive = df["Additive"].tolist()
    solvent = df["Solvent"].tolist()
    if train:
        react_yield = df["Yield"].tolist()
    else:
        react_yield = [0 for i in range(len(reactant1))]
    
    # 將reactant拼到一起资柔,之間用.分開(kāi)。product也拼到一起撵割,用>分開(kāi)
    input_data_list = []
    for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
        input_info = ".".join([react1, react2])
        input_info = ">".join([input_info, prod])
        input_data_list.append(input_info)
    output = [(react, y) for react, y in zip(input_data_list, react_yield)]

    return output

class ReactionDataset(Dataset):
    def __init__(self, data: List[Tuple[List[str], float]]):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def collate_fn(batch):
    REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
    tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
    smi_list = []
    yield_list = []
    for i in batch:
        smi_list.append(i[0])
        yield_list.append(i[1])
    tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
    yield_list = torch.tensor(yield_list)
    return tokenizer_batch, yield_list

這里面有兩個(gè)類贿堰,一個(gè)是對(duì)數(shù)據(jù)集進(jìn)行處理的,一個(gè)是自定義的反應(yīng)數(shù)據(jù)集睁枕。詳細(xì)看一下Smiles_tokenizer類的主要作用:

3.1 從read_data函數(shù)開(kāi)始
def read_data(file_path, train=True):
    df = pd.read_csv(file_path)
    reactant1 = df["Reactant1"].tolist()
    reactant2 = df["Reactant2"].tolist()
    product = df["Product"].tolist()
    additive = df["Additive"].tolist()
    solvent = df["Solvent"].tolist()
    if train:
        react_yield = df["Yield"].tolist()
    else:
        react_yield = [0 for i in range(len(reactant1))]
    input_data_list = []

這一部分代碼是講數(shù)據(jù)集中的各列提取出來(lái)并轉(zhuǎn)化為列表官边,其中反應(yīng)物、產(chǎn)物外遇、催化劑和溶劑均為SMILES字符串組成的列表注簿,對(duì)于訓(xùn)練集,react_yield是各反應(yīng)對(duì)應(yīng)的產(chǎn)率組成的列表跳仿,對(duì)于測(cè)試集诡渴,react_yield都是0構(gòu)成的列表。經(jīng)過(guò)這一步菲语,我們可以得到下面各個(gè)列表:

reactant1:['c1ccc2c(c1)Nc1ccccc1O2','c1ccc2c(c1)Nc1ccccc1O2', ...]
reactant2:['Brc1ccccc1I', 'Brc1ccccc1I', ...]
product:['Brc1ccccc1N1c2ccccc2Oc2ccccc21','Brc1ccccc1N1c2ccccc2Oc2ccccc21', ...]
additive:['CC(C)(C)[O-].CC(C)(C)[PH+](C(C)(C)C)C(C)(C)C.F[B-](F)(F)F.F[B-](F)(F)F.O=C(C=Cc1ccccc1)C=Cc1ccccc1.O=C(C=Cc1ccccc1)C=Cc1ccccc1.[H+].[Na+].[Pd]','C1COCCOCCOCCOCCOCCO1.O=C([O-])[O-].[Cu+].[I-].[K+].[K+]',]
solvent:['Cc1ccccc1', 'Clc1ccccc1Cl', ...]
react_yield:[0.78, 0.9, ...]
def read_data(file_path, train=True):
    ...
    input_data_list = []
    for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
        input_info = ".".join([react1, react2])
        input_info = ">".join([input_info, prod])
        input_data_list.append(input_info)
        output = [(react, y) for react, y in zip(input_data_list, react_yield)]
    return output

這一部分是將上面得到的各列表轉(zhuǎn)化為“反應(yīng)物1.反應(yīng)物2>產(chǎn)物的SMILES字符串”妄辩,然后輸出“(反應(yīng)物1.反應(yīng)物2>產(chǎn)物的SMILES字符串, 產(chǎn)率)”組成的列表:

[('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.78),
 ('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.9),
...
]
3.2 Smiles_tokenizer類

首先從構(gòu)造函數(shù)開(kāi)始:

class Smiles_tokenizer():
    def __init__(self, pad_token, regex, vocab_file, max_length):
        self.pad_token = pad_token
        self.regex = regex
        self.vocab_file = vocab_file
        self.max_length = max_length

        with open(self.vocab_file, "r") as f:
            lines = f.readlines()
        lines = [line.strip("\n") for line in lines]
        vocab_dic = {}
        for index, token in enumerate(lines):
            vocab_dic[token] = index
        self.vocab_dic = vocab_dic

這個(gè)構(gòu)造函數(shù)主要形成一個(gè)由各基本原子、基團(tuán)以及化學(xué)鍵等為鍵組成的字典山上。其中每個(gè)項(xiàng)都由“基礎(chǔ) : 整數(shù)序號(hào)”組成:

vocab_dic:
{'<PAD>': 0,
 '<CLS>': 1,
 '<MASK>': 2,
 '<SEP>': 3,
 '[UNK]': 4,
 '>': 5,
 'C': 6,
...
}

然后是_regex_match方法:

class Smiles_tokenizer():
    ...
    def _regex_match(self, smiles):
        regex_string = r"(" + self.regex + r"|"
        regex_string += r".)"
        prog = re.compile(regex_string)

        tokenised = []
        for smi in smiles:
            tokens = prog.findall(smi)
            if len(tokens) > self.max_length:
                tokens = tokens[:self.max_length]
            tokenised.append(tokens) # 返回一個(gè)所有的字符串列表
        return tokenised

上面的正則表達(dá)式是:

REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
self.regex = REGEX

這個(gè)正則表達(dá)式匹配的內(nèi)容包括:方括號(hào)里面的內(nèi)容(例如:[NH4+])眼耀、N、P佩憾、Br哮伟、.干花、@、=等符號(hào)楞黄。

_regex_match傳入的smiles參數(shù)池凄,是一個(gè)SMILES字符串表示的反應(yīng)方程式構(gòu)成的列表。通過(guò)for循環(huán)鬼廓,不斷用這個(gè)正則表達(dá)式將SMILES字符串表示的反應(yīng)方程式中的各個(gè)元素提起出來(lái)肿仑,作為tokens,也就是說(shuō)tokens是反應(yīng)方程式中的各個(gè)元素組成的列表(不妨稱之為:特征元素列表)碎税,而tokenised就是整個(gè)smiles對(duì)應(yīng)的特征元素列表:

# 下面是一個(gè)第四個(gè)反應(yīng)的tokens的例子尤慰,是個(gè)一維列表
tokens:['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1']
# tokenised是嵌套列表
tokenised:[
    ['c', '1', 'c', 'c', 'c', '2', 'c', '(', 'c', '1', ')', 'N', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'O', '2', '.', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'I', '>', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'N', '1', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', 'O', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', '1']
    ...
]

接著看_pad_seqs和_pad_token_to_idx方法:

class Smiles_tokenizer():
    ...
    def _pad_seqs(self, seqs, pad_token):
        pad_length = max([len(seq) for seq in seqs])
        padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
        return padded

    def _pad_token_to_idx(self, tokens):
        idx_list = []
        for token in tokens:
            tokens_idx = []
            for i in token:
                if i in self.vocab_dic.keys():
                    tokens_idx.append(self.vocab_dic[i])
                else:
                    self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
                    tokens_idx.append(self.vocab_dic[i])
            idx_list.append(tokens_idx)
        
        return idx_list

這兩個(gè)方法,第一個(gè)是把所有反應(yīng)方程蚣录,通過(guò)填充<PAD>(也就是0)割择,變成相同的長(zhǎng)度,也就是把每個(gè)方程對(duì)應(yīng)的tokens變成擁有相同的元素?cái)?shù)量萎河,便于后續(xù)的處理荔泳。

第二個(gè)方法是根據(jù)前面的vocab_dic,把tokenised中的所有特征元素變成整數(shù):

# 下面是tokenised在padded后的形式虐杯,其中展示的列表是填充后的第四個(gè)反應(yīng)
padded:[
    ...
  ['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
    ...
]
# 下面是tokenised在padded后對(duì)應(yīng)的idx_list玛歌,其中展示的列表是填充后的第四個(gè)反應(yīng)
idx_list:[
    ...
    [6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ...
]

最后一下 tokenize方法:

class Smiles_tokenizer():
    ...    
    def tokenize(self, smiles):
        tokens = self._regex_match(smiles)
        # 添加上表示開(kāi)始和結(jié)束的token:<cls>, <end>
        tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
        tokens = self._pad_seqs(tokens, self.pad_token)
        token_idx = self._pad_token_to_idx(tokens)
        return tokens, token_idx

這個(gè)方法是把前面的幾個(gè)方法匯總到一起,也就是最后返回的tokens就是上面padded的形式(前后添加了"<CLS>"和"<SEP>")擎椰,token_idx就是idx_list的形式:

tokens:[
    ...
  ['<CLS>', 'C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<SEP>']
    ...
]

token_idx:[
    ...
    [1, 6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ...
]

至此支子,就把文本型的數(shù)據(jù)轉(zhuǎn)化為了由整數(shù)構(gòu)成的數(shù)據(jù)了,而collate_fn函數(shù)就是通過(guò)Smiles_tokenizer類獲得的token_idx轉(zhuǎn)化成張量达舒。

def collate_fn(batch):
    REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
    tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
    smi_list = []
    yield_list = []
    for i in batch:
        smi_list.append(i[0])
        yield_list.append(i[1])
    tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
    yield_list = torch.tensor(yield_list)
    return tokenizer_batch, yield_list
tokenizer_batch:
tensor([[ 1,  7,  8,  7,  7,  7, 11,  7,  9,  7,  8, 15, 13,  7,  8,  7,  7,  7,
          7,  7,  8, 19, 11, 35, 17,  7,  8,  7,  7,  7,  7,  7,  8, 38,  5, 17,
          7,  8,  7,  7,  7,  7,  7,  8, 13,  8,  7, 11,  7,  7,  7,  7,  7, 11,
         19,  7, 11,  7,  7,  7,  7,  7, 11,  8,  3],
        [ 1,  7,  8,  7,  7,  7, 11,  7,  9,  7,  8, 15, 13,  7,  8,  7,  7,  7,
          7,  7,  8, 19, 11, 35, 17,  7,  8,  7,  7,  7,  7,  7,  8, 38,  5, 17,
          7,  8,  7,  7,  7,  7,  7,  8, 13,  8,  7, 11,  7,  7,  7,  7,  7, 11,
         19,  7, 11,  7,  7,  7,  7,  7, 11,  8,  3],
        ...
       ])
yield_list:
tensor([0.7800, 0.9000, ...])

4. 訓(xùn)練數(shù)據(jù)獲得模型

def train():
    ## super param
    N = 10  #int / int(len(dataset) * 1)  # 或者你可以設(shè)置為數(shù)據(jù)集大小的一定比例值朋,如 int(len(dataset) * 0.1)
    NUM_EMBED = 294 # nn.Embedding()
    INPUT_SIZE = 300 # src length
    HIDDEN_SIZE = 512
    OUTPUT_SIZE = 512
    NUM_LAYERS = 10
    DROPOUT = 0.2
    CLIP = 1 # CLIP value
    N_EPOCHS = 100
    LR = 0.0001
    
    start_time = time.time()  # 開(kāi)始計(jì)時(shí)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    data = read_data("../dataset/round1_train_data.csv")
    dataset = ReactionDataset(data)
    subset_indices = list(range(N))
    subset_dataset = Subset(dataset, subset_indices)
    train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

    model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=LR)
    # criterion = nn.MSELoss() # MSE
    criterion = nn.L1Loss() # MAE

    best_loss = 10
    for epoch in range(N_EPOCHS):
        epoch_loss = 0
        for i, (src, y) in enumerate(train_loader):
            src, y = src.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(src)
            loss = criterion(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)  # 使用范數(shù)裁剪梯度
            optimizer.step()
            epoch_loss += loss.item()
            loss_in_a_epoch = epoch_loss / len(train_loader)
        print(f'Epoch: {epoch+1:02} | Train Loss: {loss_in_a_epoch:.3f}')
        if loss_in_a_epoch < best_loss:
            # 在訓(xùn)練循環(huán)結(jié)束后保存模型
            torch.save(model.state_dict(), '../model/RNN.pth')
    end_time = time.time()  # 結(jié)束計(jì)時(shí)
    # 計(jì)算并打印運(yùn)行時(shí)間
    elapsed_time_minute = (end_time - start_time)/60
    print(f"Total running time: {elapsed_time_minute:.2f} minutes")

if __name__ == '__main__':
    train()

前面幾個(gè)大寫的字母是超參數(shù)。

通過(guò)DataLoader劃分?jǐn)?shù)據(jù)集巩搏,一共23538/128=184個(gè)batch昨登。

訓(xùn)練模型包括以下幾步:

  1. 實(shí)例化RNNModel并以train模式運(yùn)行;
  2. 選擇優(yōu)化器為 Adam
  3. 使用平均絕對(duì)誤差 (MAE) 作為損失函數(shù)
  4. 循環(huán)訓(xùn)練100次:
    • 遍歷每個(gè)批次的數(shù)據(jù)
    • 使用GPU設(shè)備進(jìn)行運(yùn)算
    • 清除梯度
    • 前向傳播計(jì)算輸出
    • 計(jì)算損失
    • 反向傳播計(jì)算梯度
    • 應(yīng)用梯度裁剪
    • 更新模型參數(shù)
    • 累加每個(gè)批次的損失
    • 計(jì)算整個(gè)周期的平均損失

5. 使用測(cè)試集驗(yàn)證模型

# 生成結(jié)果文件
def predicit_and_make_submit_file(model_file, output_file):
    NUM_EMBED = 294
    INPUT_SIZE = 300
    HIDDEN_SIZE = 512
    OUTPUT_SIZE = 512
    NUM_LAYERS = 10
    DROPOUT = 0.2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_data = read_data("../dataset/round1_test_data.csv", train=False)
    test_dataset = ReactionDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn) 

    model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
    # 加載最佳模型
    model.load_state_dict(torch.load(model_file))
    model.eval()
    output_list = []
    for i, (src, y) in enumerate(test_loader):
        src, y = src.to(device), y.to(device)
        with torch.no_grad():
            output = model(src)
            output_list += output.detach().tolist()
    ans_str_lst = ['rxnid,Yield']
    for idx,y in enumerate(output_list):
        ans_str_lst.append(f'test{idx+1},{y:.4f}')
    with open(output_file,'w') as fw:
        fw.writelines('\n'.join(ans_str_lst))

    print("done!!!")
    
predicit_and_make_submit_file("../model/RNN.pth",
                              "../output/RNN_submit.txt")

上一步已經(jīng)訓(xùn)練好了模型贯底,這一步使用這個(gè)模型對(duì)測(cè)試集進(jìn)行預(yù)測(cè)丰辣。

各超參數(shù)還是那些參數(shù),數(shù)據(jù)集換成測(cè)試集禽捆。

實(shí)例化RNNModel笙什,加載上一步保存的訓(xùn)練好的模型,然后設(shè)置模型為評(píng)估模式胚想。

遍歷測(cè)試DataLoader中的每個(gè)批次的數(shù)據(jù)琐凭,獲取輸入數(shù)據(jù)和標(biāo)簽。需要使用torch.no_grad()上下文管理器來(lái)禁用梯度計(jì)算浊服,提高預(yù)測(cè)速度淘正。

最后將預(yù)測(cè)結(jié)果寫入txt文件摆马。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末臼闻,一起剝皮案震驚了整個(gè)濱河市鸿吆,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌述呐,老刑警劉巖惩淳,帶你破解...
    沈念sama閱讀 216,496評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異乓搬,居然都是意外死亡思犁,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,407評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門进肯,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)激蹲,“玉大人,你說(shuō)我怎么就攤上這事江掩⊙瑁” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 162,632評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵环形,是天一觀的道長(zhǎng)策泣。 經(jīng)常有香客問(wèn)我,道長(zhǎng)抬吟,這世上最難降的妖魔是什么萨咕? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,180評(píng)論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮火本,結(jié)果婚禮上危队,老公的妹妹穿的比我還像新娘。我一直安慰自己钙畔,他們只是感情好茫陆,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,198評(píng)論 6 388
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著刃鳄,像睡著了一般盅弛。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上叔锐,一...
    開(kāi)封第一講書(shū)人閱讀 51,165評(píng)論 1 299
  • 那天挪鹏,我揣著相機(jī)與錄音,去河邊找鬼愉烙。 笑死讨盒,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的步责。 我是一名探鬼主播返顺,決...
    沈念sama閱讀 40,052評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼禀苦,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了遂鹊?” 一聲冷哼從身側(cè)響起振乏,我...
    開(kāi)封第一講書(shū)人閱讀 38,910評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎秉扑,沒(méi)想到半個(gè)月后慧邮,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,324評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡舟陆,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,542評(píng)論 2 332
  • 正文 我和宋清朗相戀三年误澳,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片秦躯。...
    茶點(diǎn)故事閱讀 39,711評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡忆谓,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出踱承,到底是詐尸還是另有隱情倡缠,我是刑警寧澤,帶...
    沈念sama閱讀 35,424評(píng)論 5 343
  • 正文 年R本政府宣布勾扭,位于F島的核電站毡琉,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏妙色。R本人自食惡果不足惜桅滋,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,017評(píng)論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望身辨。 院中可真熱鬧丐谋,春花似錦、人聲如沸煌珊。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,668評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)定庵。三九已至吏饿,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間蔬浙,已是汗流浹背猪落。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,823評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留畴博,地道東北人笨忌。 一個(gè)月前我還...
    沈念sama閱讀 47,722評(píng)論 2 368
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像俱病,于是被迫代替她去往敵國(guó)和親官疲。 傳聞我的和親對(duì)象是個(gè)殘疾皇子袱结,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,611評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容