本文譯自Danijar Hafner的博客Structuring Your TensorFlow Models锌介。
構(gòu)建計算圖
一般來說會對每個模型建立一個class该肴,這個class的接口是什么呢假勿?通常模型會連接一些輸入數(shù)據(jù)和目標(biāo)的placeholders以及提供一些訓(xùn)練、評估和前向傳播的操作(operation)脸候,下面是一個例子获列,展示了一個全連接神經(jīng)網(wǎng)絡(luò):
class Model:
def __init__(self, data, target):
data_size = int(data.get_shape()[1])
target_size = int(target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
cross_entropy = -tf.reduce_sum(target, tf.log(self._prediction))
self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
mistakes = tf.not_equal(
tf.argmax(target, 1), tf.argmax(self._prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
@property
def prediction(self):
return self._prediction
@property
def optimize(self):
return self._optimize
@property
def error(self):
return self._error
這是一個基本結(jié)構(gòu)逗爹。然而這里存在一些問題亡嫌,最顯著的問題是整個計算圖是用單個函數(shù)定義的,這減少了可讀性和可重用性掘而。
使用Property裝飾器
僅僅將代碼分割為不同的函數(shù)不管用挟冠,因為一旦函數(shù)被調(diào)用,計算圖就會增加(這點譯者深有體會袍睡,Tensorflow中的代碼復(fù)用和傳統(tǒng)代碼復(fù)用不一致知染,因為它會為每一行代碼構(gòu)建計算節(jié)點,即使該節(jié)點所使用的參數(shù)是同一套)斑胜。因此控淡,我們需要確保操作(operation)僅在函數(shù)第一次被調(diào)用的時候加入計算圖,這是基本的惰性編程(lazy-coding)思想止潘。
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self._prediction = None
self._optimize = None
self._error = None
@property
def prediction(self):
if not self._prediction:
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
return self._prediction
@property
def optimize(self):
if not self._optimize:
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
self._optimize = optimizer.minimize(cross_entropy)
return self._optimize
@property
def error(self):
if not self._error:
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
return self._error
這個例子比第一個例子好多了掺炭,現(xiàn)在代碼被劃分成了不同的函數(shù)。然而這個代碼還是有點冗余(因為每個函數(shù)都用了相同的邏輯:if not ……凭戴,這個部分讓代碼看上去嵌套而不扁平竹伸,所以這個部分可用裝飾器重用)。
惰性屬性裝飾器(Lazy Property Decorator)
上面的例子使用了property裝飾器簇宽,它將函數(shù)的返回結(jié)構(gòu)存儲到一個以函數(shù)名為名字的對象屬性中。現(xiàn)在我們還可以將惰性編程的部分加入裝飾器吧享。
import functools
def lazy_property(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
現(xiàn)在我們的代碼就可以更佳簡化了魏割,如下所示:
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self.prediction
self.optimize
self.error
@lazy_property
def prediction(self):
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
return tf.nn.softmax(incoming)
@lazy_property
def optimize(self):
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
return optimizer.minimize(cross_entropy)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
整個計算圖在執(zhí)行tf.initialize_variables()
前需要定義好。
用Scopes組織計算圖
使用上面的例子產(chǎn)生的計算圖依舊非常擁擠钢颂,如果你可視化整個計算圖钞它,那么它會包含很多內(nèi)部的小節(jié)點,一個解決方式是使用tf.name_scope('name')
或者tf.variable_scope('name')
殊鞭。這樣節(jié)點會被分組遭垛,可視化非常直觀。我們可以通過調(diào)整之前的裝飾器操灿,將一個函數(shù)的名字作為其命名空間:
import functools
def define_scope(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(function.__name):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
這樣我們就定義了一個緊湊锯仪、可讀性強的代碼。