起因
前段時間完成了吳恩達的深度學習第五專題序列模型,里面一些作業(yè)都很有意思,包括這個Emojify摆出,根據(jù)你輸入的話語判斷你話語的含義掖举,并且用一個表情來表示快骗,
并且把表情放在語句后,這樣就可以實現(xiàn)說話時自動添加最貼切的表情。
具體參考我的github方篮,那是一個keras版本的實現(xiàn)名秀,也是Coursera作業(yè)使用的框架,我稍稍改編了一下藕溅,
里面有些實現(xiàn)的效果以及模型的結(jié)構(gòu)匕得,這里就不多說了,代碼也很簡單巾表,容易理解汁掠。
本來這么簡單感覺沒必要寫個博客,但是正好用pytorch復現(xiàn)一遍集币,此間遇到不少坑考阱,所以我打算好好講解一番。
我將會從制作自己的數(shù)據(jù)集開始惠猿,把這個問題擴展羔砾,分成幾篇博文,按流程介紹我的實現(xiàn)歷程偶妖。
這里我不由不說姜凄,要看你的框架掌握的怎么樣,其實就是看你官方文檔看的怎么樣趾访,更重要的是态秧,你的官方源碼看的怎么樣,所以這幾篇博文我會嵌入一些官方源碼 和官方文檔的內(nèi)容扼鞋,詳細講解如何利用它們解決自己的問題申鱼,因為當你編程的時候你會發(fā)現(xiàn),遇到問題百度是沒什么用的云头,google也不是什么問題都有的捐友。 通過源碼,你也會發(fā)現(xiàn)各種教程里都不太可能說到的東西溃槐。成功運行官方教程里給出的mnist程序匣砖,并不是你就會了這個框架,甚至連入門都說不上昏滴!一定要通讀文檔和源碼:秭辍!
制作數(shù)據(jù)集
這第一篇博文我主要就想講解一下如何在pytorch框架中制作自己的數(shù)據(jù)集谣殊,這里不得不說拂共,源碼大法好。這里的制作數(shù)據(jù)集不是
收集數(shù)據(jù)并且做成對應(yīng)格式姻几,而是已經(jīng)有數(shù)據(jù)了宜狐,怎么裝載到dataset里势告,看完博文你就了解了。
emojify數(shù)據(jù)集構(gòu)造
首先介紹一下emojify的數(shù)據(jù)集是什么樣子的肌厨,其實我建議你去上面我給出的我的github上看一下培慌,上面的介紹很清楚。這里就簡單說一下柑爸,
這個數(shù)據(jù)集還是很簡單的吵护,每一個traindata就是一句話,string類型表鳍,標簽就是0-4五個整形的值馅而,代表五個表情(emoji)。testdata也一樣譬圣,
但是我把所有的testdata和traindata都合在一起了瓮恭,稍微增大訓練集,提升泛化效果厘熟,測試我們只需要手動輸入一些我們想說的話就好了屯蹦,
這個項目中測試集并不是很重要。
所以最后的訓練數(shù)據(jù)就是188個樣本绳姨,但是為了batchsize好選登澜,我刪了8個樣本,變成180個樣本飘庄,存在一個csv文件中脑蠕。csv文件的讀取相信不用我多說,
pandas和csv模塊都可以輕松讀取跪削。
直接加載數(shù)據(jù)
其實我們可以讀取數(shù)據(jù)谴仙,然后保存在列表里,然后用np.random.shuffle打亂一下碾盐,大不了再寫一個get_batch函數(shù)獲取批樣本晃跺,就像下面一樣:
def get_batchs(X,Y,batchsize = 3,batchnum = 0):
if (batchnum*batchsize+batchsize) >= X.shape[0]:
bx = X[batchnum * batchsize:]
by = Y[batchnum * batchsize:]
else:
bx = X[batchnum * batchsize:(batchnum * batchsize + batchsize)]
by = Y[batchnum * batchsize:(batchnum * batchsize + batchsize)]
return np.array(bx),np.array(by)
就可以在每個step時候調(diào)用一下,就可以輸入數(shù)據(jù)了毫玖,那么為什么還要制作數(shù)據(jù)集呢掀虎?
因為簡單的數(shù)據(jù)集自然沒有必要的,這樣處理足夠了孕豹,但是復雜的比較龐大的數(shù)據(jù)集呢?一點就是這樣寫是把所有的數(shù)據(jù)都加載進來十气,
因此需要大量的內(nèi)存励背,第二是要實現(xiàn)更加復雜的操作比較麻煩,又要添加更多的代碼砸西,比如多線程讀取數(shù)據(jù)等等叶眉。更重要的是址儒,數(shù)據(jù)集都不會制作,
如何說掌握了pytorch衅疙。
自定義數(shù)據(jù)集
回憶一下莲趣,你是怎么用pytorch實現(xiàn)mnist分類的,具體一點饱溢,如何加載mnist數(shù)據(jù)集的喧伞。你是直接用了torchvision.datasets.MNIST,而且
可以指定一些參數(shù)绩郎,比如是否從云端下載數(shù)據(jù)集潘鲫,對數(shù)據(jù)采用怎么樣的變換等等,然后調(diào)用 Data.DataLoader肋杖,可以讓數(shù)據(jù)集可迭代溉仑,
并且可以shffule,指定batchsize等等状植,然后訓練時就可以直接調(diào)用了浊竟,測試集也一樣處理,是不是很方便津畸。那么我們也可以這樣嗎振定?
答案是肯定的。如果是圖片數(shù)據(jù)集洼畅,你想做一個分類的話吩案,直接用datasets的ImageFolder函數(shù)就可以了,只需要注意文件和文件夾的格式
就行了帝簇。那么這里的emojify數(shù)據(jù)集呢徘郭?這里就要上源碼了,我們先看官方源碼的mnist是如何編寫的丧肴,這里完全是我自己摸索的残揉,然后發(fā)現(xiàn)
沒什么大問題,所以拿出來說一下芋浮。源碼在哪看呢抱环?github搜索vision和pytorch,分別是兩部分的源碼纸巷。
class MNIST(data.Dataset):
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
def __getitem__(self, index):
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
pass
上面是主要代碼镇草,download被我刪了,這里也不需要瘤旨,接下來我們直接進行修改梯啤,怎么改呢,首先看init方法存哲,賦值部分
都不怎么需要改因宇,雖然本例中幾個都用不到七婴,download不需要,所以上面可以直接改成pass察滑,然后把下面的實現(xiàn)刪去打厘,
train肯定是True,因為沒有測試集贺辰,所以else的部分就pass户盯,if self.train:的部分就讀取我們的csv數(shù)據(jù)就好了,
這里string的數(shù)據(jù)要換成index魂爪,我一開始沒有先舷,直接加載數(shù)據(jù),然后出錯滓侍,查看源碼發(fā)現(xiàn)蒋川,因為數(shù)據(jù)會在源碼里幫你轉(zhuǎn)化成torch的Tensor,
而torch里沒有string的Tensor的撩笆,所以要變換成詞匯表里的索引捺球。_check_exists也不需要,刪去夕冲,len里面train的值改成我們的樣本數(shù)180氮兵,
測試集沒有,就是0嘍歹鱼。這樣就剩getitem了泣栈,顧名思義,就是獲取item的值弥姻,官方在這個函數(shù)下標注
tuple: (image, target) where target is index of the target class南片,也就是獲取index索引的數(shù)據(jù),這也很簡單庭敦。整理一下疼进,改成如下:
class emojiDataset(Data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train
if download:
pass
if self.train:
traindata,trainlabels = read_csv(os.path.join(self.root,'mytrain.csv'))
self.train_data = sentences_to_indices(traindata, word_to_index, 10)
self.train_labels = trainlabels
# self.train_labels = convert_to_one_hot(trainlabels, C=5)
else:
pass
def __getitem__(self, index):
if self.train:
data, target = self.train_data[index], self.train_labels[index]
else:
pass
if self.transform is not None:
pass
if self.target_transform is not None:
pass
return data, target
def __len__(self):
if self.train:
return 180
else:
return 0
可以看到,其實沒怎么修改秧廉,然后iter一下伞广,print(next(iter))測試,發(fā)現(xiàn)沒有任何問題疼电,就是這么簡單嚼锄。
注意到這個問題比較簡單,所以很多函數(shù)都是pass蔽豺,你當然可以自由編寫区丑,適應(yīng)你的項目要求。
實例化一下茫虽,用dataloader試一下刊苍,看看是不是也可以,結(jié)果自然是可以的濒析。
結(jié)論
這樣數(shù)據(jù)集的準備就ok了正什,剩下就是定義模型和訓練了。