使用optuna對模型的超參數(shù)進行自動優(yōu)化

使用前先安裝optuna.
pip install optuna
optuna適用于多種機器學習框架包括pytorch,tensorflow等稠氮⌒中桑可以optuna examples github查看所有支持的框架的教程茉盏。

使用optuna優(yōu)化pytorch的模型

先有基本的pytorch經(jīng)驗鉴未,可以更快的理解下面的代碼。這里使用了gpu,如果沒有GPU鸠姨,可以修改最開始的DEVICE=torch.device("cuda")DEVICE=torch.device("cpu")

import os
import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms


"""
Optuna example that optimizes multi-layer perceptrons using PyTorch.

In this example, we optimize the validation accuracy of fashion product recognition using
PyTorch and FashionMNIST. We optimize the neural network architecture as well as the optimizer
configuration. As it is too time consuming to use the whole FashionMNIST dataset,
we here use a small subset of it.

"""


DEVICE = torch.device("cuda")
BATCHSIZE = 128 #每次訓練時數(shù)據(jù)被分為小批次的大小
CLASSES = 10 #
DIR = os.getcwd()
EPOCHS = 10 #所有的小批次完成后铜秆,即為1個epoch,此處是設置總的訓練的epoch次數(shù)
N_TRAIN_EXAMPLES = BATCHSIZE * 30 #訓練集的樣本數(shù)量
N_VALID_EXAMPLES = BATCHSIZE * 10 #測試集的樣本數(shù)量

#定義神經(jīng)網(wǎng)絡模型,在里面使用了2個自優(yōu)化超參數(shù)
def define_model(trial):
    # We optimize the number of layers, hidden units and dropout ratio in each layer.
    n_layers = trial.suggest_int("n_layers", 1, 3) #定義一個自由化超參數(shù):訓練的層數(shù)讶迁,從1到3
    layers = []

    in_features = 28 * 28
    #循環(huán)構建卷積塊
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)#定義一個自優(yōu)化超參數(shù)out_features连茧,是輸出feature的維度,從4到128
        layers.append(nn.Linear(in_features, out_features))#使用了上面定義的輸出特征維度超參數(shù)
        layers.append(nn.ReLU())#層添加上激活函數(shù)ReLU
        p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)#定義一個自優(yōu)化超參數(shù)p巍糯,是丟棄率啸驯,從0.2到0.5
        layers.append(nn.Dropout(p))#添加上上面的丟棄層
        in_features = out_features#再把最后的輸出層維度賦值給輸出層
    layers.append(nn.Linear(in_features, CLASSES))#最后添加上線性輸出層
    layers.append(nn.LogSoftmax(dim=1))#再轉為0-1分布函數(shù),此處是分類模型祟峦,所以需要這個函數(shù)罚斗,轉成概率值
    return nn.Sequential(*layers)#返回的就是一個多層神經(jīng)網(wǎng)絡的所有層

#在線加載FashionMNIST數(shù)據(jù)集,pytorch的基本操作
def get_mnist():
    # Load FashionMNIST dataset.
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=False, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    return train_loader, valid_loader


def objective(trial):
    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])#定義自優(yōu)化的超參數(shù):模型優(yōu)化器
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) #定義自由化的超參數(shù),學習率搀愧,從1e-5到1
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)#創(chuàng)建優(yōu)化器對象

    # Get the FashionMNIST dataset.使用在線獲取FashionMNIST數(shù)據(jù)集惰聂,前者是訓練數(shù)據(jù)集,后者是測試數(shù)據(jù)集
    train_loader, valid_loader = get_mnist()

    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()#訓練模型
        for batch_idx, (data, target) in enumerate(train_loader):
            # 當訓練的樣本總數(shù)超過我們最開始設置的訓練樣本數(shù)之后咱筛,就停止訓練
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            #轉換為特定硬件設備的張量搓幌,view是用于改變張量的形狀,data.size(0)表示第1維度的大小迅箩,-1溉愁,表示自動計算其他維度的大小
            data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()#清空優(yōu)化器的梯度
            output = model(data)#得到模型的輸出結果
            loss = F.nll_loss(output, target)#使用負對數(shù)似然損失(nll_loss)作為損失函數(shù)
            loss.backward()#反向傳播
            optimizer.step()#根據(jù)計算得到的梯度,通過優(yōu)化器更新模型的參數(shù)

        # 測試集饲趋,評估模型
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                # 如果測試數(shù)據(jù)量大于最開始設置的測試數(shù)據(jù)量就停止測試
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
                output = model(data)
                # 獲取最大值的索引拐揭,后面的keepdim是在原來的張量相同的維度上保存。
                pred = output.argmax(dim=1, keepdim=True)
                #累加計算結果正確的所有的次數(shù)奕塑,
                correct += pred.eq(target.view_as(pred)).sum().item()
        #這里是計算準確率堂污,
        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

        trial.report(accuracy, epoch)#將當前的準確率和迭代次數(shù)傳遞給optuna

        #optuna評估是否需要剪枝,如果需要剪枝(說明模型的參數(shù)性能不佳)龄砰,則拋出TrialPruned異常盟猖。
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    return accuracy

#判斷當前腳本是否是作為主程序運行,只有作為主程序時换棚,才會執(zhí)行下面的代碼式镐,作為模塊被引入時,不會執(zhí)行下面的代碼
if __name__ == "__main__":
    #創(chuàng)建一個optuna研究對象固蚤,名字是maxmize
    study = optuna.create_study(direction="maximize")
    #對目標函數(shù)進行優(yōu)化娘汞,objective是要優(yōu)化的目標函數(shù),n_trials=100表示最多進行100次實驗夕玩,timeout=600表示優(yōu)化過程的超時時間為600秒你弦。
    study.optimize(objective, n_trials=100, timeout=600)
    #獲取所有被提前停止的實驗
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    #獲取所有成功的實驗惊豺,deepcopy=False表示返回的是原始對象的引用,而不是復制鳖目,這可以節(jié)省內存
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    #獲取結果最佳的實驗
    trial = study.best_trial
    #輸出最佳實驗時的超參數(shù)設定
    print("  Value: ", trial.value)
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

程序最終輸出內容如下:

Study statistics: 
  Number of finished trials:  100
  Number of pruned trials:  61
  Number of complete trials:  39
Best trial:
  Value:  0.85078125
  Params: 
    n_layers: 1
    n_units_l0: 81
    dropout_l0: 0.22233836180755426
    optimizer: Adam
    lr: 0.0037253244556814374

即最優(yōu)的超參數(shù)組合的模型的準確率是0.85078125扮叨,最優(yōu)超參數(shù)設置是:

  • 層數(shù)是1缤弦,
  • 輸出維度是81领迈,
  • 損失函數(shù)是0.22233836180755426,
  • 模型優(yōu)化器是Adam,
  • 學習率是 0.0037253244556814374
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末碍沐,一起剝皮案震驚了整個濱河市狸捅,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌累提,老刑警劉巖尘喝,帶你破解...
    沈念sama閱讀 218,682評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異斋陪,居然都是意外死亡朽褪,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評論 3 395
  • 文/潘曉璐 我一進店門无虚,熙熙樓的掌柜王于貴愁眉苦臉地迎上來缔赠,“玉大人,你說我怎么就攤上這事友题∴脱撸” “怎么了?”我有些...
    開封第一講書人閱讀 165,083評論 0 355
  • 文/不壞的土叔 我叫張陵度宦,是天一觀的道長踢匣。 經(jīng)常有香客問我,道長戈抄,這世上最難降的妖魔是什么离唬? 我笑而不...
    開封第一講書人閱讀 58,763評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮划鸽,結果婚禮上输莺,老公的妹妹穿的比我還像新娘。我一直安慰自己漾稀,他們只是感情好模闲,可當我...
    茶點故事閱讀 67,785評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著崭捍,像睡著了一般尸折。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上殷蛇,一...
    開封第一講書人閱讀 51,624評論 1 305
  • 那天实夹,我揣著相機與錄音橄浓,去河邊找鬼。 笑死亮航,一個胖子當著我的面吹牛荸实,可吹牛的內容都是我干的。 我是一名探鬼主播缴淋,決...
    沈念sama閱讀 40,358評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼准给,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了重抖?” 一聲冷哼從身側響起露氮,我...
    開封第一講書人閱讀 39,261評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎钟沛,沒想到半個月后畔规,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,722評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡恨统,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年叁扫,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片畜埋。...
    茶點故事閱讀 40,030評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡莫绣,死狀恐怖,靈堂內的尸體忽然破棺而出由捎,到底是詐尸還是另有隱情兔综,我是刑警寧澤,帶...
    沈念sama閱讀 35,737評論 5 346
  • 正文 年R本政府宣布狞玛,位于F島的核電站软驰,受9級特大地震影響,放射性物質發(fā)生泄漏心肪。R本人自食惡果不足惜锭亏,卻給世界環(huán)境...
    茶點故事閱讀 41,360評論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望硬鞍。 院中可真熱鬧慧瘤,春花似錦、人聲如沸固该。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽伐坏。三九已至怔匣,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間桦沉,已是汗流浹背每瞒。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評論 1 270
  • 我被黑心中介騙來泰國打工金闽, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人剿骨。 一個月前我還...
    沈念sama閱讀 48,237評論 3 371
  • 正文 我出身青樓代芜,卻偏偏與公主長得像,于是被迫代替她去往敵國和親浓利。 傳聞我的和親對象是個殘疾皇子挤庇,可洞房花燭夜當晚...
    茶點故事閱讀 44,976評論 2 355

推薦閱讀更多精彩內容