契機:由于該函數(shù)會保存檢查點者娱,由于設置默認為五個checkpoint,所以此處需要追尋第六個檢查點產(chǎn)生時系統(tǒng)是如何工作的苏揣;初次之外摸清底層檢查點的討論黄鳍,包括I/O的事宜等等。
- 第一步平匈,最上層用戶創(chuàng)建框沟,每1000輪保存一次檢查點
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
- 追溯saver.save()
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False):
# pylint: disable=line-too-long
"""Saves variables.
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
The method returns the path prefix of the newly created checkpoint files.
This string can be passed directly to a call to `restore()`.
Args:
sess: A Session to use to save the variables.
save_path: String. Prefix of filenames created for the checkpoint.
global_step: If provided the global step number is appended to `save_path`
to create the checkpoint filenames. The optional argument can be a
`Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoints. That file, kept in the
same directory as the checkpoint files, is automatically managed by the
saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
write_meta_graph: `Boolean` indicating whether or not to write the meta
graph file.
write_state: `Boolean` indicating whether or not to write the
`CheckpointStateProto`.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of save_path and with `_debug` added before
the file extension. This is only enabled when `write_meta_graph` is
`True`
Returns:
A string: path prefix used for the checkpoint files. If the saver is
sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
is the number of shards created.
If the saver is empty, returns None.
Raises:
TypeError: If `sess` is not a `Session`.
ValueError: If `latest_filename` contains path components, or if it
collides with `save_path`.
RuntimeError: If save and restore ops weren't built.
"""
# pylint: enable=line-too-long
if not self._is_built and not context.executing_eagerly():
raise RuntimeError(
"`build()` should be called before save if defer_build==True")
if latest_filename is None:
latest_filename = "checkpoint"
if self._write_version != saver_pb2.SaverDef.V2:
logging.warning("*******************************************************")
logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
logging.warning("Consider switching to the more efficient V2 format:")
logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
logging.warning("now on by default.")
logging.warning("*******************************************************")
if os.path.split(latest_filename)[0]:
raise ValueError("'latest_filename' must not contain path components")
if global_step is not None:
if not isinstance(global_step, compat.integral_types):
global_step = training_util.global_step(sess, global_step)
checkpoint_file = "%s-%d" % (save_path, global_step)
if self._pad_step_number:
# Zero-pads the step numbers, so that they are sorted when listed.
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
else:
checkpoint_file = save_path
if os.path.basename(save_path) == latest_filename and not self._sharded:
# Guard against collision between data file and checkpoint state file.
raise ValueError(
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
(latest_filename, save_path))
if (not context.executing_eagerly() and
not isinstance(sess, session.SessionInterface)):
raise TypeError("'sess' must be a Session; %s" % sess)
save_path_parent = os.path.dirname(save_path)
if not self._is_empty:
try:
if context.executing_eagerly():
self._build_eager(
checkpoint_file, build_save=True, build_restore=False)
model_checkpoint_path = self.saver_def.save_tensor_name
else:
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
## 上面始終在對checkpoint的名字做修改九秀,下句將最近的checkpoint保存下來
self._RecordLastCheckpoint(model_checkpoint_path)
## 更新"checkpoint"文件里面的值桨啃,該函數(shù)設計寫入操作
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
##
self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
if not gfile.IsDirectory(save_path_parent):
exc = ValueError(
"Parent directory of {} doesn't exist, can't save.".format(
save_path))
raise exc
if write_meta_graph:
meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
self.export_meta_graph(
meta_graph_filename,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info)
if self._is_empty:
return None
else:
return model_checkpoint_path
這里代碼中反復提到了context.executing_eagerly():
這里針對該內(nèi)容進行總結:
TensorFlow 引入了「Eager Execution」辣卒,它是一個命令式踱稍、由運行定義的接口,一旦從 Python 被調(diào)用灾前,其操作立即被執(zhí)行防症。這使得入門 TensorFlow 變的更簡單,也使研發(fā)更直觀哎甲。
簡單來說蔫敲,這是用戶可定義的一個機制,可選擇是否開啟炭玫。這里關心I/O層面的問題奈嘿,所以這里不詳述。
進入該函數(shù)吞加,進行判斷:
## global_step非空則進入
if global_step is not None:
## 查看源碼發(fā)現(xiàn)都是flase裙犹,所以可以直接進入
if not isinstance(global_step, compat.integral_types):
## 獲取當前步數(shù)
global_step = training_util.global_step(sess, global_step)
## 打印出來:“checkpoint路徑-步數(shù)”
checkpoint_file = "%s-%d" % (save_path, global_step)
之后進入創(chuàng)建checkpoint部分:
if not self._is_empty:
try:
if context.executing_eagerly():
self._build_eager(
checkpoint_file, build_save=True, build_restore=False)
## Store the tensor values to the tensor_names.
model_checkpoint_path = self.saver_def.save_tensor_name
else:
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
其中save_tensor_name
=save_tensor.numpy() if build_save else ""
# 用tf.train.Saver()創(chuàng)建一個Saver來管理模型中的所有變量
saver = tf.train.Saver(tf.all_variables())
在看下面函數(shù):
checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
def update_checkpoint_state_internal(save_dir, ##model save path
model_checkpoint_path, ##checkpoint
all_model_checkpoint_paths=None, ## from old to new
latest_filename=None,
save_relative_paths=False,
all_model_checkpoint_timestamps=None,
last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
all_model_checkpoint_timestamps: Optional list of timestamps (floats,
seconds since the Epoch) indicating when the checkpoints in
`all_model_checkpoint_paths` were created.
last_preserved_timestamp: A float, indicating the number of seconds since
the Epoch when the last preserved checkpoint was written, e.g. due to a
`keep_checkpoint_every_n_hours` parameter (see
`tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
if save_relative_paths:
if os.path.isabs(model_checkpoint_path):
rel_model_checkpoint_path = os.path.relpath(
model_checkpoint_path, save_dir)
else:
rel_model_checkpoint_path = model_checkpoint_path
rel_all_model_checkpoint_paths = []
for p in all_model_checkpoint_paths:
if os.path.isabs(p):
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
else:
rel_all_model_checkpoint_paths.append(p)
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
else:
## generate related checkpoint
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
last_preserved_timestamp=last_preserved_timestamp)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
該函數(shù)比較關鍵(對checkpint這個文件進行寫入),需要詳細分析:
該函數(shù)目的為更新checkpoint文件中的內(nèi)容衔憨,更新包含CheckpointState原型的檢查點文件叶圃。
傳入?yún)?shù)介紹:
save_dir: Directory where the model was saved.(模型存儲的目錄)
model_checkpoint_path: 檢查點文件名字(/tmp/cifar10_train/model.ckpt-400)
all_model_checkpoint_paths: 字符串列表,即當前沒有刪除的所有checkpoint內(nèi)容践图〔艄冢“checkpoint”這個文件中的內(nèi)容:
CP_DIY_last_checkpoints: [('/tmp/cifar10_train/model.ckpt-0', 1567741570.985191), ('/tmp/cifar10_train/model.ckpt-100', 1567741597.823174), ('/tmp/cifar10_train/model.ckpt-200', 1567741622.994202), ('/tmp/cifar10_train/model.ckpt-300', 1567741648.138035), ('/tmp/cifar10_train/model.ckpt-400', 1567741674.720897)]
latest_filename: checkpont這個文件的名字,默認叫做 'checkpoint'
save_relative_paths: 相對码党、絕對路徑
在該函數(shù)中打印ckpt:
得到:
CP_DIY_ckpt: model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-0"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-0"
CP_DIY_ckpt: model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-100"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-0"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-100"
說明存儲的內(nèi)容包括當前的model_checkpoint_path
以及all_model_checkpoint_paths
德崭。
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
最后一句話表明:將ckpt中的內(nèi)容寫入coord_checkpoint_filename
,
即將
model_checkpoint_path: "/tmp/cifar10_train/model.ckpt-999999"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-998000"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-998500"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999000"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999500"
all_model_checkpoint_paths: "/tmp/cifar10_train/model.ckpt-999999"
寫入checkpoint這個文件中揖盘。
下面更新meta計算圖checkpoint文件的名字
meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
得到xxxx.meta文件眉厨。
為了在該文件中寫入,我們查看export_meta_graph
函數(shù):
@tf_export(v1=["train.export_meta_graph"])
def export_meta_graph(filename=None,
meta_info_def=None,
graph_def=None,
saver_def=None,
collection_list=None,
as_text=False,
graph=None,
export_scope=None,
clear_devices=False,
clear_extraneous_savers=False,
strip_default_attrs=False,
save_debug_info=False,
**kwargs):
# pylint: disable=line-too-long
"""Returns `MetaGraphDef` proto.
Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported
at a later time or location to restart training, run inference, or be
a subgraph.
Args:
filename: Optional filename including the path for writing the generated
`MetaGraphDef` protocol buffer.
meta_info_def: `MetaInfoDef` protocol buffer.
graph_def: `GraphDef` protocol buffer.
saver_def: `SaverDef` protocol buffer.
collection_list: List of string keys to collect.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract the
subgraph. The scope name will be striped from the node definitions for
easy import later into new name scopes. If `None`, the whole graph is
exported. graph_def and export_scope cannot both be specified.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the graph
(both Save/Restore ops and SaverDefs) that are not associated with the
provided SaverDef.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before the
file extend.
**kwargs: Optional keyed arguments.
Returns:
A `MetaGraphDef` proto.
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
RuntimeError: If called with eager execution enabled.
@compatibility(eager)
Exporting/importing meta graphs is not supported unless both `graph_def` and
`graph` are provided. No graph exists when eager execution is enabled.
@end_compatibility
"""
# pylint: enable=line-too-long
if context.executing_eagerly() and not (graph_def is not None and
graph is not None):
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
"execution is enabled.")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
filename=filename,
meta_info_def=meta_info_def,
graph_def=graph_def,
saver_def=saver_def,
collection_list=collection_list,
as_text=as_text,
graph=graph,
export_scope=export_scope,
clear_devices=clear_devices,
clear_extraneous_savers=clear_extraneous_savers,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info,
**kwargs)
return meta_graph_def
那么系統(tǒng)如何將計算圖導出呢兽狭?
def export_scoped_meta_graph(filename=None,
graph_def=None,
graph=None,
export_scope=None,
as_text=False,
unbound_inputs_col_name="unbound_inputs",
clear_devices=False,
saver_def=None,
clear_extraneous_savers=False,
strip_default_attrs=False,
save_debug_info=False,
**kwargs):
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported
at a later time or location to restart training, run inference, or be
a subgraph.
Args:
filename: Optional filename including the path for writing the
generated `MetaGraphDef` protocol buffer.
graph_def: `GraphDef` protocol buffer.
graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract
the subgraph. The scope name will be stripped from the node definitions
for easy import later into new name scopes. If `None`, the whole graph
is exported.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
unbound_inputs_col_name: Optional `string`. If provided, a string collection
with the given name will be added to the returned `MetaGraphDef`,
containing the names of tensors that must be remapped when importing the
`MetaGraphDef`.
clear_devices: Boolean which controls whether to clear device information
before exporting the graph.
saver_def: `SaverDef` protocol buffer.
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
strip_default_attrs: Set to true if default valued attributes must be
removed while exporting the GraphDef.
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before the
file extension.
**kwargs: Optional keyed arguments, including meta_info_def and
collection_list.
Returns:
A `MetaGraphDef` proto and dictionary of `Variables` in the exported
name scope.
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
ValueError: When executing in Eager mode and either `graph_def` or `graph`
is undefined.
"""
if context.executing_eagerly() and not (graph_def is not None and
graph is not None):
raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.")
graph = graph or ops.get_default_graph()
exclude_nodes = None
unbound_inputs = []
if export_scope or clear_extraneous_savers or clear_devices:
if graph_def:
new_graph_def = graph_pb2.GraphDef()
new_graph_def.versions.CopyFrom(graph_def.versions)
new_graph_def.library.CopyFrom(graph_def.library)
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
for node_def in graph_def.node:
if _should_include_node(node_def.name, export_scope, exclude_nodes):
new_node_def = _node_def(node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
new_graph_def.node.extend([new_node_def])
graph_def = new_graph_def
else:
# Only do this complicated work if we want to remove a name scope.
graph_def = graph_pb2.GraphDef()
# pylint: disable=protected-access
graph_def.versions.CopyFrom(graph.graph_def_versions)
bytesize = 0
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
saver_def)
for key in sorted(graph._nodes_by_id):
if _should_include_node(graph._nodes_by_id[key].name,
export_scope,
exclude_nodes):
value = graph._nodes_by_id[key]
# pylint: enable=protected-access
node_def = _node_def(value.node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
graph_def.node.extend([node_def])
if value.outputs:
assert "_output_shapes" not in graph_def.node[-1].attr
graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
output.get_shape().as_proto() for output in value.outputs])
bytesize += value.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access
# It's possible that not all the inputs are in the export_scope.
# If we would like such information included in the exported meta_graph,
# add them to a special unbound_inputs collection.
if unbound_inputs_col_name:
# Clears the unbound_inputs collections.
graph.clear_collection(unbound_inputs_col_name)
for k in unbound_inputs:
graph.add_to_collection(unbound_inputs_col_name, k)
var_list = {}
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=export_scope)
for v in variables:
if _should_include_node(v, export_scope, exclude_nodes):
var_list[ops.strip_name_scope(v.name, export_scope)] = v
scoped_meta_graph_def = create_meta_graph_def(
graph_def=graph_def,
graph=graph,
export_scope=export_scope,
exclude_nodes=exclude_nodes,
clear_extraneous_savers=clear_extraneous_savers,
saver_def=saver_def,
strip_default_attrs=strip_default_attrs,
**kwargs)
if filename:
graph_io.write_graph(
scoped_meta_graph_def,
os.path.dirname(filename),
os.path.basename(filename),
as_text=as_text)
if save_debug_info:
name, _ = os.path.splitext(filename)
debug_filename = "{name}{ext}".format(name=name, ext=".debug")
# Gets the operation from the graph by the name. Exludes variable nodes,
# so only the nodes in the frozen models are included.
ops_to_export = []
for node in scoped_meta_graph_def.graph_def.node:
scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
ops_to_export.append(graph.get_operation_by_name(scoped_op_name))
graph_debug_info = create_graph_debug_info_def(ops_to_export)
graph_io.write_graph(
graph_debug_info,
os.path.dirname(debug_filename),
os.path.basename(debug_filename),
as_text=as_text)
return scoped_meta_graph_def, var_list
該函數(shù)有些復雜憾股,后續(xù)繼續(xù)分析。那么如何將內(nèi)容寫入該文件呢椭符?
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
"""Writes a graph proto to a file.
The graph is written as a text proto unless `as_text` is `False`.
python
v = tf.Variable(0, name='my_variable')
sess = tf.compat.v1.Session()
tf.io.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
or
python
v = tf.Variable(0, name='my_variable')
sess = tf.compat.v1.Session()
tf.io.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt')
Args:
graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
logdir: Directory where to write the graph. This can refer to remote
filesystems, such as Google Cloud Storage (GCS).
name: Filename for the graph.
as_text: If `True`, writes the graph as an ASCII proto.
Returns:
The path of the output proto file.
"""
if isinstance(graph_or_graph_def, ops.Graph):
graph_def = graph_or_graph_def.as_graph_def()
else:
graph_def = graph_or_graph_def
# gcs does not have the concept of directory at the moment.
if not file_io.file_exists(logdir) and not logdir.startswith('gs:'):
file_io.recursive_create_dir(logdir)
path = os.path.join(logdir, name)
if as_text:
file_io.atomic_write_string_to_file(path,
text_format.MessageToString(
graph_def, float_format=''))
else:
file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())
return path
首先判斷是否存在文件夾荔燎,如果不存在則:file_io.recursive_create_dir(logdir)
耻姥。
繼續(xù)向下走销钝,我們提取出
def write_string_to_file(filename, file_content):
"""Writes a string to a given file.
Args:
filename: string, path to a file
file_content: string, contents that need to be written to the file
Raises:
errors.OpError: If there are errors during the operation.
"""
with FileIO(filename, mode="w") as f:
f.write(file_content)
對于FileIO:
The constructor takes the following arguments:
name: name of the file
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'. Append 'b' for bytes mode.
Can be used as an iterator to iterate over lines in the file.
The default buffer size used for the BufferedInputStream used for reading
the file line by line is 1024 * 512 bytes.
"""
def write(self, file_content):
"""Writes file_content to the file. Appends to the end of the file."""
self._prewrite_check()
pywrap_tensorflow.AppendToFile(
compat.as_bytes(file_content), self._writable_file)
這里調(diào)用write
函數(shù)分為兩步:
- 第一步
self._prewrite_check()
我們查看底層相關代碼:
def _prewrite_check(self):
if not self._writable_file:
if not self._write_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for writing")
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
前面是判斷,是否輸入的mode是合適的琐簇,然后直接進入:
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
def CreateWritableFile(filename, mode):
return _pywrap_tensorflow_internal.CreateWritableFile(filename, mode)
CreateWritableFile = _pywrap_tensorflow_internal.CreateWritableFile
這里關鍵的函數(shù)是:CreateWritableFile
定位到底層C:在file_io.i文件中
tensorflow::WritableFile* CreateWritableFile(
const string& filename, const string& mode, TF_Status* status) {
std::unique_ptr<tensorflow::WritableFile> file;
tensorflow::Status s;
if (mode.find("a") != std::string::npos) {
s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
} else {
s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
}
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return nullptr;
}
return file.release();
}
使用a
進入s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
Status PosixFileSystem::NewAppendableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
string translated_fname = TranslateName(fname);
Status s;
FILE* f = fopen(translated_fname.c_str(), "a");
if (f == nullptr) {
s = IOError(fname, errno);
} else {
result->reset(new PosixWritableFile(translated_fname, f));
}
return s;
}
其他的進入s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
Status PosixFileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string translated_fname = TranslateName(fname);
Status s;
FILE* f = fopen(translated_fname.c_str(), "w");
if (f == nullptr) {
s = IOError(fname, errno);
} else {
result->reset(new PosixWritableFile(translated_fname, f));
}
return s;
}
w 打開只寫文件蒸健,若文件存在則文件長度清為0座享,即該文件內(nèi)容會消失。若文件不存在則建立該文件似忧。
w+ 打開可讀寫文件渣叛,若文件存在則文件長度清為零,即該文件內(nèi)容會消失盯捌。若文件不存在則建立該文件淳衙。
用w不能讀只能寫,w+能力強一點饺著。
均在PosixFileSystem
中箫攀。
- 第二步
pywrap_tensorflow.AppendToFile(compat.as_bytes(file_content), self._writable_file)
第一步中將self._writable_file = None
參數(shù)更新為當前所操作的文件,之后將file_content與_writable_file作為參數(shù)傳入:
def AppendToFile(file_content, file):
return _pywrap_tensorflow_internal.AppendToFile(file_content, file)
AppendToFile = _pywrap_tensorflow_internal.AppendToFile
在file_io.i
中存在AppendToFile
底層源碼:
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
TF_Status* status) {
tensorflow::Status s = file->Append(file_content);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
tensorflow::Status s = file->Append(file_content);
Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
}
return Status::OK();
}