keras有著很多已經(jīng)與訓(xùn)練好的模型供調(diào)用,因此我們可以基于這些已經(jīng)訓(xùn)練好的模型來做特征提取或者微調(diào)指么,來滿足我們自己的需求。
比如我們要調(diào)用VGG16在imagenet下訓(xùn)練的模型:
from keras.applications import VGG16
conv_base = VGG16(include_top=False, weights='imagenet')
features_batch = conv_base.predict(inputs_batch)
這里是利用預(yù)訓(xùn)練的模型來做特征提取,因此我們不需要頂層的分類器網(wǎng)絡(luò)部分的權(quán)重迷帜,只需要使用到訓(xùn)練好的卷積基。這也就是VGG16參數(shù)中include_top=False的含義色洞,weights='imagenet'的意思就直接是基于imagenet訓(xùn)練的網(wǎng)絡(luò)權(quán)重了戏锹。
但是在服務(wù)器上運(yùn)行的時(shí)候遇到一個(gè)問題,因?yàn)檫@個(gè)模型第一次使用時(shí)需要去下載火诸,而服務(wù)器連接下載的url超時(shí)锦针。。惭蹂。那就只能手動(dòng)離線下載然后放到路徑里去供調(diào)用了伞插。
首先keras提供的模型下載地址是:https://github.com/fchollet/deep-learning-models/releases
其中我們找到vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5下載即可。
從這個(gè)命名也可以看出很多信息了盾碗,比如從tf看出這是基于tensorflow的(th是基于Theano )媚污,notop也就是我們上面說的不要頂層的分類器部分,h5后綴表示keras使用HDF5格式存儲(chǔ)的廷雅,等等耗美。
下好后放在哪呢京髓?我們只能看看keras的代碼是怎么寫的,從報(bào)錯(cuò)信息中可以得到你的機(jī)器中vgg16.py的文件路徑商架,比如:
Traceback (most recent call last):
File "main.py", line 9, in <module>
train.train()
File "/cloudox/cifar10_test/train.py", line 52, in train
conv_base = VGG16(include_top=False, weights='imagenet')
File "/……/keras/applications/__init__.py", line 28, in wrapper
return base_fun(*args, **kwargs)
File "/……/keras/applications/vgg16.py", line 11, in VGG16
return vgg16.VGG16(*args, **kwargs)
File "/……/keras_applications/vgg16.py", line 209, in VGG16
file_hash='6d6bbae143d832006294945121d1f1fc')
File "/……/keras/utils/data_utils.py", line 226, in get_file
raise Exception(error_msg.format(origin, e.errno, e.reason))
Exception: URL fetch failure on https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5: None -- [Errno 110] Connection timed out
從報(bào)錯(cuò)信息中堰怨,第一我們可以知道是下載“https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5”這個(gè)文件超時(shí),這也是我們上面文件下載路徑的由來蛇摸。第二我們可以知道下載的源頭在哪里备图,大致檢查一下,就會(huì)發(fā)現(xiàn)是在"/……/keras_applications/vgg16.py"這個(gè)文件中(“/usr/local/app/anaconda2/envs/tensorflow/lib/python2.7/site-packages/keras_applications/vgg16.py”)赶袄,他的代碼其實(shí)就在這:https://github.com/fchollet/deep-learning-models/blob/master/vgg16.py
好我們看看vgg16.py的代碼揽涮,首先在頂部定義了兩個(gè)下載路徑:
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
然后往下翻會(huì)看到獲取模型權(quán)重文件的代碼:
# load weights
if weights == 'imagenet':
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
還記得我們調(diào)用的時(shí)候傳的參數(shù)吧,VGG16(include_top=False, weights='imagenet')饿肺,所以就是這里蒋困。這里調(diào)用了get_file這個(gè)函數(shù)來從路徑中獲取權(quán)重文件,那我們看看這個(gè)函數(shù)在哪敬辣,代碼中說了在:
from keras.utils.data_utils import get_file
那就去找嘛雪标,既可以在你的文件夾里找,也可以在github找溉跃,因?yàn)関gg16這個(gè)文件屬于一個(gè)單獨(dú)的工程村刨,因此我們從作者的所有倉庫中找到keras工程,然后順著keras.utils.data_utils找到代碼喊积,在這:https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py
這時(shí)候離我們要的東西就不遠(yuǎn)了烹困,這時(shí)候都不用詳細(xì)看代碼,我們看下注釋:
注釋說乾吻,這個(gè)函數(shù)會(huì)先檢查cache中是否有文件髓梅,如果沒有就從url下載,而這個(gè)cache的路徑在~/.keras
绎签,默認(rèn)存儲(chǔ)文件是datasets枯饿,說明默認(rèn)是下載數(shù)據(jù)集的,還記得vgg16那邊傳的參數(shù)么诡必,cache_subdir='models'奢方,所以這個(gè)文件應(yīng)該在的位置就是~/.keras/models
,這時(shí)候我們直接進(jìn)入該目錄爸舒,發(fā)現(xiàn)果然有個(gè)models文件:
$ cd ~/.keras/
~/.keras]$ ls
datasets keras.json models
那就直接把文件放進(jìn)來就好啦蟋字。
這時(shí)候再去運(yùn)行之前自己的代碼就可以成功啦。