簡單GAN網(wǎng)絡(luò) matlab實現(xiàn)

說明

此代碼在matlab上搭建了簡單的生成對抗性網(wǎng)絡(luò)聋亡,用來生成手寫數(shù)字圖像汰蓉。
網(wǎng)絡(luò)中生成器和鑒別器的隱藏層均為2層珊佣,且都是全連接層康铭,是一個比較簡單的網(wǎng)絡(luò)結(jié)構(gòu)惯退。主要用來說明怎么在matlab上搭建GAN(Generative Adversarial Net)網(wǎng)絡(luò)。

網(wǎng)絡(luò)模型

如圖1所示从藤,是生成器網(wǎng)絡(luò)模型催跪,一個輸入層,兩個全連接層夷野。輸入的數(shù)據(jù)是100×1的噪聲懊蒸,輸出是784×1的向量,將輸出進行reshape之后悯搔,就可以得到一張28×28的手寫數(shù)字圖像榛鼎。

圖1 生成器

如圖2所示,是將生成器和鑒別器連接起來之后的模型。這里主要想說明一點者娱,在進行網(wǎng)絡(luò)參數(shù)更新的時候,為了得到生成器參數(shù)的偏導(dǎo)數(shù)苏揣,bp過程需要鑒別器黄鳍,再傳到生成器。

到這里就產(chǎn)生了一個疑問:不是說更新生成器的時候平匈,鑒別網(wǎng)絡(luò)的參數(shù)需要固定住不變嗎框沟,如果bp過程需要經(jīng)過鑒別網(wǎng)絡(luò),那應(yīng)該怎么保持鑒別網(wǎng)絡(luò)的參數(shù)不變呢增炭?
其實bp的時候忍燥,只是算出來了各個網(wǎng)絡(luò)層參數(shù)對于loss的偏導(dǎo)數(shù),在求生成器的參數(shù)的偏導(dǎo)數(shù)的時候隙姿,鑒別網(wǎng)絡(luò)的參數(shù)的偏導(dǎo)數(shù)也被求出來了梅垄。但是求出來了偏導(dǎo)數(shù),不一定就要對網(wǎng)絡(luò)進行更新输玷。也就是求出來了網(wǎng)絡(luò)loss對生成器和鑒別器的偏導(dǎo)數(shù)队丝,但是只使用到了生成器的偏導(dǎo)數(shù)來更新生成器。

圖2 將generator和discriminator看成一個整體

More

實例

實例1

??????代碼在githubgan_adam.m
這里使用上面提到的網(wǎng)絡(luò)結(jié)構(gòu)來生成手寫數(shù)字圖片,使用到了Adam算法作為優(yōu)化器來更新GAN網(wǎng)絡(luò)尤误。

clear;
clc;
% -----------加載數(shù)據(jù)
load('mnist_uint8', 'train_x');
train_x = double(reshape(train_x, 60000, 28, 28))/255;
train_x = permute(train_x,[1,3,2]);
train_x = reshape(train_x, 60000, 784);
% -----------------定義模型
generator = nnsetup([100, 512, 784]);
discriminator = nnsetup([784, 512, 1]);
% -----------開始訓(xùn)練
batch_size = 60;
epoch = 100;
images_num = 60000;
batch_num = ceil(images_num / batch_size);
learning_rate = 0.001;
for e=1:epoch
    kk = randperm(images_num);
    for t=1:batch_num
        % 準(zhǔn)備數(shù)據(jù)
        images_real = train_x(kk((t - 1) * batch_size + 1:t * batch_size), :, :);
        noise = unifrnd(-1, 1, batch_size, 100);
        % 開始訓(xùn)練
        % -----------更新generator侠畔,固定discriminator
        generator = nnff(generator, noise);
        images_fake = generator.layers{generator.layers_count}.a;
        discriminator = nnff(discriminator, images_fake);
        logits_fake = discriminator.layers{discriminator.layers_count}.z;
        discriminator = nnbp_d(discriminator, logits_fake, ones(batch_size, 1));
        generator = nnbp_g(generator, discriminator);
        generator = nnapplygrade(generator, learning_rate);
        % -----------更新discriminator,固定generator
        generator = nnff(generator, noise);
        images_fake = generator.layers{generator.layers_count}.a;
        images = [images_fake;images_real];
        discriminator = nnff(discriminator, images);
        logits = discriminator.layers{discriminator.layers_count}.z;
        labels = [zeros(batch_size,1);ones(batch_size,1)];
        discriminator = nnbp_d(discriminator, logits, labels);
        discriminator = nnapplygrade(discriminator, learning_rate);
        % ----------------輸出loss
        if t == batch_num
            c_loss = sigmoid_cross_entropy(logits(1:batch_size), ones(batch_size, 1));
            d_loss = sigmoid_cross_entropy(logits, labels);
            fprintf('c_loss:"%f",d_loss:"%f"\n',c_loss, d_loss);
        end
        if t == batch_num
            path = ['./pics/epoch_',int2str(e),'_t_',int2str(t),'.png'];
            save_images(images_fake, [4, 4], path);
            fprintf('save_sample:%s\n', path);
        end
    end
end
% sigmoid激活函數(shù)
function output = sigmoid(x)
    output =1./(1+exp(-x));
end
% relu
function output = relu(x)
    output = max(x, 0);
end
% relu對x的導(dǎo)數(shù)
function output = delta_relu(x)
    output = max(x,0);
    output(output>0) = 1;
end
% 交叉熵?fù)p失函數(shù)袄膏,此處的logits是未經(jīng)過sigmoid激活的
% https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
function result = sigmoid_cross_entropy(logits, labels)
    result = max(logits, 0) - logits .* labels + log(1 + exp(-abs(logits)));
    result = mean(result);
end
% sigmoid_cross_entropy對logits的導(dǎo)數(shù)践图,此處的logits是未經(jīng)過sigmoid激活的
function result = delta_sigmoid_cross_entropy(logits, labels)
    temp1 = max(logits, 0);
    temp1(temp1>0) = 1;
    temp2 = logits;
    temp2(temp2>0) = -1;
    temp2(temp2<0) = 1;
    result = temp1 - labels + exp(-abs(logits))./(1+exp(-abs(logits))) .* temp2;
end
% 根據(jù)所給的結(jié)構(gòu)建立網(wǎng)絡(luò)
function nn = nnsetup(architecture)
    nn.architecture   = architecture;
    nn.layers_count = numel(nn.architecture);
    % t,beta1,beta2,epsilon,nn.layers{i}.w_m,nn.layers{i}.w_v,nn.layers{i}.b_m,nn.layers{i}.b_v是應(yīng)用adam算法更新網(wǎng)絡(luò)所需的變量
    nn.t = 0;
    nn.beta1 = 0.9;
    nn.beta2 = 0.999;
    nn.epsilon = 10^(-8);
    % 假設(shè)結(jié)構(gòu)為[100, 512, 784],則有3層沉馆,輸入層100码党,兩個隱藏層:100*512,512*784, 輸出為最后一層的a值(激活值)
    for i = 2 : nn.layers_count   
        nn.layers{i}.w = normrnd(0, 0.02, nn.architecture(i-1), nn.architecture(i));
        nn.layers{i}.b = normrnd(0, 0.02, 1, nn.architecture(i));
        nn.layers{i}.w_m = 0;
        nn.layers{i}.w_v = 0;
        nn.layers{i}.b_m = 0;
        nn.layers{i}.b_v = 0;
    end
end
% 前向傳遞
function nn = nnff(nn, x)
    nn.layers{1}.a = x;
    for i = 2 : nn.layers_count
        input = nn.layers{i-1}.a;
        w = nn.layers{i}.w;
        b = nn.layers{i}.b;
        nn.layers{i}.z = input*w + repmat(b, size(input, 1), 1);
        if i ~= nn.layers_count
            nn.layers{i}.a = relu(nn.layers{i}.z);
        else
            nn.layers{i}.a = sigmoid(nn.layers{i}.z);
        end
    end
end
% discriminator的bp斥黑,下面的bp涉及到對各個參數(shù)的求導(dǎo)
% 如果更改網(wǎng)絡(luò)結(jié)構(gòu)(激活函數(shù)等)則涉及到bp的更改揖盘,更改weights,biases的個數(shù)則不需要更改bp
% 為了更新w,b锌奴,就是要求最終的loss對w兽狭,b的偏導(dǎo)數(shù),殘差就是在求w,b偏導(dǎo)數(shù)的中間計算過程的結(jié)果
function nn = nnbp_d(nn, y_h, y)
    % d表示殘差箕慧,殘差就是最終的loss對各層未激活值(z)的偏導(dǎo)服球,偏導(dǎo)數(shù)的計算需要采用鏈?zhǔn)角髮?dǎo)法則-自己手動推出來
    n = nn.layers_count;
    % 最后一層的殘差
    nn.layers{n}.d = delta_sigmoid_cross_entropy(y_h, y);
    for i = n-1:-1:2
        d = nn.layers{i+1}.d;
        w = nn.layers{i+1}.w;
        z = nn.layers{i}.z;
        % 每一層的殘差是對每一層的未激活值求偏導(dǎo)數(shù),所以是后一層的殘差乘上w,再乘上對激活值對未激活值的偏導(dǎo)數(shù)
        nn.layers{i}.d = d*w' .* delta_relu(z);    
    end
    % 求出各層的殘差之后颠焦,就可以根據(jù)殘差求出最終loss對weights和biases的偏導(dǎo)數(shù)
    for i = 2:n
        d = nn.layers{i}.d;
        a = nn.layers{i-1}.a;
        % dw是對每層的weights進行偏導(dǎo)數(shù)的求解
        nn.layers{i}.dw = a'*d / size(d, 1);
        nn.layers{i}.db = mean(d, 1);
    end
end
% generator的bp
function g_net = nnbp_g(g_net, d_net)
    n = g_net.layers_count;
    a = g_net.layers{n}.a;
    % generator的loss是由label_fake得到的斩熊,(images_fake過discriminator得到label_fake)
    % 對g進行bp的時候,可以將g和d看成是一個整體
    % g最后一層的殘差等于d第2層的殘差乘上(a .* (a_o))
    g_net.layers{n}.d = d_net.layers{2}.d * d_net.layers{2}.w' .* (a .* (1-a));
    for i = n-1:-1:2
        d = g_net.layers{i+1}.d;
        w = g_net.layers{i+1}.w;
        z = g_net.layers{i}.z;
        % 每一層的殘差是對每一層的未激活值求偏導(dǎo)數(shù)伐庭,所以是后一層的殘差乘上w,再乘上對激活值對未激活值的偏導(dǎo)數(shù)
        g_net.layers{i}.d = d*w' .* delta_relu(z);    
    end
    % 求出各層的殘差之后粉渠,就可以根據(jù)殘差求出最終loss對weights和biases的偏導(dǎo)數(shù)
    for i = 2:n
        d = g_net.layers{i}.d;
        a = g_net.layers{i-1}.a;
        % dw是對每層的weights進行偏導(dǎo)數(shù)的求解
        g_net.layers{i}.dw = a'*d / size(d, 1);
        g_net.layers{i}.db = mean(d, 1);
    end
end
% 應(yīng)用梯度
% 使用adam算法更新變量,可以參考:
% https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
function nn = nnapplygrade(nn, learning_rate)
    n = nn.layers_count;
    nn.t = nn.t+1;
    beta1 = nn.beta1;
    beta2 = nn.beta2;
    lr = learning_rate * sqrt(1-nn.beta2^nn.t) / (1-nn.beta1^nn.t);
    for i = 2:n
        dw = nn.layers{i}.dw;
        db = nn.layers{i}.db;
        % 下面的6行代碼是使用adam更新weights與biases
        nn.layers{i}.w_m = beta1 * nn.layers{i}.w_m + (1-beta1) * dw;
        nn.layers{i}.w_v = beta2 * nn.layers{i}.w_v + (1-beta2) * (dw.*dw);
        nn.layers{i}.w = nn.layers{i}.w - lr * nn.layers{i}.w_m ./ (sqrt(nn.layers{i}.w_v) + nn.epsilon);
        nn.layers{i}.b_m = beta1 * nn.layers{i}.b_m + (1-beta1) * db;
        nn.layers{i}.b_v = beta2 * nn.layers{i}.b_v + (1-beta2) * (db.*db);
        nn.layers{i}.b = nn.layers{i}.b - lr * nn.layers{i}.b_m ./ (sqrt(nn.layers{i}.b_v) + nn.epsilon); 
    end
end
% 保存圖片圾另,便于觀察generator生成的images_fake
function save_images(images, count, path)
    n = size(images, 1);
    row = count(1);
    col = count(2);
    I = zeros(row*28, col*28);
    for i = 1:row
        for j = 1:col
            r_s = (i-1)*28+1;
            c_s = (j-1)*28+1;
            index = (i-1)*col + j;
            pic = reshape(images(index, :), 28, 28);
            I(r_s:r_s+27, c_s:c_s+27) = pic;
        end
    end
    imwrite(I, path);
end

結(jié)果


epoch_5_t_1000.png

epoch_13_t_1000.png
實例2霸株,Mini-batch Gradient Descent

??????代碼在githubgan_mbgd.m

這里使用的網(wǎng)絡(luò)結(jié)構(gòu)與實例1的一樣,只是換了一種優(yōu)化器集乔。使用Mini-batch Gradient Descent算法更新GAN網(wǎng)絡(luò)去件,下面是部分代碼。

% 應(yīng)用梯度
function nn = nnapplygrade(nn, learning_rate)
    n = nn.layers_count;
    for i = 2:n
        dw = nn.layers{i}.dw;
        db = nn.layers{i}.db;
        nn.layers{i}.w = nn.layers{i}.w - learning_rate * dw;
        nn.layers{i}.b = nn.layers{i}.b - learning_rate * db;
    end
end

結(jié)果:


epoch_11_t_1000.png
epoch_13_t_1000.png
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末饺著,一起剝皮案震驚了整個濱河市箫攀,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌幼衰,老刑警劉巖靴跛,帶你破解...
    沈念sama閱讀 217,084評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異渡嚣,居然都是意外死亡梢睛,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,623評論 3 392
  • 文/潘曉璐 我一進店門识椰,熙熙樓的掌柜王于貴愁眉苦臉地迎上來绝葡,“玉大人,你說我怎么就攤上這事腹鹉〔爻” “怎么了?”我有些...
    開封第一講書人閱讀 163,450評論 0 353
  • 文/不壞的土叔 我叫張陵功咒,是天一觀的道長愉阎。 經(jīng)常有香客問我,道長力奋,這世上最難降的妖魔是什么榜旦? 我笑而不...
    開封第一講書人閱讀 58,322評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮景殷,結(jié)果婚禮上溅呢,老公的妹妹穿的比我還像新娘澡屡。我一直安慰自己,他們只是感情好咐旧,可當(dāng)我...
    茶點故事閱讀 67,370評論 6 390
  • 文/花漫 我一把揭開白布驶鹉。 她就那樣靜靜地躺著,像睡著了一般休偶。 火紅的嫁衣襯著肌膚如雪梁厉。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,274評論 1 300
  • 那天踏兜,我揣著相機與錄音,去河邊找鬼八秃。 笑死碱妆,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的昔驱。 我是一名探鬼主播疹尾,決...
    沈念sama閱讀 40,126評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼骤肛!你這毒婦竟也來了纳本?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,980評論 0 275
  • 序言:老撾萬榮一對情侶失蹤腋颠,失蹤者是張志新(化名)和其女友劉穎繁成,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體淑玫,經(jīng)...
    沈念sama閱讀 45,414評論 1 313
  • 正文 獨居荒郊野嶺守林人離奇死亡巾腕,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,599評論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了絮蒿。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片尊搬。...
    茶點故事閱讀 39,773評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖土涝,靈堂內(nèi)的尸體忽然破棺而出佛寿,到底是詐尸還是另有隱情,我是刑警寧澤但壮,帶...
    沈念sama閱讀 35,470評論 5 344
  • 正文 年R本政府宣布冀泻,位于F島的核電站,受9級特大地震影響茵肃,放射性物質(zhì)發(fā)生泄漏腔长。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,080評論 3 327
  • 文/蒙蒙 一验残、第九天 我趴在偏房一處隱蔽的房頂上張望捞附。 院中可真熱鬧,春花似錦、人聲如沸鸟召。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,713評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽欧募。三九已至压状,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間跟继,已是汗流浹背种冬。 一陣腳步聲響...
    開封第一講書人閱讀 32,852評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留舔糖,地道東北人娱两。 一個月前我還...
    沈念sama閱讀 47,865評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像金吗,于是被迫代替她去往敵國和親十兢。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,689評論 2 354

推薦閱讀更多精彩內(nèi)容