pytorch自己的數(shù)據(jù)如何處理

參考 pytorch: 準備、訓練和測試自己的圖片數(shù)據(jù)

1及汉、下載數(shù)據(jù)

下載原始地址fashion-mnist
轉(zhuǎn)換圖片代碼如下(應該是版本問題,對原代碼label讀取作了輕微調(diào)整):

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="D:/data/fashion_mnist/"#數(shù)據(jù)集合地址 數(shù)據(jù)要先解壓縮
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
        )
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
        )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
    if(train):
        f=open(root+'train.txt','w')
        data_path=root+'/train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            io.imsave(img_path,img.numpy())
            f.write(img_path+' '+str(label.item())+'\n')
        f.close()
    else:
        f = open(root + 'test.txt', 'w')
        data_path = root + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
            img_path = data_path+ str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label.item()) + '\n')
        f.close()

convert_to_img(True)
convert_to_img(False)

2、數(shù)據(jù)讀取及分類任務

其中數(shù)據(jù)讀取部分采用了參考網(wǎng)站的代碼桑腮,網(wǎng)絡訓練部分則用了mnist簡易分類網(wǎng)絡

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
root="D:/data/fashion_mnist/"

#----數(shù)據(jù)處理階段
def default_loader(path):
    return Image.open(path).convert('RGB')
class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

#----構(gòu)建網(wǎng)絡及訓練部分
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2))
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(64 * 3 * 3, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 10)
        )

    def forward(self, x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        res = conv3_out.view(conv3_out.size(0), -1)
        out = self.dense(res)
        return out

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(DEVICE)
print(model)
EPOCHS=15

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%200 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() 
            pred = output.max(1, keepdim=True)[1] 
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

for epoch in range(1, EPOCHS + 1):
    train(model, DEVICE, train_loader, optimizer, epoch)
    test(model, DEVICE, test_loader)

3痘绎、實驗結(jié)果

訓練了15個epoch津函,測試準確率為91.46%


訓練結(jié)果
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市孤页,隨后出現(xiàn)的幾起案子尔苦,更是在濱河造成了極大的恐慌,老刑警劉巖行施,帶你破解...
    沈念sama閱讀 218,122評論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件允坚,死亡現(xiàn)場離奇詭異,居然都是意外死亡蛾号,警方通過查閱死者的電腦和手機稠项,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,070評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來须教,“玉大人皿渗,你說我怎么就攤上這事∏嵯伲” “怎么了乐疆?”我有些...
    開封第一講書人閱讀 164,491評論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長贬养。 經(jīng)常有香客問我挤土,道長,這世上最難降的妖魔是什么误算? 我笑而不...
    開封第一講書人閱讀 58,636評論 1 293
  • 正文 為了忘掉前任仰美,我火速辦了婚禮,結(jié)果婚禮上儿礼,老公的妹妹穿的比我還像新娘咖杂。我一直安慰自己,他們只是感情好蚊夫,可當我...
    茶點故事閱讀 67,676評論 6 392
  • 文/花漫 我一把揭開白布诉字。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪壤圃。 梳的紋絲不亂的頭發(fā)上陵霉,一...
    開封第一講書人閱讀 51,541評論 1 305
  • 那天,我揣著相機與錄音伍绳,去河邊找鬼踊挠。 笑死,一個胖子當著我的面吹牛冲杀,可吹牛的內(nèi)容都是我干的效床。 我是一名探鬼主播,決...
    沈念sama閱讀 40,292評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼漠趁,長吁一口氣:“原來是場噩夢啊……” “哼扁凛!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起闯传,我...
    開封第一講書人閱讀 39,211評論 0 276
  • 序言:老撾萬榮一對情侶失蹤谨朝,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后甥绿,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體字币,經(jīng)...
    沈念sama閱讀 45,655評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,846評論 3 336
  • 正文 我和宋清朗相戀三年共缕,在試婚紗的時候發(fā)現(xiàn)自己被綠了洗出。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,965評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡图谷,死狀恐怖翩活,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情便贵,我是刑警寧澤菠镇,帶...
    沈念sama閱讀 35,684評論 5 347
  • 正文 年R本政府宣布,位于F島的核電站承璃,受9級特大地震影響利耍,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜盔粹,卻給世界環(huán)境...
    茶點故事閱讀 41,295評論 3 329
  • 文/蒙蒙 一隘梨、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧舷嗡,春花似錦轴猎、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,894評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽烦秩。三九已至,卻和暖如春郎仆,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背兜蠕。 一陣腳步聲響...
    開封第一講書人閱讀 33,012評論 1 269
  • 我被黑心中介騙來泰國打工扰肌, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人熊杨。 一個月前我還...
    沈念sama閱讀 48,126評論 3 370
  • 正文 我出身青樓曙旭,卻偏偏與公主長得像,于是被迫代替她去往敵國和親晶府。 傳聞我的和親對象是個殘疾皇子桂躏,可洞房花燭夜當晚...
    茶點故事閱讀 44,914評論 2 355

推薦閱讀更多精彩內(nèi)容