簡(jiǎn)介
這一次課程使用RNN進(jìn)行催化反應(yīng)產(chǎn)率預(yù)測(cè)购岗。
總體想法是把數(shù)據(jù)集中的反應(yīng)物和產(chǎn)物通過(guò)SMILES字符串表示出來(lái)铣缠,然后根據(jù)基本的原子揭糕、連接鍵等將化學(xué)反應(yīng)對(duì)應(yīng)的SMILES字符串轉(zhuǎn)化為整數(shù)序列柱蟀,再通過(guò)RNN進(jìn)行訓(xùn)練和預(yù)測(cè)击费。
下面的代碼是對(duì)DataWhale提供的baseline的學(xué)習(xí)卿拴。
過(guò)程
1. 導(dǎo)入必要的庫(kù)
import re
import time
import pandas as pd
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
re模塊是為了從化學(xué)反應(yīng)的SMILES字符串中提取出反應(yīng)的基本組成元素。
RNN主要使用pytorch中的RNN畜眨。
2. 定義RNN模型
# 定義RNN模型
class RNNModel(nn.Module):
def __init__(self, num_embed, input_size, hidden_size, output_size, num_layers, dropout, device):
super(RNNModel, self).__init__()
self.embed = nn.Embedding(num_embed, input_size)
self.rnn = nn.RNN(input_size, hidden_size, num_layers=num_layers,
batch_first=True, dropout=dropout, bidirectional=True)
self.fc = nn.Sequential(nn.Linear(2 * num_layers * hidden_size, output_size),
nn.Sigmoid(),
nn.Linear(output_size, 1),
nn.Sigmoid())
def forward(self, x):
# x : [bs, seq_len]
x = self.embed(x)
# x : [bs, seq_len, input_size]
_, hn = self.rnn(x) # hn : [2*num_layers, bs, h_dim]
hn = hn.transpose(0,1)
z = hn.reshape(hn.shape[0], -1) # z shape: [bs, 2*num_layers*h_dim]
output = self.fc(z).squeeze(-1) # output shape: [bs, 1]
return output
-
這個(gè)RNNModel包括了三層:
-
嵌入層:nn.Embedding(num_embeddings, embedding_dim)鸣驱,將整數(shù)索引的序列轉(zhuǎn)換為稠密向量表示:
- num_embeddings:詞匯表的大小泛鸟,由于vocab_full.txt有294行,這里填寫294
- embedding_dim:每個(gè)嵌入向量的維度大小踊东,也就是輸出向量(即RNN層輸入向量)的大小北滥,這里是input_size
- 輸入nn.Embedding(num_embeddings, embedding_dim)(x):這里面x的shape是(batch_size, seq_len),seq_len是每個(gè)樣本的元素?cái)?shù)量
- 輸出:即RNN層的輸入闸翅,shape是(batch_size, seq_len, input_size)
-
RNN層:nn.rnn(input_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True):
- input_size:輸入張量的特征數(shù)再芋,即輸入RNN的單元的向量的大小。輸入300是因?yàn)?/li>
- hidden_size:RNN 單元的隱藏狀態(tài)向量的大小坚冀,也是輸出向量的大小
- num_layers:RNN 層的數(shù)量
- dropout:浮點(diǎn)數(shù)济赎,表示在訓(xùn)練過(guò)程中 RNN 層之間的輸出被丟棄的概率,可防止過(guò)擬合
- bidirectional:布爾值遗菠,指示是否使用雙向 RNN联喘。若為True,則是雙向 RNN
-
全連接層:
- nn.Linear辙纬,包括兩個(gè)線性層和兩個(gè) Sigmoid 激活函數(shù),將RNN層的輸出轉(zhuǎn)化為一個(gè)預(yù)測(cè)結(jié)果
-
-
前向傳播self.forward(x)方法:
輸入的數(shù)據(jù)x是DataLoader劃分后得數(shù)據(jù)叭喜。x的shape是(batch_size, seq_len)贺拣。
這個(gè) forward方法首先通過(guò)嵌入層將輸入 x 轉(zhuǎn)換為嵌入表示,然后將其饋送到 RNN 層中。之后譬涡,它處理 RNN 的最后一個(gè)隱藏狀態(tài) (hn)闪幽,并將其通過(guò)全連接層產(chǎn)生最終的輸出。
由于hn的shape是(2*num_layers, batch_size, hidden_size)涡匀,通過(guò)hn.transpose(0,1)將其轉(zhuǎn)換為( batch_size, 2*num_layers, hidden_size)盯腌,以便作為全連接層的輸入。z = hn.reshape(hn.shape[0], -1)陨瘩,將hn的shape變?yōu)?bs, 2*num_layers*h_dim)腕够。
output = self.fc(z).squeeze(-1) ,將張量z經(jīng)過(guò)全連接層處理舌劳,最終得到模型的輸出帚湘。
這幾步算是前向傳播的流水線操作,必須記住這些步驟甚淡。
3. 數(shù)據(jù)預(yù)處理
# tokenizer大诸,鑒于SMILES的特性,這里需要自己定義tokenizer和vocab
# 這里直接將smiles str按字符拆分贯卦,并替換為詞匯表中的序號(hào)
class Smiles_tokenizer():
def __init__(self, pad_token, regex, vocab_file, max_length):
self.pad_token = pad_token
self.regex = regex
self.vocab_file = vocab_file
self.max_length = max_length
with open(self.vocab_file, "r") as f:
lines = f.readlines()
lines = [line.strip("\n") for line in lines]
vocab_dic = {}
for index, token in enumerate(lines):
vocab_dic[token] = index
self.vocab_dic = vocab_dic
def _regex_match(self, smiles):
regex_string = r"(" + self.regex + r"|"
regex_string += r".)"
prog = re.compile(regex_string)
tokenised = []
for smi in smiles:
tokens = prog.findall(smi)
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
tokenised.append(tokens) # 返回一個(gè)所有的字符串列表
return tokenised
def tokenize(self, smiles):
tokens = self._regex_match(smiles)
# 添加上表示開(kāi)始和結(jié)束的token:<cls>, <end>
tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
tokens = self._pad_seqs(tokens, self.pad_token)
token_idx = self._pad_token_to_idx(tokens)
return tokens, token_idx
def _pad_seqs(self, seqs, pad_token):
pad_length = max([len(seq) for seq in seqs])
padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
return padded
def _pad_token_to_idx(self, tokens):
idx_list = []
for token in tokens:
tokens_idx = []
for i in token:
if i in self.vocab_dic.keys():
tokens_idx.append(self.vocab_dic[i])
else:
self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
tokens_idx.append(self.vocab_dic[i])
idx_list.append(tokens_idx)
return idx_list
# 讀數(shù)據(jù)并處理
def read_data(file_path, train=True):
df = pd.read_csv(file_path)
reactant1 = df["Reactant1"].tolist()
reactant2 = df["Reactant2"].tolist()
product = df["Product"].tolist()
additive = df["Additive"].tolist()
solvent = df["Solvent"].tolist()
if train:
react_yield = df["Yield"].tolist()
else:
react_yield = [0 for i in range(len(reactant1))]
# 將reactant拼到一起资柔,之間用.分開(kāi)。product也拼到一起撵割,用>分開(kāi)
input_data_list = []
for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
input_info = ".".join([react1, react2])
input_info = ">".join([input_info, prod])
input_data_list.append(input_info)
output = [(react, y) for react, y in zip(input_data_list, react_yield)]
return output
class ReactionDataset(Dataset):
def __init__(self, data: List[Tuple[List[str], float]]):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch):
REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
smi_list = []
yield_list = []
for i in batch:
smi_list.append(i[0])
yield_list.append(i[1])
tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
yield_list = torch.tensor(yield_list)
return tokenizer_batch, yield_list
這里面有兩個(gè)類贿堰,一個(gè)是對(duì)數(shù)據(jù)集進(jìn)行處理的,一個(gè)是自定義的反應(yīng)數(shù)據(jù)集睁枕。詳細(xì)看一下Smiles_tokenizer類的主要作用:
3.1 從read_data函數(shù)開(kāi)始
def read_data(file_path, train=True):
df = pd.read_csv(file_path)
reactant1 = df["Reactant1"].tolist()
reactant2 = df["Reactant2"].tolist()
product = df["Product"].tolist()
additive = df["Additive"].tolist()
solvent = df["Solvent"].tolist()
if train:
react_yield = df["Yield"].tolist()
else:
react_yield = [0 for i in range(len(reactant1))]
input_data_list = []
這一部分代碼是講數(shù)據(jù)集中的各列提取出來(lái)并轉(zhuǎn)化為列表官边,其中反應(yīng)物、產(chǎn)物外遇、催化劑和溶劑均為SMILES字符串組成的列表注簿,對(duì)于訓(xùn)練集,react_yield是各反應(yīng)對(duì)應(yīng)的產(chǎn)率組成的列表跳仿,對(duì)于測(cè)試集诡渴,react_yield都是0構(gòu)成的列表。經(jīng)過(guò)這一步菲语,我們可以得到下面各個(gè)列表:
reactant1:['c1ccc2c(c1)Nc1ccccc1O2','c1ccc2c(c1)Nc1ccccc1O2', ...]
reactant2:['Brc1ccccc1I', 'Brc1ccccc1I', ...]
product:['Brc1ccccc1N1c2ccccc2Oc2ccccc21','Brc1ccccc1N1c2ccccc2Oc2ccccc21', ...]
additive:['CC(C)(C)[O-].CC(C)(C)[PH+](C(C)(C)C)C(C)(C)C.F[B-](F)(F)F.F[B-](F)(F)F.O=C(C=Cc1ccccc1)C=Cc1ccccc1.O=C(C=Cc1ccccc1)C=Cc1ccccc1.[H+].[Na+].[Pd]','C1COCCOCCOCCOCCOCCO1.O=C([O-])[O-].[Cu+].[I-].[K+].[K+]',]
solvent:['Cc1ccccc1', 'Clc1ccccc1Cl', ...]
react_yield:[0.78, 0.9, ...]
def read_data(file_path, train=True):
...
input_data_list = []
for react1, react2, prod, addi, sol in zip(reactant1, reactant2, product, additive, solvent):
input_info = ".".join([react1, react2])
input_info = ">".join([input_info, prod])
input_data_list.append(input_info)
output = [(react, y) for react, y in zip(input_data_list, react_yield)]
return output
這一部分是將上面得到的各列表轉(zhuǎn)化為“反應(yīng)物1.反應(yīng)物2>產(chǎn)物的SMILES字符串”妄辩,然后輸出“(反應(yīng)物1.反應(yīng)物2>產(chǎn)物的SMILES字符串, 產(chǎn)率)”組成的列表:
[('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.78),
('c1ccc2c(c1)Nc1ccccc1O2.Brc1ccccc1I>Brc1ccccc1N1c2ccccc2Oc2ccccc21', 0.9),
...
]
3.2 Smiles_tokenizer類
首先從構(gòu)造函數(shù)開(kāi)始:
class Smiles_tokenizer():
def __init__(self, pad_token, regex, vocab_file, max_length):
self.pad_token = pad_token
self.regex = regex
self.vocab_file = vocab_file
self.max_length = max_length
with open(self.vocab_file, "r") as f:
lines = f.readlines()
lines = [line.strip("\n") for line in lines]
vocab_dic = {}
for index, token in enumerate(lines):
vocab_dic[token] = index
self.vocab_dic = vocab_dic
這個(gè)構(gòu)造函數(shù)主要形成一個(gè)由各基本原子、基團(tuán)以及化學(xué)鍵等為鍵組成的字典山上。其中每個(gè)項(xiàng)都由“基礎(chǔ) : 整數(shù)序號(hào)”組成:
vocab_dic:
{'<PAD>': 0,
'<CLS>': 1,
'<MASK>': 2,
'<SEP>': 3,
'[UNK]': 4,
'>': 5,
'C': 6,
...
}
然后是_regex_match方法:
class Smiles_tokenizer():
...
def _regex_match(self, smiles):
regex_string = r"(" + self.regex + r"|"
regex_string += r".)"
prog = re.compile(regex_string)
tokenised = []
for smi in smiles:
tokens = prog.findall(smi)
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
tokenised.append(tokens) # 返回一個(gè)所有的字符串列表
return tokenised
上面的正則表達(dá)式是:
REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
self.regex = REGEX
這個(gè)正則表達(dá)式匹配的內(nèi)容包括:方括號(hào)里面的內(nèi)容(例如:[NH4+])眼耀、N、P佩憾、Br哮伟、.干花、@、=等符號(hào)楞黄。
_regex_match傳入的smiles參數(shù)池凄,是一個(gè)SMILES字符串表示的反應(yīng)方程式構(gòu)成的列表。通過(guò)for循環(huán)鬼廓,不斷用這個(gè)正則表達(dá)式將SMILES字符串表示的反應(yīng)方程式中的各個(gè)元素提起出來(lái)肿仑,作為tokens,也就是說(shuō)tokens是反應(yīng)方程式中的各個(gè)元素組成的列表(不妨稱之為:特征元素列表)碎税,而tokenised就是整個(gè)smiles對(duì)應(yīng)的特征元素列表:
# 下面是一個(gè)第四個(gè)反應(yīng)的tokens的例子尤慰,是個(gè)一維列表
tokens:['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1']
# tokenised是嵌套列表
tokenised:[
['c', '1', 'c', 'c', 'c', '2', 'c', '(', 'c', '1', ')', 'N', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'O', '2', '.', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'I', '>', 'Br', 'c', '1', 'c', 'c', 'c', 'c', 'c', '1', 'N', '1', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', 'O', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', '1']
...
]
接著看_pad_seqs和_pad_token_to_idx方法:
class Smiles_tokenizer():
...
def _pad_seqs(self, seqs, pad_token):
pad_length = max([len(seq) for seq in seqs])
padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
return padded
def _pad_token_to_idx(self, tokens):
idx_list = []
for token in tokens:
tokens_idx = []
for i in token:
if i in self.vocab_dic.keys():
tokens_idx.append(self.vocab_dic[i])
else:
self.vocab_dic[i] = max(self.vocab_dic.values()) + 1
tokens_idx.append(self.vocab_dic[i])
idx_list.append(tokens_idx)
return idx_list
這兩個(gè)方法,第一個(gè)是把所有反應(yīng)方程蚣录,通過(guò)填充<PAD>(也就是0)割择,變成相同的長(zhǎng)度,也就是把每個(gè)方程對(duì)應(yīng)的tokens變成擁有相同的元素?cái)?shù)量萎河,便于后續(xù)的處理荔泳。
第二個(gè)方法是根據(jù)前面的vocab_dic,把tokenised中的所有特征元素變成整數(shù):
# 下面是tokenised在padded后的形式虐杯,其中展示的列表是填充后的第四個(gè)反應(yīng)
padded:[
...
['C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
...
]
# 下面是tokenised在padded后對(duì)應(yīng)的idx_list玛歌,其中展示的列表是填充后的第四個(gè)反應(yīng)
idx_list:[
...
[6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
...
]
最后一下 tokenize方法:
class Smiles_tokenizer():
...
def tokenize(self, smiles):
tokens = self._regex_match(smiles)
# 添加上表示開(kāi)始和結(jié)束的token:<cls>, <end>
tokens = [["<CLS>"] + token + ["<SEP>"] for token in tokens]
tokens = self._pad_seqs(tokens, self.pad_token)
token_idx = self._pad_token_to_idx(tokens)
return tokens, token_idx
這個(gè)方法是把前面的幾個(gè)方法匯總到一起,也就是最后返回的tokens就是上面padded的形式(前后添加了"<CLS>"和"<SEP>")擎椰,token_idx就是idx_list的形式:
tokens:[
...
['<CLS>', 'C', '1', 'C', 'O', 'C', 'C', 'N', '1', '.', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'Cl', '>', 'F', 'c', '1', 'c', 'n', 'c', '(', 'Cl', ')', 'n', 'c', '1', 'N', '1', 'C', 'C', 'O', 'C', 'C', '1', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<SEP>']
...
]
token_idx:[
...
[1, 6, 8, 6, 19, 6, 6, 13, 8, 35, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 40, 5, 39, 7, 8, 7, 16, 7, 9, 40, 15, 16, 7, 8, 13, 8, 6, 6, 19, 6, 6, 8, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
...
]
至此支子,就把文本型的數(shù)據(jù)轉(zhuǎn)化為了由整數(shù)構(gòu)成的數(shù)據(jù)了,而collate_fn函數(shù)就是通過(guò)Smiles_tokenizer類獲得的token_idx轉(zhuǎn)化成張量达舒。
def collate_fn(batch):
REGEX = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
tokenizer = Smiles_tokenizer("<PAD>", REGEX, "../vocab_full.txt", max_length=300)
smi_list = []
yield_list = []
for i in batch:
smi_list.append(i[0])
yield_list.append(i[1])
tokenizer_batch = torch.tensor(tokenizer.tokenize(smi_list)[1])
yield_list = torch.tensor(yield_list)
return tokenizer_batch, yield_list
tokenizer_batch:
tensor([[ 1, 7, 8, 7, 7, 7, 11, 7, 9, 7, 8, 15, 13, 7, 8, 7, 7, 7,
7, 7, 8, 19, 11, 35, 17, 7, 8, 7, 7, 7, 7, 7, 8, 38, 5, 17,
7, 8, 7, 7, 7, 7, 7, 8, 13, 8, 7, 11, 7, 7, 7, 7, 7, 11,
19, 7, 11, 7, 7, 7, 7, 7, 11, 8, 3],
[ 1, 7, 8, 7, 7, 7, 11, 7, 9, 7, 8, 15, 13, 7, 8, 7, 7, 7,
7, 7, 8, 19, 11, 35, 17, 7, 8, 7, 7, 7, 7, 7, 8, 38, 5, 17,
7, 8, 7, 7, 7, 7, 7, 8, 13, 8, 7, 11, 7, 7, 7, 7, 7, 11,
19, 7, 11, 7, 7, 7, 7, 7, 11, 8, 3],
...
])
yield_list:
tensor([0.7800, 0.9000, ...])
4. 訓(xùn)練數(shù)據(jù)獲得模型
def train():
## super param
N = 10 #int / int(len(dataset) * 1) # 或者你可以設(shè)置為數(shù)據(jù)集大小的一定比例值朋,如 int(len(dataset) * 0.1)
NUM_EMBED = 294 # nn.Embedding()
INPUT_SIZE = 300 # src length
HIDDEN_SIZE = 512
OUTPUT_SIZE = 512
NUM_LAYERS = 10
DROPOUT = 0.2
CLIP = 1 # CLIP value
N_EPOCHS = 100
LR = 0.0001
start_time = time.time() # 開(kāi)始計(jì)時(shí)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
data = read_data("../dataset/round1_train_data.csv")
dataset = ReactionDataset(data)
subset_indices = list(range(N))
subset_dataset = Subset(dataset, subset_indices)
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
model.train()
optimizer = optim.Adam(model.parameters(), lr=LR)
# criterion = nn.MSELoss() # MSE
criterion = nn.L1Loss() # MAE
best_loss = 10
for epoch in range(N_EPOCHS):
epoch_loss = 0
for i, (src, y) in enumerate(train_loader):
src, y = src.to(device), y.to(device)
optimizer.zero_grad()
output = model(src)
loss = criterion(output, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) # 使用范數(shù)裁剪梯度
optimizer.step()
epoch_loss += loss.item()
loss_in_a_epoch = epoch_loss / len(train_loader)
print(f'Epoch: {epoch+1:02} | Train Loss: {loss_in_a_epoch:.3f}')
if loss_in_a_epoch < best_loss:
# 在訓(xùn)練循環(huán)結(jié)束后保存模型
torch.save(model.state_dict(), '../model/RNN.pth')
end_time = time.time() # 結(jié)束計(jì)時(shí)
# 計(jì)算并打印運(yùn)行時(shí)間
elapsed_time_minute = (end_time - start_time)/60
print(f"Total running time: {elapsed_time_minute:.2f} minutes")
if __name__ == '__main__':
train()
前面幾個(gè)大寫的字母是超參數(shù)。
通過(guò)DataLoader劃分?jǐn)?shù)據(jù)集巩搏,一共23538/128=184個(gè)batch昨登。
訓(xùn)練模型包括以下幾步:
- 實(shí)例化RNNModel并以train模式運(yùn)行;
- 選擇優(yōu)化器為 Adam
- 使用平均絕對(duì)誤差 (MAE) 作為損失函數(shù)
- 循環(huán)訓(xùn)練100次:
- 遍歷每個(gè)批次的數(shù)據(jù)
- 使用GPU設(shè)備進(jìn)行運(yùn)算
- 清除梯度
- 前向傳播計(jì)算輸出
- 計(jì)算損失
- 反向傳播計(jì)算梯度
- 應(yīng)用梯度裁剪
- 更新模型參數(shù)
- 累加每個(gè)批次的損失
- 計(jì)算整個(gè)周期的平均損失
5. 使用測(cè)試集驗(yàn)證模型
# 生成結(jié)果文件
def predicit_and_make_submit_file(model_file, output_file):
NUM_EMBED = 294
INPUT_SIZE = 300
HIDDEN_SIZE = 512
OUTPUT_SIZE = 512
NUM_LAYERS = 10
DROPOUT = 0.2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_data = read_data("../dataset/round1_test_data.csv", train=False)
test_dataset = ReactionDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
model = RNNModel(NUM_EMBED, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS, DROPOUT, device).to(device)
# 加載最佳模型
model.load_state_dict(torch.load(model_file))
model.eval()
output_list = []
for i, (src, y) in enumerate(test_loader):
src, y = src.to(device), y.to(device)
with torch.no_grad():
output = model(src)
output_list += output.detach().tolist()
ans_str_lst = ['rxnid,Yield']
for idx,y in enumerate(output_list):
ans_str_lst.append(f'test{idx+1},{y:.4f}')
with open(output_file,'w') as fw:
fw.writelines('\n'.join(ans_str_lst))
print("done!!!")
predicit_and_make_submit_file("../model/RNN.pth",
"../output/RNN_submit.txt")
上一步已經(jīng)訓(xùn)練好了模型贯底,這一步使用這個(gè)模型對(duì)測(cè)試集進(jìn)行預(yù)測(cè)丰辣。
各超參數(shù)還是那些參數(shù),數(shù)據(jù)集換成測(cè)試集禽捆。
實(shí)例化RNNModel笙什,加載上一步保存的訓(xùn)練好的模型,然后設(shè)置模型為評(píng)估模式胚想。
遍歷測(cè)試DataLoader中的每個(gè)批次的數(shù)據(jù)琐凭,獲取輸入數(shù)據(jù)和標(biāo)簽。需要使用torch.no_grad()上下文管理器來(lái)禁用梯度計(jì)算浊服,提高預(yù)測(cè)速度淘正。
最后將預(yù)測(cè)結(jié)果寫入txt文件摆马。