一種基于樣式的生成器架構(gòu),用于生成對(duì)抗性網(wǎng)絡(luò)
Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
http://stylegan.xyz/paper
摘要:我們借鑒風(fēng)格轉(zhuǎn)換文獻(xiàn)髓迎,提出了一種用于生成對(duì)抗網(wǎng)絡(luò)的替代生成器架構(gòu)召川。新的架構(gòu)會(huì)導(dǎo)致自動(dòng)學(xué)習(xí),無監(jiān)督的高級(jí)屬性分離 (例如, 在人臉上訓(xùn)練時(shí)的姿勢(shì)和身份) 和生成的圖像中的隨機(jī)變化 (例如雀斑、頭發(fā)), 并且使它可以直觀地、特定規(guī)模的控制合成刑枝。新生成器在傳統(tǒng)的分布質(zhì)量指標(biāo)方面改進(jìn)了最先進(jìn)的技術(shù), 從而顯著改善了插值特性, 并更好地消除了潛在的變異因素迅腔。為了量化插值質(zhì)量和分離, 我們提出了兩種新的自動(dòng)化方法, 適用于任何生成器架構(gòu)的自動(dòng)化方法装畅。最后, 我們介紹了一個(gè)新的, 高度多樣化和高質(zhì)量的人臉數(shù)據(jù)集。
系統(tǒng)要求
- 支持Linux和Windows沧烈,單出于性能和兼容性要求的考慮掠兄,官方建議使用Linux。
- 64位的Python3锌雀,建議使用Anaconda3蚂夕,且numpy版本1.14.3或更新。
- 支持GPU的Tensorflow版本1.10.0或更新腋逆。
- 一個(gè)或多個(gè)具有至少11GB DRAM的高端NVIDIA GPU婿牍。官方推薦推薦配備8個(gè)Tesla V100 GPU的NVIDIA DGX-1。
- NVIDIA驅(qū)動(dòng)版本391.35或更新惩歉,CUDA工具包9.0或更新等脂,cuDNN7.3.1或更新。
這其中必須項(xiàng)有:
- NVIDIA GPU的電腦(硬件條件)
- NVIDIA驅(qū)動(dòng)(驅(qū)動(dòng)顯卡)
- CUDA(NVIDIA并行計(jì)算框架)撑蚌,cuDNN是深度神經(jīng)網(wǎng)絡(luò)的加速庫(kù)非必須
- GPU版的Tensorflow(深度學(xué)習(xí)框架)
下載運(yùn)行模型的腳本
官方提供了StyleGan的GitHub地址上遥,把代碼下載下來進(jìn)行解壓本地目錄下,同時(shí)你需要將目錄路徑添加到環(huán)境變量PYTHONPATH争涌,為的是導(dǎo)入文件夾下的模塊粉楚。
注意:變量名為PYTHONPATH,沒有就新增一個(gè)亮垫,變量值為路徑解幼。
使用預(yù)訓(xùn)練網(wǎng)絡(luò)
pretrained_example.py有給到使用預(yù)訓(xùn)練StyleGAN生成器的最小示例。執(zhí)行腳本后會(huì)從谷歌網(wǎng)盤下載預(yù)訓(xùn)練StyleGAN生成器并生成一張圖片包警,圖片會(huì)在目錄下的/results/example.png
看到。因?yàn)楣雀杈W(wǎng)盤的緣故我們無法直接下載底靠,需要預(yù)訓(xùn)練模型的可以直接從這里下(提取碼: 9vx8)害晦。下載好的karras2019stylegan-ffhq-1024x1024.pkl
直接放到目錄里就行。
直接在命令行下執(zhí)行 python pretrained_example.py
,如果沒有網(wǎng)絡(luò)問題會(huì)見到下圖的打印信息壹瘟,這里我們直接下載好預(yù)訓(xùn)練生成器鲫剿,所以代碼需要改改,打開pretrained_example.py
改成下面這樣稻轨,即把網(wǎng)絡(luò)下載變成直接讀取本地文件灵莲,并將原代碼行注釋。
# with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
with open('karras2019stylegan-ffhq-1024x1024.pkl', 'rb') as f:
_G, _D, Gs = pickle.load(f)
調(diào)整完之后只要我們運(yùn)行pretrained_example.py
代碼即可生成example.png
圖片殴俱,如果你想生成其他隨機(jī)圖片的話只需要把5修改為其他數(shù)字即可:
rnd = np.random.RandomState(5)
generate_figures.py給出了一個(gè)更加高級(jí)的示例政冻。這個(gè)腳本復(fù)制了論文中的圖形,以說明樣式混合线欲、噪聲輸入和截?cái)?
預(yù)先訓(xùn)練好的網(wǎng)絡(luò)存儲(chǔ)為標(biāo)準(zhǔn)的pickle文件在谷歌網(wǎng)盤上明场,同樣的需要將腳本中dnnlib.util.open_url
函數(shù)改成直接讀取pkl文件:
def load_Gs(file):
if file not in _Gs_cache:
with open(file, 'rb') as f:
_G, _D, Gs = pickle.load(f)
_Gs_cache[file] = Gs
return _Gs_cache[file]
main
主函數(shù)部分中的load_Gs
的參數(shù)調(diào)整為文件路徑:
load_Gs('karras2019stylegan-ffhq-1024x1024.pkl')
下面的代碼將會(huì)生成dnnlib.tflib.Network的3個(gè)實(shí)例。為了生成圖像李丰,您通常需要使用Gs—另外兩個(gè)網(wǎng)絡(luò)是完整的苦锨。為了讓pickle.load()工作,你需要包含dnnlib
的源目錄添加到環(huán)境變量PYTHONPATH中和tf.Session
設(shè)置為默認(rèn)趴泌≈凼妫可以通過調(diào)用dnnlib.tflib.init_tf()
初始化Session。
with open('karras2019stylegan-ffhq-1024x1024.pkl', 'rb') as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
有三種方法使用預(yù)先訓(xùn)練的生成器:
- 使用
Gs.run()
進(jìn)行輸入和輸出為numpy數(shù)組的快速模式操作:
# 選擇特征向量
rnd = np.random.RandomState(5)
latents = rnd.randn(1, Gs.input_shape[1])
# 生成圖像
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
第一個(gè)參數(shù)是一批形狀為[num, 512]的特征向量嗜憔,第二個(gè)參數(shù)預(yù)留給類別標(biāo)簽(StypeGan并沒有使用秃励,所以參數(shù)為None)。其余的關(guān)鍵字參數(shù)是可選的痹筛,可用于進(jìn)一步修改操作(參見下面)莺治。輸出是一批圖像,其格式由output_transform參數(shù)決定帚稠。
- 使用
Gs.get_output_for()
將生成器合并為一個(gè)更大的TensorFlow表達(dá)式的一部分:
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
images = tflib.convert_images_to_uint8(images)
result_expr.append(inception_clone.get_output_for(images))
前面的代碼來自metrics/frechet_inception_distance.py谣旁。它生成一批隨機(jī)圖像,并將它們直接提供給Inception-v3網(wǎng)絡(luò)滋早,而無需在中間將數(shù)據(jù)轉(zhuǎn)換為numpy數(shù)組榄审。
- 查找
Gs.components.mapping
和Gs.components.synthesis
以訪問生成器的各個(gè)子網(wǎng)絡(luò)。與Gs
類似杆麸,子網(wǎng)絡(luò)表示為dnnlib.tflib.Network
的獨(dú)立實(shí)例:
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
上面的代碼來自generate_figures.py搁进。首先利用映射網(wǎng)絡(luò)將一批特征向量轉(zhuǎn)化為中間的W空間,然后利用合成網(wǎng)絡(luò)將這些向量轉(zhuǎn)化為一批圖像昔头。dlatents
數(shù)組為合成網(wǎng)絡(luò)的每一層存儲(chǔ)同一w向量的單獨(dú)副本饼问,以方便樣式混合。
為訓(xùn)練準(zhǔn)備數(shù)據(jù)集
訓(xùn)練和評(píng)估腳本對(duì)存儲(chǔ)為多分辨率TFRecords的數(shù)據(jù)集進(jìn)行操作揭斧。每個(gè)數(shù)據(jù)集都由一個(gè)目錄表示莱革,其中包含幾個(gè)分辨率相同的圖像數(shù)據(jù),以支持有效的流。還有一個(gè)每個(gè)分辨率單獨(dú)的*.tfrecords文件盅视,如果數(shù)據(jù)集包含標(biāo)簽捐名,它們也存儲(chǔ)在單獨(dú)的文件中。默認(rèn)情況下闹击,腳本期望在datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords
中找到數(shù)據(jù)集镶蹋。可以通過編輯config.py
來更改目錄:
result_dir = 'results' # 結(jié)果目錄
data_dir = 'datasets' # 數(shù)據(jù)目錄
cache_dir = 'cache' # 緩存目錄
訓(xùn)練網(wǎng)絡(luò)
設(shè)置好數(shù)據(jù)集后赏半,你就可以訓(xùn)練你自己的StyleGAN網(wǎng)絡(luò):
- 編輯train.py贺归,通過取消注釋或編輯特定行來指定數(shù)據(jù)集和訓(xùn)練配置。
- 使用
python train.py
來運(yùn)行訓(xùn)練腳本除破。 - 結(jié)果被寫入一個(gè)新創(chuàng)建的目錄
results/<ID>-<DESCRIPTION>
牧氮。 - 訓(xùn)練可能需要幾天(或幾周)才能完成,這取決于機(jī)器配置瑰枫。
使用Tesla V100 GPU的默認(rèn)配置的預(yù)計(jì)培訓(xùn)時(shí)間:
GPU | 1024×1024 | 512×512 | 256×256 |
---|---|---|---|
1 | 41 天 4小時(shí) | 24 天 21 小時(shí) | 14 天 22 小時(shí) |
2 | 21 天 22 小時(shí) | 13 天 7 小時(shí) | 9 天 5 小時(shí) |
4 | 11 天 8 小時(shí) | 7 天 0 小時(shí) | 4 天 21 小時(shí) |
8 | 6 天 14 小時(shí) | 4 天 10 小時(shí) | 3 天 8 小時(shí) |
評(píng)估質(zhì)量和分解
使用run_metrics.py
可以評(píng)估本文中使用的質(zhì)量和解糾纏度量踱葛。默認(rèn)情況下,腳本將計(jì)算預(yù)訓(xùn)練的FFHQ生成器的Frechet初始距離(fid50k)光坝,并將結(jié)果寫入results下新創(chuàng)建的目錄尸诽。可以通過取消注釋或編輯run_metrics.py
中的特定行來更改確切的行為盯另。使用Tesla V100 GPU預(yù)訓(xùn)練的FFHQ生成器的預(yù)期評(píng)估時(shí)間和結(jié)果:
度量 | 時(shí)間 | 結(jié)果 | 描述 |
---|---|---|---|
fid50k | 16 分鐘 | 4.4159 | Fréchet Inception Distanc使用50,000張圖像瓢捉。 |
ppl_zfull | 55 分鐘 | 664.8854 | Z 中完整路徑的感知路徑長(zhǎng)度硼讽。 |
ppl_wfull | 55 分鐘 | 233.3059 | W 中完整路徑的感知路徑長(zhǎng)度奄毡。 |
ppl_zend | 55 分鐘 | 666.1057 | Z 中路徑端點(diǎn)的感知路徑長(zhǎng)度出革。 |
ppl_wend | 55 分鐘 | 197.2266 | W 中路徑端點(diǎn)的感知路徑長(zhǎng)度 |
ls | 10 hours | z: 165.0106 w: 3.7447 |
Z 和 W中的線性可分性。 |
請(qǐng)注意芝发,由于TensorFlow的非確定性绪商,每次運(yùn)行的確切結(jié)果可能有所不同。
其他預(yù)訓(xùn)練網(wǎng)絡(luò)生成的圖片
項(xiàng)目地址:https://github.com/NVlabs/stylegan