在使用TensorFlow進行建模扑浸、訓(xùn)練和預(yù)測時粘舟,可以使用estimator這樣的高階函數(shù)方便使用凉驻”从伲基本的套路是:
訓(xùn)練 fit
from tensorflow.contrib import learn
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
validation_monitor = learn.monitors.ValidationMonitor(x['val'], y['val'],
every_n_steps=config['print_steps'],
early_stopping_rounds=1000)
classifier.fit(x['train'], y['train'],monitors=[validation_monitor], batch_size=config['batch_size'], steps=len(x['train'])*config['epochs'])
- 初始化classifer
- 配置monitor
- 使用fit函數(shù)進行訓(xùn)練
預(yù)測 predict
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
pre = classifier.predict(x['test'])
- 初始化classifier,其中model_dir的配置需要指定為和訓(xùn)練相同的目錄
- 使用predict
根據(jù)官方文檔的步驟郭计,這樣predict即可使用訓(xùn)練好的模型和參數(shù)霸琴。
但是實際結(jié)果是這樣predict會報錯,具體錯誤如下:
ValueError: Tried to convert 'values' to a tensor and failed. Error: None values not supported.
修正后的predict
這其實是learn.predict的一個bug昭伸,目前社區(qū)還沒有fix梧乘,不過在issue-3208 給出了work around的方法,在predict之前使用eval庐杨,類似于聲明classifier并沒有實例化选调,在eval的時候,會將其實例化辑莫,這樣在之后使用predict的時候就不會再報Error: None values not supported.的錯誤了学歧。當(dāng)然這樣做的缺點是因為eval會使用TF對模型進行運算至少一次罩引,這樣會造成一定的性能損耗各吨。
work around的方法
classifier = learn.Estimator(model_fn=lstm_model(stepLength, config['lstm_layers'], config['dense_layers']),
model_dir=config['log_dir'])
classifier.evaluate(x=x['val'],y=y['val'],steps=1)
pre = classifier.predict(x['test'])