Session()方法
tensorflow的內(nèi)核使用更加高效的C++作為后臺,以支撐它的密集計(jì)算。tensorflow把前臺(即python程序)與后臺程序之間的連接稱為"會話(Session)"
Session
作為會話,主要功能是指定操作對象的執(zhí)行環(huán)境址遇,Session
類構(gòu)造函數(shù)有3個(gè)可選參數(shù)。
-
target
(可選):指定連接的執(zhí)行引擎,多用于分布式場景屋灌。 -
graph
(可選):指定要在Session對象中參與計(jì)算的圖(graph)。 -
config
(可選):輔助配置Session對象所需的參數(shù)(限制CPU或GPU使用數(shù)目应狱,設(shè)置優(yōu)化參數(shù)以及設(shè)置日志選項(xiàng)等)共郭。
run()方法
Session對象創(chuàng)建完畢,便可以使用它最重要的方法run()來啟動所需要的數(shù)據(jù)流圖進(jìn)行計(jì)算疾呻。
run()方法有4個(gè)參數(shù):
run(
fetches,
feed_dict=None
options=None,
run_metadata=None
)
(1).fetches參數(shù)
- '取得之物'除嘹,表示數(shù)據(jù)流圖中能接收的任意數(shù)據(jù)流圖元素,各類Op/Tensor對象岸蜗。Op,run()將返回None尉咕;Tensor,rnu()將返回Numpy數(shù)組。
import tensorflow as tf
from collections import namedtuple
a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
session = tf.Session()
v1 = session.run(a) #fetches參數(shù)為單個(gè)張量值璃岳,返回值為Numpy數(shù)組
print(v1)
v2 = session.run([a, b]) #fetches參數(shù)為python類表年缎,包括兩個(gè)numpy的1維矩陣
print(v2)
v3 = session.run(tf.global_variables_initializer()) #fetches 為Op類型
print(v3)
session.close()
[10 20]
[array([10, 20], dtype=int32), array([ 1., 2.], dtype=float32)]
None
(2). feed_dict參數(shù)
- 可選項(xiàng)悔捶,給數(shù)據(jù)流圖提供運(yùn)行時(shí)數(shù)據(jù)。
feed_dict
的數(shù)據(jù)結(jié)構(gòu)為python中的字典晦款,其元素為各種鍵值對炎功。"key"為各種Tensor對象的句柄;"value"很廣泛缓溅,但必須和“鍵”的類型相匹配蛇损,或能轉(zhuǎn)換為同一類型。
import tensorflow as tf
a = tf.add(1, 2)
b = tf.multiply(a, 2)
session = tf.Session()
v1 = session.run(b)
print(v1)
replace_dict = {a:20}
v2 = session.run(b, feed_dict = replace_dict)
print(v2)
6
40