確保目錄結(jié)構(gòu)存在抡诞。每次創(chuàng)建文件,確保父目錄已經(jīng)存在土陪。確保指定路徑全部或部分目錄已經(jīng)存在昼汗。創(chuàng)建沿指定路徑上不存在目錄。
下載函數(shù)鬼雀,如果文件名未指定顷窒,從URL解析。下載文件,返回本地文件系統(tǒng)文件名鞋吉。如果文件存在鸦做,不下載。如果文件未指定谓着,從URL解析泼诱,返回filepath 。實際下載前赊锚,檢查下載位置是否有目標名稱文件治筒。是,跳過下載舷蒲。下載文件耸袜,返回路徑。重復下載牲平,把文件從文件系統(tǒng)刪除堤框。
import os
import shutil
import errno
from lxml import etree
from urllib.request import urlopen
def ensure_directory(directory):
directory = os.path.expanduser(directory)
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
def download(url, directory, filename=None):
if not filename:
_, filename = os.path.split(url)
directory = os.path.expanduser(directory)
ensure_directory(directory)
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath):
return filepath
print('Download', filepath)
with urlopen(url) as response, open(filepath, 'wb') as file_:
shutil.copyfileobj(response, file_)
return filepath
磁盤緩存修飾器,較大規(guī)模數(shù)據(jù)集處理中間結(jié)果保存磁盤公共位置纵柿,緩存加載函數(shù)修飾器胰锌。Python pickle功能實現(xiàn)函數(shù)返回值序列化、反序列化藐窄。只適合能納入主存數(shù)據(jù)集资昧。@disk_cache修飾器,函數(shù)實參傳給被修飾函數(shù)荆忍。函數(shù)參數(shù)確定參數(shù)組合是否有緩存格带。散列映射為文件名數(shù)字。如果是'method'刹枉,跳過第一參數(shù)叽唱,緩存filepath,'directory/basename-hash.pickle'微宝。方法method=False參數(shù)通知修飾器是否忽略第一個參數(shù)棺亭。
import functools
import os
import pickle
def disk_cache(basename, directory, method=False):
directory = os.path.expanduser(directory)
ensure_directory(directory)
def wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
key = (tuple(args), tuple(kwargs.items()))
if method and key:
key = key[1:]
filename = '{}-{}.pickle'.format(basename, hash(key))
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath):
with open(filepath, 'rb') as handle:
return pickle.load(handle)
result = func(*args, **kwargs)
with open(filepath, 'wb') as handle:
pickle.dump(result, handle)
return result
return wrapped
return wrapper
@disk_cache('dataset', '/home/user/dataset/')
def get_dataset(one_hot=True):
dataset = Dataset('http://example.com/dataset.bz2')
dataset = Tokenize(dataset)
if one_hot:
dataset = OneHotEncoding(dataset)
return dataset
屬性字典。繼承自內(nèi)置dict類蟋软,可用屬性語法訪問悠已有元素镶摘。傳入標準字典(鍵值對)。內(nèi)置函數(shù)locals岳守,返回作用域所有局部變量名值映射凄敢。
class AttrDict(dict):
def __getattr__(self, key):
if key not in self:
raise AttributeError
return self[key]
def __setattr__(self, key, value):
if key not in self:
raise AttributeError
self[key] = value
惰性屬性修飾器。外部使用湿痢。訪問model.optimze涝缝,數(shù)據(jù)流圖創(chuàng)建新計算路徑。調(diào)用model.prediction,創(chuàng)建新權(quán)值和偏置拒逮。定義只計算一次屬性罐氨。結(jié)果保存到帶有某些前綴的函數(shù)調(diào)用。惰性屬性滩援,TensorFlow模型結(jié)構(gòu)化栅隐、分類。
import functools
def lazy_property(function):
attribute = '_lazy_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self.prediction
self.optimize
self.error
@lazy_property
def prediction(self):
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
return tf.nn.softmax(incoming)
@lazy_property
def optimize(self):
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
return optimizer.minimize(cross_entropy)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
覆蓋數(shù)據(jù)流圖修飾器狠怨。未明確指定使用期他數(shù)據(jù)流圖约啊,TensorFlow使用默認。Jupyter Notebook佣赖,解釋器狀態(tài)在不同一單元執(zhí)行期間保持恰矩。初始默認數(shù)據(jù)流圖始終存在。執(zhí)行再次定義數(shù)據(jù)流圖運算單元憎蛤,添加到已存在數(shù)據(jù)流圖外傅。根據(jù)菜單選項重新啟動kernel,再次運行所有單元俩檬。
創(chuàng)建定制數(shù)據(jù)流圖萎胰,設置默認。所有運算添加到該數(shù)據(jù)流圖棚辽,再次運行單元技竟,創(chuàng)建新數(shù)據(jù)流圖。舊數(shù)據(jù)流圖自動清理屈藐。
修飾器中創(chuàng)建數(shù)據(jù)流圖榔组,修飾主函數(shù)。主函數(shù)定義完整數(shù)據(jù)流圖联逻,定義占位符搓扯,調(diào)用函數(shù)創(chuàng)建模型。
import functools
import tensorflow as tf
def overwrite_graph(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
with tf.Graph().as_default():
return function(*args, **kwargs)
return wrapper
@overwrite_graph
def main():
data = tf.placeholder(...)
target = tf.placeholder(...)
model = Model()
main()
API文檔包归,編寫代碼時參考:
https://www.tensorflow.org/versions/master/api_docs/index.html
Github庫锨推,跟蹤TensorFlow最新功能特性,閱讀拉拽請求(pull request)公壤、問題(issues)换可、發(fā)行記錄(release note):
https://github.com/tensorflow/tensorflow
分布式 TensorFlow:
https://www.tensorflow.org/versions/master/how_tos/distributed/index.html
構(gòu)建新TensorFlow功能:
https://www.tensorflow.org/master/how_tos/adding_an_op/index.html
郵件列表:
https://groups.google.com/a/tensorflow.org/d/forum/discuss
StackOverflow:
http://stackoverflow.com/questions/tagged/tensorflow
代碼:
https://github.com/backstopmedia/tensorflowbook
參考資料:
《面向機器智能的TensorFlow實踐》
歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi