tensorflow對(duì)類別特征牙甫,會(huì)先轉(zhuǎn)換成字符串掷酗,然后做hash。代碼實(shí)現(xiàn)如下:
template <uint64 hash(StringPiece)>
class StringToHashBucketOp : public OpKernel {
public:
explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<string>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int64>();
typedef decltype(input_flat.size()) Index;
for (Index i = 0; i < input_flat.size(); ++i) {
const uint64 input_hash = hash(input_flat(i));
const uint64 bucket_id = input_hash % num_buckets_;
// The number of buckets is always in the positive range of int64 so is
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
// safe.
output_flat(i) = static_cast<int64>(bucket_id);
}
}
private:
int64 num_buckets_;
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
};
注意窟哺,這里定義了個(gè)模板類泻轰,其中hash是一個(gè)模板類型名。對(duì)模板的調(diào)用如下:
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
StringToHashBucketOp<Fingerprint64>);
以上且轨,定義在頭文件string_to_hash_bucket_op.h中浮声,實(shí)現(xiàn)在string_to_hash_bucket_op.cc里。
可以發(fā)現(xiàn)這里的hash函數(shù)使用的是Fingerprint64殖告,來自于google開源的farmhash阿蝶。
以下兩種實(shí)現(xiàn)是等價(jià)的:
const uint64 input_hash = hash(input_flat(i));
和
const uint64_t input_hash = NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(x.data(), x.size());