SimGAN-Captcha代碼閱讀與復(fù)現(xiàn)

項目介紹

項目地址:戳這里
大概的講一下這個項目的起因是大神要參加HackMIT,需要他們在15000張驗證碼中識別出10000張或者每個字符的識別準(zhǔn)確率要到90%球切。然后他不想標(biāo)注數(shù)據(jù)(就是這么任性~)齐莲。于是決定先自己生成一批驗證碼(synthesizer合成器)诗力,然后把這些驗證碼用一個refiner(GAN)去對這批合成的驗證碼做一些調(diào)整讓它們看起來和真實的訓(xùn)練樣本的樣式差不多夺衍。這樣他就相當(dāng)于有了一批標(biāo)注好的驗證碼徙硅,用這部分的標(biāo)注驗證碼去訓(xùn)練一個分類器鞠鲜,然后對需要hack的15000張圖片做分類麻裁。借鑒的paper是Apple在2016年發(fā)的,戳這里嗤锉。但是呢渔欢,他發(fā)現(xiàn)他的這批數(shù)據(jù)訓(xùn)練出來的模型對真實樣本的準(zhǔn)確率只有55%,于是他讓一個同學(xué)標(biāo)注了4000張要hack的圖片(這個同學(xué)原本打算標(biāo)注10000張)瘟忱,最后開開心心的一張圖片都沒標(biāo)注的有了參加這個比賽的資格奥额。

下面如果不想關(guān)注paper細(xì)節(jié)可以跳過這部分,直接到項目代碼這一塊就可以访诱。

Overview

下圖是paper中的總體結(jié)構(gòu)垫挨。paper中是要合成和訓(xùn)練集相似的眼睛圖片。

Overview.jpg

模擬器先合成一些圖片(Synthetic),然后用一個Refiner對這個圖片進(jìn)行refine(改善触菜,調(diào)整)九榔,再用一個判別器(discriminator)去判別refine之后的圖片和真實的但沒有標(biāo)注的圖片。目標(biāo)是讓判別器沒有辦法區(qū)分真實圖片和refine出來的圖片。那么我們就可以用模擬器生成一批有標(biāo)注的數(shù)據(jù)帚屉,然后用refiner去進(jìn)行修正谜诫,得到的圖片就和原始的訓(xùn)練數(shù)據(jù)集很相近了。

Objective

這里簡要的概述一下模型需要用到的損失函數(shù)攻旦。
Simulated+Unsupervised learning要用一些沒有標(biāo)注的的真實圖片Y來學(xué)習(xí)一個Refiner喻旷,這個Refiner進(jìn)一步用來refine我們的合成圖片X。
關(guān)鍵點是需要讓合成的圖片x'看起來和真實的圖片差不多牢屋,并且還要保留標(biāo)注的信息且预。比如你要讓你的合成圖片的紋理和真實圖片的紋理是一樣的,同時你不能丟失合成圖片的內(nèi)容信息(realism)(驗證碼上面的數(shù)字字母)烙无。因此有兩個loss需要Refiner去優(yōu)化:

x_refined.png
loss_function.png

上圖中的l_real指的是refine之后的合成圖片(x_i')和真實圖片Y之間的loss锋谐。l_reg是原始合成圖片x_i和被refine之后的合成圖片的x_i'之間的loss。lambda是一個高參截酷。

Refiner的目標(biāo)就是盡可能的糊弄判別器D涮拗,讓判別器沒有辦法區(qū)分一個圖片是real還是合成的。判別器D的目標(biāo)正好相反迂苛,是盡可能的能夠區(qū)分出來三热。那么判別器的loss是這樣的:


discriminator_loss.png

這個是一個二分類的交叉熵,D(.)是輸入圖片是合成圖片的概率三幻,1-D(.)就是輸入圖片是真實圖片的概率就漾。換句話說,如果輸入的圖片是合成圖片念搬,那么loss就是前半部分抑堡,如果輸入是真實圖片,loss就是后半部分朗徊。在實現(xiàn)的細(xì)節(jié)里面首妖,當(dāng)輸入是合成圖片x_i那么label就是1,反之為0荣倾。并且每個mini-batch當(dāng)中悯搔,我們會隨機采樣一部分的真實圖片和一部分的合成圖片。模型方面用了ConvNet舌仍,最后一層輸出是sample是合成圖片的概率。最后用SGD來更新參數(shù)通危。(這里的判別器就是用了一個卷積網(wǎng)絡(luò)铸豁,然后加了一個binary_categorical_crossentropy,再用SGD降低loss)菊碟。

那么和判別器目標(biāo)相反节芥,refiner應(yīng)該是迫使判別器沒有辦法區(qū)分refine之后的合成圖片。所以它的l_real是醬紫的:

l_real.png

接下來是l_reg, 為了保留原始圖片的內(nèi)容信息,我們需要一個loss來迫使模型不要把圖片修改的和原始圖片差異很大头镊,這里引入了self-regularization loss。這個loss就是讓refine之后的圖片像素點和原始的圖片的像素點之間的差不要太大。

綜合起來refiner的loss如下:


refiner_loss.png

在訓(xùn)練過程中苛秕,我們分別減小refiner和discriminator的loss姜凄。在更新refiner的時候就把discriminator的參數(shù)固定住不更新,在更新discriminator的參數(shù)的時候就固定refiner的參數(shù)坛芽。

這里有兩個tricks留储。

  1. local adversarial loss
    refiner在學(xué)習(xí)為真實圖片建模的時候不應(yīng)該引入artifacts, 當(dāng)我們訓(xùn)練一個強判別器的時候咙轩,refiner會傾向于強調(diào)一些圖片特征來fool當(dāng)前的判別器获讳,從而導(dǎo)致生成了一些artifacts。那么怎么解決呢活喊?我可以可以觀察到如果我們從refine的合成圖片上挖出一塊(patch)丐膝,這一塊的統(tǒng)計信息(statistics)應(yīng)該是和真實圖片的統(tǒng)計信息應(yīng)該是相似的。因此钾菊,我們可以不用定義一個全局的判別器(對整張圖片判斷合成Or真實)尤误,我們可以對圖片上的每一塊都判別一下。這樣的話结缚,不僅僅是限定了接收域(receptive field)损晤,也為訓(xùn)練判別器提供了更多的樣本。
    判別器是一個全卷積網(wǎng)絡(luò)红竭,它的輸出是w*h個patches是合成圖片的概率尤勋。所以在更新refiner的時候,我們可以把這些w*h個patches的交叉熵loss相加茵宪。
local_patch.png

比如上面這張圖最冰,輸出就是2*3的矩陣,每個值表示的是這塊patch是合成圖片的概率值稀火。算loss的時候把這6塊圖片的交叉熵都加起來暖哨。

2.用refined的歷史圖片來更新判別器
對抗訓(xùn)練的一個問題是判別器只關(guān)注最近的refined圖片,這會引起兩個問題-對抗訓(xùn)練的分散和refiner網(wǎng)絡(luò)又引進(jìn)了判別器早就忘掉的artifacts凰狞。因此通過用refined的歷史圖片作為一個buffer而不單單是當(dāng)前的mini-batch來更新分類器篇裁。具體方法是,在每一輪分類器的訓(xùn)練中赡若,我們先從當(dāng)前的batch中采樣b/2張圖片达布,然后從大小為B的buffer中采樣b/2張圖片,合在一起來更新判別器的參數(shù)逾冬。然后這一輪之后黍聂,用新生成的b/2張圖片來替換掉B中的b/2張圖片躺苦。

image_buffer.png

參數(shù)細(xì)節(jié)

實現(xiàn)細(xì)節(jié):
Refiner:
輸入圖片55*35=> 64個3*3的filter => 4個resnet block => 1個1*1的fitler => 輸出作為合成的圖片(黑白的,所以1個通道)
1個resnet block是醬紫的:


resnet_block.png

Discriminator:
96個3*3filter, stride=2 => 64個3*3filter, stride = 2 => max_pool: 3*3, stride=1 => 32個3*3filter产还,stride=1 => 32個1*1的filter, stride=1 => 2個1*1的filter, stride=1 => softmax

我們的網(wǎng)絡(luò)都是全卷積網(wǎng)絡(luò)的匹厘,Refiner和Disriminator的最后層是很相似的(refiner的輸出是和原圖一樣大小的, discriminator要把原圖縮一下變成比如W/4 * H/4來表示這么多個patch的概率值)。 首先只用self-regularization loss來訓(xùn)練Refiner網(wǎng)絡(luò)1000步脐区, 然后訓(xùn)練Discriminator 200步愈诚。接著每次更新一次判別器,我們都更新Refiner兩次坡椒。

算法具體細(xì)節(jié)如下:

algorithm.png

項目代碼Overview

challenges:需要預(yù)測的數(shù)據(jù)樣本文件夾
imgs: 從challenges解壓之后的圖片文件夾
SimGAN-Captcha.ipynb: 整個項目的流程notebook
arial-extra.otf: 模擬器生成驗證碼的字體類型
avg.png: 比賽主辦方根據(jù)每個人的信息做了一些加密生成的一些線條扰路,訓(xùn)練的時候需要去掉這些線條。
image_history_buffer.py:

項目代碼.png

預(yù)處理

這部分原本作者是寫了需要從某個地址把圖片對應(yīng)的base64加密的圖片下載下來倔叼,但是因為這個是去年的比賽汗唱,url已經(jīng)不管用了。所以作者把對應(yīng)的文件直接放到了challenges里面丈攒。我們直接從第二步解壓開始就可以了哩罪。因為python2和python3不太一樣,作者應(yīng)該用的是Python2巡验, 我這里給出python3版本的代碼际插。

解壓

每個challenges文件下下的文件都是一個json文件,包含了1000個base64加密的jpg圖片文件显设,所以對每一個文件框弛,我們把base64的str解壓成一個jpeg,然后把他們放到orig文件夾下捕捂。

import requests
import threading
URL = "https://captcha.delorean.codes/u/rickyhan/challenge"
DIR = "challenges/"
NUM_CHALLENGES = 20
lock = threading.Lock()

import json, base64, os
IMG_DIR = "./orig"
fnames = ["{}/challenge-{}".format(DIR, i) for i in range(NUM_CHALLENGES)]
if not os.path.exists(IMG_DIR):
    os.mkdir(IMG_DIR)
def save_imgs(fname):
    with open(fname,'r') as f:
        l = json.loads(f.read(), encoding="latin-1")
    for image in l['images']:
        byte_image = bytes(map(ord,image['jpg_base64']))
        b = base64.decodebytes(byte_image)
        name = image['name']
        with open(IMG_DIR+"/{}.jpg".format(name), 'wb') as f:
            f.write(b)

for fname in fnames:
    save_imgs(fname)
assert len(os.listdir(IMG_DIR)) == 1000 * NUM_CHALLENGES

解壓之后的圖片長這個樣子:

from PIL import Image
imgpath = IMG_DIR + "/"+ os.listdir(IMG_DIR)[0]
imgpath2 = IMG_DIR + "/"+ os.listdir(IMG_DIR)[3]
im = Image.open(example_image_path)
im2 = Image.open(example_image_path2)
IMG_FNAMES = [IMG_DIR + '/' + p for p in os.listdir(IMG_DIR)]
im
im.png
img2
im2.png

轉(zhuǎn)換成黑白圖片

二值圖會節(jié)省很大的計算瑟枫,所以我們這里設(shè)置了一個閾值,然后把圖片一張張轉(zhuǎn)換成相應(yīng)的二值圖指攒。(這里采用的轉(zhuǎn)換方式見下面的注釋慷妙。)

def gray(img_path):
    # convert to grayscale, then binarize
    #L = R * 299/1000 + G * 587/1000 + B * 114/1000
    img = Image.open(img_path).convert("L") # convert to gray scale, one 8-bit byte per pixel
    img = img.point(lambda x: 255 if x > 200 or x == 0 else x) # value found through T&E
    img = img.point(lambda x: 0 if x < 255 else 255, "1") # convert to binary image
    img.save(img_path)

for img_path in IMG_FNAMES:
    gray(img_path)
im = Image.open(example_image_path)
im
binarized.png

抽取mask

可以看到這些圖片上面都有相同的水平的線,前面講過允悦,因為是比賽膝擂,所以這些captcha上的線都是根據(jù)參賽者的名字生成的。在現(xiàn)實生活中隙弛,我們可以用openCV的一些 形態(tài)轉(zhuǎn)換函數(shù)(morphological transformation)來把這些噪音給過濾掉架馋。這里作者用的是把所有圖片相加取平均得到了mask。他也推薦大家可以用bit mask(&=)來過濾掉驶鹉。

mask = np.ones((height, width))
for im in ims:
    mask &= im

這里是把所有圖片相加取平均:

import numpy as np
WIDTH, HEIGHT = im.size
MASK_DIR = "avg.png"
def generateMask():
    N=1000*NUM_CHALLENGES
    arr=np.zeros((HEIGHT, WIDTH),np.float)
    for fname in IMG_FNAMES:
        imarr=np.array(Image.open(fname),dtype=np.float)
        arr=arr+imarr/N
    arr=np.array(np.round(arr),dtype=np.uint8)
    out=Image.fromarray(arr,mode="L")  # save as gray scale
    out.save(MASK_DIR)

generateMask()
im = Image.open(MASK_DIR) # ok this can be done with binary mask: &=
im
mask_before.png

再修正一下

im = Image.open(MASK_DIR)
im = im.point(lambda x:255 if x > 230 else x)
im = im.point(lambda x:0 if x<255 else 255, "1") # 1-bit bilevel, stored with the leftmost pixel in the most significant bit. 0 means black, 1 means white.
im.save(MASK_DIR)
im
mask_after.png

真實圖片的生成器

我們在訓(xùn)練的時候也需要把真實的圖片丟進(jìn)去绩蜻,所以這里直接用keras的flow_from_directory來自動生成圖片并且把圖片做一些預(yù)處理。

from keras import models
from keras import layers
from keras import optimizers
from keras import applications
from keras.preprocessing import image
import tensorflow as tf
# Real data generator

datagen = image.ImageDataGenerator(
    preprocessing_function=applications.xception.preprocess_input
    #  調(diào)用imagenet_utils的preoprocess input函數(shù)
    #  tf: will scale pixels between -1 and 1,sample-wise.
)

flow_from_directory_params = {'target_size': (HEIGHT, WIDTH),
                              'color_mode': 'grayscale',
                              'class_mode': None,
                              'batch_size': BATCH_SIZE}

real_generator = datagen.flow_from_directory(
        directory=".",
        **flow_from_directory_params
)

(Dumb)生成器(模擬器Simulator)

接著我們需要定義個生成器來幫我們生成(驗證碼室埋,標(biāo)注label)對,這些生成的驗證碼應(yīng)該盡可能的和真實圖片的那些比較像。

# Synthetic captcha generator
from PIL import ImageFont, ImageDraw
from random import choice, random
from string import ascii_lowercase, digits
alphanumeric = ascii_lowercase + digits


def fuzzy_loc(locs):
    acc = []
    for i,loc in enumerate(locs[:-1]):
        if locs[i+1] - loc < 8:
            continue
        else:
            acc.append(loc)
    return acc

def seg(img):
    arr = np.array(img, dtype=np.float)
    arr = arr.transpose()
    # arr = np.mean(arr, axis=2)
    arr = np.sum(arr, axis=1)
    locs = np.where(arr < arr.min() + 2)[0].tolist()
    locs = fuzzy_loc(locs)
    return locs

def is_well_formed(img_path):
    original_img = Image.open(img_path)
    img = original_img.convert('1')
    return len(seg(img)) == 4

noiseimg = np.array(Image.open("avg.png").convert("1"))
# noiseimg = np.bitwise_not(noiseimg)
fnt = ImageFont.truetype('./arial-extra.otf', 26)
def gen_one():
    og = Image.new("1", (100,50))
    text = ''.join([choice(alphanumeric) for _ in range(4)])
    draw = ImageDraw.Draw(og)
    for i, t in enumerate(text):
        txt=Image.new('L', (40,40))
        d = ImageDraw.Draw(txt)
        d.text( (0, 0), t,  font=fnt, fill=255)
        if random() > 0.5:
            w=txt.rotate(-20*(random()-1),  expand=1)
            og.paste( w, (i*20 + int(25*random()), int(25+30*(random()-1))),  w)
        else:
            w=txt.rotate(20*(random()-1),  expand=1)
            og.paste( w, (i*20 + int(25*random()), int(20*random())),  w)
    segments = seg(og)
    if len(segments) != 4:
        return gen_one()
    ogarr = np.array(og)
    ogarr = np.bitwise_or(noiseimg, ogarr)
    ogarr = np.expand_dims(ogarr, axis=2).astype(float)
    ogarr = np.random.random(size=(50,100,1)) * ogarr
    ogarr = (ogarr > 0.0).astype(float) # add noise
    return ogarr, text


def synth_generator():
    arrs = []
    while True:
        for _ in range(BATCH_SIZE):
            img, text = gen_one()
            arrs.append(img)
        yield np.array(arrs)
        arrs = []

上面這段代碼主要是隨機產(chǎn)生了不同的字符數(shù)字姚淆,然后進(jìn)行旋轉(zhuǎn)孕蝉,之后把字符貼在一起,把原來的那個噪音圖片avg.png加上去腌逢,把一些重合的字符的驗證碼給去掉降淮。這里如果發(fā)現(xiàn)有問題,強烈建議先升級一下PILLOW搏讶,debug了好久....sigh~

def get_image_batch(generator):
    """keras generators may generate an incomplete batch for the last batch"""
    #img_batch = generator.next()
    img_batch = next(generator)
    if len(img_batch) != BATCH_SIZE:
        img_batch = generator.next()

    assert len(img_batch) == BATCH_SIZE

    return img_batch

看一下真實的圖片長什么樣子

import matplotlib.pyplot as plt
%matplotlib inline
imarr = get_image_batch(real_generator)
imarr = imarr[0, :, :, 0]
plt.imshow(imarr)
real_image.png

我們生成的圖片長什么樣子

imarr = get_image_batch(synth_generator())[0, :, :, 0]
print imarr.shape
plt.imshow(imarr)
synthesized_image.png

注意上面的圖片之所以顯示的有顏色是因為用了plt.imshow, 實際上是灰白的二值圖佳鳖。

這部分生成的代碼,我個人覺得讀者可以直接在github上下載一個驗證碼生成器就好媒惕,然后把圖片根據(jù)之前的步驟搞成二值圖就行系吩,而且可以盡可能的選擇跟自己需要預(yù)測的驗證碼比較相近的字體。

模型定義

整個網(wǎng)絡(luò)一共有三個部分

  1. Refiner
    Refiner,Rθ,是一個RestNet, 它在像素維度上去修改我們生成的圖片妒蔚,而不是整體的修改圖片內(nèi)容穿挨,這樣才可以保留整體圖片的結(jié)構(gòu)和標(biāo)注。(要不然就尷尬了肴盏,萬一把字母a都變成別的字母標(biāo)注就不準(zhǔn)確了)
  2. Discriminator
    判別器科盛,Dφ,是一個簡單的ConvNet, 包含了5個卷積層和2個max-pooling層菜皂,是一個二分類器贞绵,區(qū)分一個驗證碼是我們合成的還是真實的樣本集。
  3. 把他們合在一起
    把refined的圖片合到判別器里面

Refiner

主要是4個resnet_block疊加在一起恍飘,最后再用一個1*1的filter來構(gòu)造一個feature_map作為生成的圖片榨崩。可以看到全部的border_mode都是same常侣,也就是說當(dāng)中任何一步的輸出都和原始的圖片長寬保持一致(fully convolution)蜡饵。
一個resnet_block是醬紫的:

resnet_block.png

我們先把輸入圖片用64個3*3的filter去conv一下,得到的結(jié)果(input_features)再把它丟到4個resnet_block中去胳施。

def refiner_network(input_image_tensor):
    """
    :param input_image_tensor: Input tensor that corresponds to a synthetic image.
    :return: Output tensor that corresponds to a refined synthetic image.
    """
    def resnet_block(input_features, nb_features=64, nb_kernel_rows=3, nb_kernel_cols=3):
        """
        A ResNet block with two `nb_kernel_rows` x `nb_kernel_cols` convolutional layers,
        each with `nb_features` feature maps.
        See Figure 6 in https://arxiv.org/pdf/1612.07828v1.pdf.
        :param input_features: Input tensor to ResNet block.
        :return: Output tensor from ResNet block.
        """
        y = layers.Convolution2D(nb_features, nb_kernel_rows, nb_kernel_cols, border_mode='same')(input_features)
        y = layers.Activation('relu')(y)
        y = layers.Convolution2D(nb_features, nb_kernel_rows, nb_kernel_cols, border_mode='same')(y)

        y = layers.merge([input_features, y], mode='sum')
        return layers.Activation('relu')(y)

    # an input image of size w × h is convolved with 3 × 3 filters that output 64 feature maps
    x = layers.Convolution2D(64, 3, 3, border_mode='same', activation='relu')(input_image_tensor)

    # the output is passed through 4 ResNet blocks
    for _ in range(4):
        x = resnet_block(x)

    # the output of the last ResNet block is passed to a 1 × 1 convolutional layer producing 1 feature map
    # corresponding to the refined synthetic image
    return layers.Convolution2D(1, 1, 1, border_mode='same', activation='tanh')(x)

Discriminator

這里注意一下subsample就是strides, 由于subsample=(2,2)所以會把圖片長寬減半,因為有兩個溯祸,所以最后的圖片會變成原來的1/16左右。比如一開始圖片大小是10050, 經(jīng)過一次變換之后是5025舞肆,再經(jīng)過一次變換之后是25*13焦辅。

Discriminator_detail.png

最后生成了兩個feature_map,一個是用來判斷是不是real還有一個用來判斷是不是refined的椿胯。

def discriminator_network(input_image_tensor):
    """
    :param input_image_tensor: Input tensor corresponding to an image, either real or refined.
    :return: Output tensor that corresponds to the probability of whether an image is real or refined.
    """
    x = layers.Convolution2D(96, 3, 3, border_mode='same', subsample=(2, 2), activation='relu')(input_image_tensor)
    x = layers.Convolution2D(64, 3, 3, border_mode='same', subsample=(2, 2), activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(3, 3), border_mode='same', strides=(1, 1))(x)
    x = layers.Convolution2D(32, 3, 3, border_mode='same', subsample=(1, 1), activation='relu')(x)
    x = layers.Convolution2D(32, 1, 1, border_mode='same', subsample=(1, 1), activation='relu')(x)
    x = layers.Convolution2D(2, 1, 1, border_mode='same', subsample=(1, 1), activation='relu')(x)

    # here one feature map corresponds to `is_real` and the other to `is_refined`,
    # and the custom loss function is then `tf.nn.sparse_softmax_cross_entropy_with_logits`
    return layers.Reshape((-1, 2))(x)    # (batch_size, # of local patches, 2)

把它們合起來

refiner 加到discriminator中去筷登。這里有兩個loss:

  1. self_regularization_loss
    論文中是這么寫的: The self-regularization term minimizes the image difference
    between the synthetic and the refined images.
    就是用來控制refine的圖片不至于跟原來的圖片差別太大,由于paper中沒有具體寫公式哩盲,但是大致就是讓生成的像素值和原始圖片的像素值之間的距離不要太大前方。這里項目的原作者是用了:
def self_regularization_loss(y_true, y_pred):
    delta = 0.0001  # FIXME: need to figure out an appropriate value for this
    return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

y_true: 丟到refiner里面的input_image_tensor
y_pred: refiner的output
這里的delta是用來控制這個loss的權(quán)重狈醉,論文里面是lambda。
整個loss就是把refiner的輸入圖片和輸出圖片的每個像素點值相減取絕對值惠险,最后把整張圖片的差值都相加起來再乘以delta苗傅。

  1. local_adversarial_loss
    為了讓refiner能夠?qū)W習(xí)到真實圖片的特征而不是一些artifacts來欺騙判別器,我們認(rèn)為我們從refined的圖片中sample出來的patch, 應(yīng)該是和真實圖片的patch的statistics是相似的班巩。所以我們在所有的local patches上定義判別器而不是學(xué)習(xí)一個全局的判別器渣慕。
def local_adversarial_loss(y_true, y_pred):
    # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
    # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
    y_true = tf.reshape(y_true, (-1, 2))
    y_pred = tf.reshape(y_pred, (-1, 2))
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)

    return tf.reduce_mean(loss)

合起來如下:

# Refiner
synthetic_image_tensor = layers.Input(shape=(HEIGHT, WIDTH, 1)) #合成的圖片
refined_image_tensor = refiner_network(synthetic_image_tensor)
refiner_model = models.Model(input=synthetic_image_tensor, output=refined_image_tensor, name='refiner') 

# Discriminator
refined_or_real_image_tensor = layers.Input(shape=(HEIGHT, WIDTH, 1)) #真實的圖片
discriminator_output = discriminator_network(refined_or_real_image_tensor)
discriminator_model = models.Model(input=refined_or_real_image_tensor, output=discriminator_output,
                                   name='discriminator')

# Combined
refiner_model_output = refiner_model(synthetic_image_tensor)
combined_output = discriminator_model(refiner_model_output)
combined_model = models.Model(input=synthetic_image_tensor, output=[refiner_model_output, combined_output],
                              name='combined')

def self_regularization_loss(y_true, y_pred):
    delta = 0.0001  # FIXME: need to figure out an appropriate value for this
    return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

# define custom local adversarial loss (softmax for each image section) for the discriminator
# the adversarial loss function is the sum of the cross-entropy losses over the local patches
def local_adversarial_loss(y_true, y_pred):
    # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
    # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
    y_true = tf.reshape(y_true, (-1, 2))
    y_pred = tf.reshape(y_pred, (-1, 2))
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)

    return tf.reduce_mean(loss)


# compile models
BATCH_SIZE = 512
sgd = optimizers.RMSprop()

refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
discriminator_model.trainable = False
combined_model.compile(optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

預(yù)訓(xùn)練

預(yù)訓(xùn)練對于GAN來說并不是一定需要的,但是預(yù)訓(xùn)練可以讓GAN收斂的更快一些抱慌。這里我們兩個模型都先預(yù)訓(xùn)練逊桦。
對于真實樣本label標(biāo)注為[1,0], 對于合成的圖片label為[0,1]。

# the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
# discriminator_model.output_shape = num of local patches
y_real = np.array([[[1.0, 0.0]] * discriminator_model.output_shape[1]] * BATCH_SIZE)
y_refined = np.array([[[0.0, 1.0]] * discriminator_model.output_shape[1]] * BATCH_SIZE)
assert y_real.shape == (BATCH_SIZE, discriminator_model.output_shape[1], 2)

對于refiner, 我們根據(jù)self_regularization_loss來預(yù)訓(xùn)練抑进,也就是說對于refiner的輸入和輸出都是同一張圖(類似于auto-encoder)强经。

LOG_INTERVAL = 10
MODEL_DIR = "./model/"
print('pre-training the refiner network...')
gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

for i in range(100):
    synthetic_image_batch = get_image_batch(synth_generator())
    gen_loss = np.add(refiner_model.train_on_batch(synthetic_image_batch, synthetic_image_batch), gen_loss)

    # log every `log_interval` steps
    if not i % LOG_INTERVAL:
        print('Refiner model self regularization loss: {}.'.format(gen_loss / LOG_INTERVAL))
        gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

refiner_model.save(os.path.join(MODEL_DIR, 'refiner_model_pre_trained.h5'))··

對于判別器,我們用一個batch的真實圖片來訓(xùn)練单匣,再用另一個batch的合成圖片來交替訓(xùn)練夕凝。

from tqdm import tqdm
print('pre-training the discriminator network...')
disc_loss = np.zeros(shape=len(discriminator_model.metrics_names))

for _ in tqdm(range(100)):
    real_image_batch = get_image_batch(real_generator)
    disc_loss = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss)

    synthetic_image_batch = get_image_batch(synth_generator())
    refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
    disc_loss = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined), disc_loss)

discriminator_model.save(os.path.join(MODEL_DIR, 'discriminator_model_pre_trained.h5'))

# hard-coded for now
print('Discriminator model loss: {}.'.format(disc_loss / (100 * 2)))

訓(xùn)練

這里有兩個點1)用refined的歷史圖片來更新判別器,2)訓(xùn)練的整體流程
1)用refined的歷史圖片來更新判別器
對抗訓(xùn)練的一個問題是判別器只關(guān)注最近的refined圖片户秤,這會引起兩個問題-對抗訓(xùn)練的分散和refiner網(wǎng)絡(luò)又引進(jìn)了判別器早就忘掉的artifacts码秉。因此通過用refined的歷史圖片作為一個buffer而不單單是當(dāng)前的mini-batch來更新分類器。具體方法是鸡号,在每一輪分類器的訓(xùn)練中转砖,我們先從當(dāng)前的batch中采樣b/2張圖片,然后從大小為B的buffer中采樣b/2張圖片鲸伴,合在一起來更新判別器的參數(shù)府蔗。然后這一輪之后,用新生成的b/2張圖片來替換掉B中的b/2張圖片汞窗。

buffer.png

由于論文中沒有寫B(tài)的大小為多少姓赤,這里作者用了100*batch_size作為buffer的大小。

2)訓(xùn)練流程
xi是合成的的圖片
yj是真實的圖片
T是步數(shù)(steps)
K_d是每個step仲吏,判別器更新的次數(shù)
K_g是每個step不铆,生成網(wǎng)絡(luò)的更新次數(shù)(refiner的更新次數(shù))


Algorithm1.png

這里要注意在判別器更新的每一輪,其中的合成的圖片的minibatch已經(jīng)用1)當(dāng)中的采樣方式來替代了裹唆。

from image_history_buffer import ImageHistoryBuffer


k_d = 1  # number of discriminator updates per step
k_g = 2  # number of generative network updates per step
nb_steps = 1000

# TODO: what is an appropriate size for the image history buffer?
image_history_buffer = ImageHistoryBuffer((0, HEIGHT, WIDTH, 1), BATCH_SIZE * 100, BATCH_SIZE)

combined_loss = np.zeros(shape=len(combined_model.metrics_names))
disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

# see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
for i in range(nb_steps):
    print('Step: {} of {}.'.format(i, nb_steps))

    # train the refiner
    for _ in range(k_g * 2):
        # sample a mini-batch of synthetic images
        synthetic_image_batch = get_image_batch(synth_generator())

        # update θ by taking an SGD step on mini-batch loss LR(θ)
        combined_loss = np.add(combined_model.train_on_batch(synthetic_image_batch,
                                                             [synthetic_image_batch, y_real]), combined_loss) #注意combine模型的local adversarial loss是要用y_real來對抗學(xué)習(xí)誓斥,從而迫使refiner去修改圖片來做到跟真實圖片很像

    for _ in range(k_d):
        # sample a mini-batch of synthetic and real images
        synthetic_image_batch = get_image_batch(synth_generator())
        real_image_batch = get_image_batch(real_generator)

        # refine the synthetic images w/ the current refiner
        refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

        # use a history of refined images
        half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
        image_history_buffer.add_to_image_history_buffer(refined_image_batch)

        if len(half_batch_from_image_history):
            refined_image_batch[:batch_size // 2] = half_batch_from_image_history

        # update φ by taking an SGD step on mini-batch loss LD(φ)
        disc_loss_real = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss_real)
        disc_loss_refined = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined),
                                   disc_loss_refined)

    if not i % LOG_INTERVAL:
        # log loss summary
        print('Refiner model loss: {}.'.format(combined_loss / (LOG_INTERVAL * k_g * 2)))
        print('Discriminator model loss real: {}.'.format(disc_loss_real / (LOG_INTERVAL * k_d * 2)))
        print('Discriminator model loss refined: {}.'.format(disc_loss_refined / (LOG_INTERVAL * k_d * 2)))

        combined_loss = np.zeros(shape=len(combined_model.metrics_names))
        disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
        disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

        # save model checkpoints
        model_checkpoint_base_name = os.path.join(MODEL_DIR, '{}_model_step_{}.h5')
        refiner_model.save(model_checkpoint_base_name.format('refiner', i))
        discriminator_model.save(model_checkpoint_base_name.format('discriminator', i))

SimGAN的結(jié)果

我們從合成圖片的生成器中拿一個batch的圖片,用訓(xùn)練好的refiner去Predict一下许帐,然后顯示其中的一張圖(我運行生成的圖片當(dāng)中是一些點點的和作者的不太一樣劳坑,但是跟真實圖片更像,待補充):

synthetic_image_batch = get_image_batch(synth_generator())
arr = refiner_model.predict_on_batch(synthetic_image_batch)
plt.imshow(arr[200, :, :, 0])
plt.show()
refiner_output.png
plt.imshow(get_image_batch(real_generator)[2,:,:,0])
plt.show()
real_image_output.png

這里作者認(rèn)為生成的圖片中字母的邊都模糊和有噪音的成畦,不那么的平滑了距芬。(我覺得和原始圖片比起來涝开,在refine之前的圖片看起來和真實圖片也很像啊,唯一不同的應(yīng)該是當(dāng)中那些若有若無的點啊蔑穴,讀者可以在生成圖片的時候把噪音給去掉忠寻,再來refine圖片惧浴,看能不能生成字母邊是比較噪音的(noisy)存和,我這邊refine之后的圖片就是當(dāng)中有一點一點的,圖片待補充)

開始運用到實際的驗證碼識別

那么有了可以很好的生成和要預(yù)測的圖片很像的refiner之后衷旅,我們就可以構(gòu)造我們的驗證碼分類模型了捐腿,這里作者用了多輸出的模型,就是給定一張圖片柿顶,有固定的輸出(這里是4茄袖,因為要預(yù)測4個字母)。

我們先用之前的合成圖片的生成器(gen_one)來構(gòu)造一個生成器嘁锯,接著用refiner_model來預(yù)測一下作為這個generator的輸出圖片宪祥。由于分類模型的輸出要用categorical_crossentropy,所以我們需要把輸出的字母變成one-hot形式家乘。

n_class = len(alphanumeric)
def mnist_generator(batch_size=128):
    X = np.zeros((batch_size, HEIGHT, WIDTH, 1), dtype=np.uint8)
    y = [np.zeros((batch_size, n_class), dtype=np.uint8) for _ in range(4)] # 4 chars
    while True:
        for i in range(batch_size):
            im, random_str = gen_one()
            X[i] = im
            for j, ch in enumerate(random_str):
                y[j][i, :] = 0
                y[j][i, alphanumeric.find(ch)] = 1   # one_hot形式蝗羊,讓當(dāng)前字母的index為1
        yield refiner_model.predict(np.array(X)), y

mg = mnist_generator().next()

建模

from keras.layers import *

input_tensor = Input((HEIGHT, WIDTH, 1))
x = input_tensor
x = Conv2D(32, kernel_size=(3, 3),
                 activation='relu')(x)
# 4個conv-max_polling
for _ in range(4):
    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
x = [Dense(n_class, activation='softmax', name='c%d'%(i+1))(x) for i in range(4)] # 4個輸出

model = models.Model(inputs=input_tensor, outputs=x)
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

from keras.callbacks import History
history = History()  # history call back現(xiàn)在已經(jīng)是每個模型在訓(xùn)練的時候都會自帶的了,fit函數(shù)會返回仁锯,主要用于記錄事件耀找,比如loss之類的
model.fit_generator(mnist_generator(), steps_per_epoch=1000, epochs=20, callbacks=[history])

測試模型

先看一下在合成圖片上的預(yù)測:

def decode(y):
    y = np.argmax(np.array(y), axis=2)[:,0]
    return ''.join([alphanumeric[x] for x in y])

X, y = next(mnist_generator(1))
y_pred = model.predict(X)
plt.title('real: %s\npred:%s'%(decode(y), decode(y_pred)))
plt.imshow(X[0, :, :, 0], cmap='gray')
plt.axis('off')
synthetic_predict.png

看一下對于要預(yù)測的圖片的預(yù)測:

X = next(real_generator)
X = refiner_model.predict(X) 
 # 不確定作者為什么要用refiner來predict,應(yīng)該是可以省去這一步的
# 事實證明是不可以的业崖,后面會分析
y_pred = model.predict(X)
plt.title('pred:%s'%(decode(y_pred)))
plt.imshow(X[0,:,:,0], cmap='gray')
plt.axis('off')
real_predict.png

后續(xù)補充

  1. 將預(yù)測模型這里的圖片替換掉野芒,改成實際操作時候生成的圖片
    在訓(xùn)練過程中可以發(fā)現(xiàn)判別器的loss下降的非常快双炕,并且到后面很難讓refine的和real的loss都變高狞悲。有的時候運氣好的話也許可以。我在訓(xùn)練的時候出現(xiàn)了兩種情況:
    第一種情況:
    合成前:


    syn_before.png

    合成后:


    syn_after.png

    可以看到合成之后的圖片中也是有一點一點的妇斤。拿這種圖片去做訓(xùn)練摇锋,后面對真實圖片做預(yù)測的時候就可以直接丟進(jìn)分類器訓(xùn)練了。

第二種情況(作者notebook中展示的):
也就是前面寫到的情況趟济。
類似于下面這樣乱投,看起來refiner之后沒什么變化的感覺:


syn2_after.png

這個看起來并沒有感覺和真實圖片很像啊G瓯唷F蒽拧!
可是神奇的是媳纬,作者在預(yù)測真實的圖片的時候双肤,他居然用refiner去predict真實的圖片施掏!
真實的圖片之前是長這個樣子的:


real_before_refiner.png

refiner之后居然長成了這樣:
real_after_refiner.png

無語了呢!它居然把那些噪聲點給去掉了一大半........他這波反向的操作讓我很措手不及茅糜。于是他用refine之后的真實圖片丟到分類器去做預(yù)測.....效果居然還不錯.....

反正我已經(jīng)凌亂了呢..............................

不過如何讓模型能夠?qū)W到我們?nèi)四X做識別的過程是件非常重要的事情呢...這里如果你想用合成的圖片直接當(dāng)作訓(xùn)練集去訓(xùn)練然后預(yù)測真實圖片七芭,準(zhǔn)確率應(yīng)該會非常低(我試了一下),也就是說模型在學(xué)習(xí)的過程中還是沒有學(xué)習(xí)到字符的輪廓概念蔑赘,但是我們又沒辦法控制教會它去學(xué)習(xí)怎么"識別"物體狸驳,應(yīng)該學(xué)習(xí)哪些特征,最近發(fā)布的論文(戳這里)大家可以去看看(我還沒有看...)缩赛。

未完待續(xù)

  1. 評估準(zhǔn)確率
  2. 修改驗證碼生成器耙箍,改成其他任意的生成器
  3. 將模型用到更復(fù)雜的背景的驗證碼上,評估準(zhǔn)確率
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末酥馍,一起剝皮案震驚了整個濱河市辩昆,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌旨袒,老刑警劉巖汁针,帶你破解...
    沈念sama閱讀 219,110評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異砚尽,居然都是意外死亡施无,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評論 3 395
  • 文/潘曉璐 我一進(jìn)店門尉辑,熙熙樓的掌柜王于貴愁眉苦臉地迎上來帆精,“玉大人,你說我怎么就攤上這事隧魄∽苛罚” “怎么了?”我有些...
    開封第一講書人閱讀 165,474評論 0 356
  • 文/不壞的土叔 我叫張陵购啄,是天一觀的道長襟企。 經(jīng)常有香客問我,道長狮含,這世上最難降的妖魔是什么顽悼? 我笑而不...
    開封第一講書人閱讀 58,881評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮几迄,結(jié)果婚禮上蔚龙,老公的妹妹穿的比我還像新娘。我一直安慰自己映胁,他們只是感情好木羹,可當(dāng)我...
    茶點故事閱讀 67,902評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般坑填。 火紅的嫁衣襯著肌膚如雪抛人。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,698評論 1 305
  • 那天脐瑰,我揣著相機與錄音妖枚,去河邊找鬼。 笑死苍在,一個胖子當(dāng)著我的面吹牛绝页,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播忌穿,決...
    沈念sama閱讀 40,418評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼抒寂,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了掠剑?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,332評論 0 276
  • 序言:老撾萬榮一對情侶失蹤郊愧,失蹤者是張志新(化名)和其女友劉穎朴译,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體属铁,經(jīng)...
    沈念sama閱讀 45,796評論 1 316
  • 正文 獨居荒郊野嶺守林人離奇死亡眠寿,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,968評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了焦蘑。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片盯拱。...
    茶點故事閱讀 40,110評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖例嘱,靈堂內(nèi)的尸體忽然破棺而出狡逢,到底是詐尸還是另有隱情,我是刑警寧澤拼卵,帶...
    沈念sama閱讀 35,792評論 5 346
  • 正文 年R本政府宣布奢浑,位于F島的核電站,受9級特大地震影響腋腮,放射性物質(zhì)發(fā)生泄漏雀彼。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,455評論 3 331
  • 文/蒙蒙 一即寡、第九天 我趴在偏房一處隱蔽的房頂上張望徊哑。 院中可真熱鬧,春花似錦聪富、人聲如沸莺丑。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽窒盐。三九已至草则,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間蟹漓,已是汗流浹背炕横。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留葡粒,地道東北人份殿。 一個月前我還...
    沈念sama閱讀 48,348評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像嗽交,于是被迫代替她去往敵國和親卿嘲。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,047評論 2 355

推薦閱讀更多精彩內(nèi)容