def get_shape_list(tensor, expected_rank=None, name=None)
參數(shù):
tensor:一個(gè)需要返回shape的tf.Tensor
expected_rank:int或者是一個(gè)int的list氛雪。輸入tensor期望的rank(也就是矩陣的維度)婿牍,如果輸入tensor的rank不等于這個(gè)數(shù),或者不是這個(gè)list的元素之一著榴,會(huì)拋異常夺姑。
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import six
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
print(str(index))
shape[index] = dyn_shape[index]
return shape
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
if name is None:
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
if __name__ == '__main__':
zero = tf.constant([0.0, 0.1, 0.2])
zero_shape = get_shape_list(zero, expected_rank=[1, 3])
print('zero_shape:')
print(zero_shape)
one = tf.constant([[0.0, 0.1, 0.2],[0.0, 0.1, 0.2]])
one_shape = get_shape_list(one, expected_rank=[2, 3])
print('one_shape:')
print(one_shape)
two = tf.compat.v1.placeholder(tf.int32, [32, 512])
two_shape = get_shape_list(two, expected_rank=2)
print('two_shape:')
print(two_shape)
輸出結(jié)果:
zero_shape:
[3]
one_shape:
[2, 3]
two_shape:
[32, 512]