WordpieceTokenizer && BertTokenizer
1 WordpieceTokenizer
class WordpieceTokenizer(TokenizerWithOffsets):
"""Tokenizes a tensor of UTF-8 string tokens into subword pieces."""
def __init__(self,
vocab_lookup_table,
suffix_indicator='##',
max_bytes_per_word=100,
max_chars_per_token=None,
token_out_type=dtypes.int64,
unknown_token='[UNK]',
split_unknown_characters=False):
_ENGLISH_VOCAB = [
b"don",
b"##'",
b"##t",
b"tread",
b"##ness",
b"hel",
b"##lo",
b"there",
b"my",
b"na",
b"##me",
b"is",
b"ter",
b"##ry",
b"what",
b"##cha",
b"##ma",
b"##call",
b"##it?",
b"you",
b"said",
]
dict(
tokens=[[b"don't", b"treadness"]],
expected_subwords=[[[b"don", b"##'", b"##t"], [b"tread", b"##ness"]]],
vocab=_ENGLISH_VOCAB,
),
dict(
tokens=[[b"hello", b"there", b"my", b"name", b"is", b"terry"],
[b"whatchamacallit?", b"you", b"said"]],
expected_subwords=[[[b"hel", b"##lo"], [b"there"], [b"my"],
[b"na", b"##me"], [b"is"], [b"ter", b"##ry"]],
[[b"what", b"##cha", b"##ma", b"##call", b"##it?"],
[b"you"], [b"said"]]],
vocab=_ENGLISH_VOCAB,
),
# Basic case w/ unknown token
dict(
tokens=[[b"don't", b"tread", b"cantfindme", b"treadcantfindme"]],
expected_subwords=[[[b"don", b"##'", b"##t"], [b"tread"], [b"[UNK]"],
[b"[UNK]"]]],
vocab=_ENGLISH_VOCAB,
),
2 BasicTokenizer
class BasicTokenizer(TokenizerWithOffsets):
"""Basic tokenizer for for tokenizing text.
A basic tokenizer that tokenizes using some deterministic rules:
- For most languages, this tokenizer will split on whitespace.
- For Chinese, Japanese, and Korean characters, this tokenizer will split on
Unicode characters.
Attributes:
lower_case: bool - If true, a preprocessing step is added to lowercase the
text, apply NFD normalization, and strip accents characters.
keep_whitespace: bool - If true, preserves whitespace characters instead of
stripping them away.
normalization_form: If true and lower_case=False, the input text will be
normalized to `normalization_form`. See normalize_utf8() op for a list of
valid values.
"""
def __init__(self,
lower_case=False,
keep_whitespace=False,
normalization_form=None):
self._lower_case = lower_case
if not keep_whitespace:
self._keep_delim_regex_pattern = _KEEP_DELIM_NO_WHITESPACE_PATTERN
else:
self._keep_delim_regex_pattern = _DELIM_REGEX_PATTERN
self._normalization_form = normalization_form
def tokenize(self, text_input):
tokens, _, _ = self.tokenize_with_offsets(text_input)
return tokens
def tokenize_with_offsets(self, text_input):
"""Performs basic word tokenization for BERT.
Args:
text_input: A `Tensor` or `RaggedTensor` of untokenized UTF-8 strings.
Returns:
A `RaggedTensor` of tokenized strings from text_input.
"""
# lowercase and strip accents (if option is set)
if self._lower_case:
text_input = case_fold_utf8(text_input)
text_input = normalize_utf8(text_input, "NFD")
text_input = string_ops.regex_replace(text_input, r"\p{Mn}", "")
else:
# utf8 normalization
if self._normalization_form is not None:
text_input = normalize_utf8(text_input, self._normalization_form)
# strip out control characters
text_input = string_ops.regex_replace(text_input, r"\p{Cc}|\p{Cf}", " ")
return regex_split_ops.regex_split_with_offsets(
text_input, _DELIM_REGEX_PATTERN, self._keep_delim_regex_pattern,
"BertBasicTokenizer")
2.1 CaseFoldUTF8Op
# pylint: disable=redefined-builtin
def case_fold_utf8(input, name=None):
"""Applies case folding to every UTF-8 string in the input.
The input is a `Tensor` or `RaggedTensor` of any shape, and the resulting
output has the same shape as the input. Note that NFKC normalization is
implicitly applied to the strings.
For example:
```python
>>> case_fold_utf8(['The Quick-Brown',
... 'CAT jumped over',
... 'the lazy dog !! ']
tf.Tensor(['the quick-brown' 'cat jumped over' 'the lazy dog !! '],
shape=(3,), dtype=string)
Args:
input: A Tensor
or RaggedTensor
of UTF-8 encoded strings.
name: The name for this op (optional).
Returns:
A Tensor
or RaggedTensor
of type string, with case-folded contents.
"""
with ops.name_scope(name, "CaseFoldUTF8", [input]):
input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
input, dtype=dtypes.string)
if ragged_tensor.is_ragged(input_tensor):
result = gen_normalize_ops.case_fold_utf8(input_tensor.flat_values)
return input_tensor.with_flat_values(result)
else:
return gen_normalize_ops.case_fold_utf8(input_tensor)
```c++
class CaseFoldUTF8Op : public tensorflow::OpKernel {
public:
explicit CaseFoldUTF8Op(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_vec = input_tensor->flat<tstring>();
// TODO(gregbillock): support forwarding
tensorflow::Tensor* output_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(),
&output_tensor));
auto output_vec = output_tensor->flat<tstring>();
icu::ErrorCode icu_error;
const icu::Normalizer2* nfkc_cf = icu::Normalizer2::getNFKCCasefoldInstance(
icu_error);
OP_REQUIRES(context, icu_error.isSuccess(), errors::Internal(
absl::StrCat(icu_error.errorName(),
": Could not retrieve ICU NFKC_CaseFold normalizer")));
for (int64 i = 0; i < input_vec.size(); ++i) {
string output_text;
icu::StringByteSink<string> byte_sink(&output_text);
const auto& input = input_vec(i);
nfkc_cf->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()),
byte_sink, nullptr, icu_error);
OP_REQUIRES(context, !U_FAILURE(icu_error), errors::Internal(
"Could not normalize input string: " + input_vec(i)));
output_vec(i) = output_text;
}
}
};
2.2 NormalizeUTF8Op
def normalize_utf8(input, normalization_form="NFKC", name=None):
"""Normalizes each UTF-8 string in the input tensor using the specified rule.
See http://unicode.org/reports/tr15/
Args:
input: A `Tensor` or `RaggedTensor` of type string. (Must be UTF-8.)
normalization_form: One of the following string values ('NFC', 'NFKC',
'NFD', 'NFKD'). Default is 'NFKC'.
name: The name for this op (optional).
Returns:
A `Tensor` or `RaggedTensor` of type string, with normalized contents.
"""
with ops.name_scope(name, "NormalizeUTF8", [input]):
input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
input, dtype=dtypes.string)
if ragged_tensor.is_ragged(input_tensor):
result = gen_normalize_ops.normalize_utf8(input_tensor.flat_values,
normalization_form)
return input_tensor.with_flat_values(result)
else:
return gen_normalize_ops.normalize_utf8(input_tensor, normalization_form)
class NormalizeUTF8Op : public tensorflow::OpKernel {
public:
explicit NormalizeUTF8Op(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context),
normalization_form_(GetNormalizationForm(context)) {}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_vec = input_tensor->flat<tstring>();
tensorflow::Tensor* output_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(),
&output_tensor));
auto output_vec = output_tensor->flat<tstring>();
icu::ErrorCode icu_error;
const icu::Normalizer2* normalizer = nullptr;
if (normalization_form_ == "NFKC") {
normalizer = icu::Normalizer2::getNFKCInstance(icu_error);
OP_REQUIRES(context, icu_error.isSuccess(), errors::Internal(
absl::StrCat(icu_error.errorName(),
": Could not retrieve ICU NFKC normalizer")));
} else if (normalization_form_ == "NFC") {
normalizer = icu::Normalizer2::getNFCInstance(icu_error);
OP_REQUIRES(context, icu_error.isSuccess(), errors::Internal(
absl::StrCat(icu_error.errorName(),
": Could not retrieve ICU NFC normalizer")));
} else if (normalization_form_ == "NFD") {
normalizer = icu::Normalizer2::getNFDInstance(icu_error);
OP_REQUIRES(context, icu_error.isSuccess(), errors::Internal(
absl::StrCat(icu_error.errorName(),
": Could not retrieve ICU NFD normalizer")));
} else if (normalization_form_ == "NFKD") {
normalizer = icu::Normalizer2::getNFKDInstance(icu_error);
OP_REQUIRES(context, icu_error.isSuccess(), errors::Internal(
absl::StrCat(icu_error.errorName(),
": Could not retrieve ICU NFKd normalizer")));
} else {
OP_REQUIRES(
context, false,
errors::InvalidArgument(absl::StrCat(
"Unknown normalization form requrested: ", normalization_form_)));
}
for (int64 i = 0; i < input_vec.size(); ++i) {
string output_text;
icu::StringByteSink<string> byte_sink(&output_text);
const auto& input = input_vec(i);
normalizer->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()),
byte_sink, nullptr, icu_error);
OP_REQUIRES(
context, !U_FAILURE(icu_error),
errors::Internal(absl::StrCat(icu_error.errorName(),
": Could not normalize input string: ",
absl::string_view(input_vec(i)))));
output_vec(i) = output_text;
}
}
private:
string normalization_form_;
};
REGISTER_KERNEL_BUILDER(Name("NormalizeUTF8").Device(tensorflow::DEVICE_CPU),
NormalizeUTF8Op);
2.3 regex_replace 依賴github.com/google/re2開源庫(kù)
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
r"""Replaces matches of the `pattern` regular expression in `input` with the
replacement string provided in `rewrite`.
It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
Args:
input: A `Tensor` of type `string`. The text to be processed.
pattern: A `Tensor` of type `string`.
The regular expression to be matched in the `input` strings.
rewrite: A `Tensor` of type `string`.
The rewrite string to be substituted for the `pattern` expression where it is
matched in the `input` strings.
replace_global: An optional `bool`. Defaults to `True`.
If True, the replacement is global (that is, all matches of the `pattern` regular
expression in each input string are rewritten), otherwise the `rewrite`
substitution is only made for the first `pattern` match.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `string`.
"""
class RegexReplaceOp : public OpKernel {
public:
explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
}
void Compute(OpKernelContext* ctx) override {
const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
errors::InvalidArgument("Pattern must be scalar, but received ",
pattern_tensor->shape().DebugString()));
const string& pattern = pattern_tensor->scalar<tstring>()();
const RE2 match(pattern);
OP_REQUIRES(ctx, match.ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", match.error()));
const Tensor* rewrite_tensor;
OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
errors::InvalidArgument("Rewrite must be scalar, but received ",
rewrite_tensor->shape().DebugString()));
const string& rewrite = rewrite_tensor->scalar<tstring>()();
OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
}
private:
bool replace_global_;
};
目錄:/home/henry/.local/miniconda3/envs/tf3.7.5/lib/python3.7/site-packages/tensorflow_core/include/external/com_googlesource_code_re2/re2
// Copyright 2003-2009 The RE2 Authors. All Rights Reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#ifndef RE2_RE2_H_
#define RE2_RE2_H_
// C++ interface to the re2 regular-expression library.
// RE2 supports Perl-style regular expressions (with extensions like
// \d, \w, \s, ...).
//
// -----------------------------------------------------------------------
// REGEXP SYNTAX:
//
// This module uses the re2 library and hence supports
// its syntax for regular expressions, which is similar to Perl's with
// some of the more complicated things thrown away. In particular,
// backreferences and generalized assertions are not available, nor is \Z.
//
// See https://github.com/google/re2/wiki/Syntax for the syntax
// supported by RE2, and a comparison with PCRE and PERL regexps.
//
// For those not familiar with Perl's regular expressions,
// here are some examples of the most commonly used extensions:
//
// "hello (\\w+) world" -- \w matches a "word" character
// "version (\\d+)" -- \d matches a digit
// "hello\\s+world" -- \s matches any whitespace character
// "\\b(\\w+)\\b" -- \b matches non-empty string at word boundary
// "(?i)hello" -- (?i) turns on case-insensitive matching
// "/\\*(.*?)\\*/" -- .*? matches . minimum no. of times possible
//
// -----------------------------------------------------------------------
2.4 regex_split_with_offsets
class RegexSplitOp : public tensorflow::OpKernel {
public:
explicit RegexSplitOp(tensorflow::OpKernelConstruction* ctx)
: tensorflow::OpKernel(ctx) {}
void Compute(tensorflow::OpKernelContext* ctx) override {
bool should_keep_delim;
std::unique_ptr<RE2> delim_re;
GetRegexFromInput(ctx, "delim_regex_pattern", &delim_re);
std::unique_ptr<RE2> keep_delim_re;
GetRegexFromInput(ctx, "keep_delim_regex_pattern", &keep_delim_re);
should_keep_delim = keep_delim_re->pattern().empty() ? false : true;
const Tensor* input_tensor;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<tstring>();
std::vector<int64> begin_offsets;
std::vector<int64> end_offsets;
std::vector<absl::string_view> tokens;
std::vector<int64> row_splits;
row_splits.push_back(0);
for (size_t i = 0; i < input_flat.size(); ++i) {
RegexSplit(absl::string_view(input_flat(i).data()), *delim_re,
should_keep_delim, *keep_delim_re, &tokens, &begin_offsets,
&end_offsets);
row_splits.push_back(begin_offsets.size());
}
// Emit the flat Tensors needed to construct RaggedTensors for tokens,
// start, end offsets.
std::vector<int64> tokens_shape;
tokens_shape.push_back(tokens.size());
std::vector<int64> offsets_shape;
offsets_shape.push_back(begin_offsets.size());
std::vector<int64> row_splits_shape;
row_splits_shape.push_back(row_splits.size());
Tensor* output_tokens_tensor = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("tokens", TensorShape(tokens_shape),
&output_tokens_tensor));
auto output_tokens = output_tokens_tensor->flat<tstring>();
Tensor* output_begin_offsets_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("begin_offsets", TensorShape(offsets_shape),
&output_begin_offsets_tensor));
auto output_begin_offsets = output_begin_offsets_tensor->flat<int64>();
Tensor* output_end_offsets_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("end_offsets", TensorShape(offsets_shape),
&output_end_offsets_tensor));
auto output_end_offsets = output_end_offsets_tensor->flat<int64>();
Tensor* output_row_splits_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output("row_splits", TensorShape(row_splits_shape),
&output_row_splits_tensor));
auto output_row_splits = output_row_splits_tensor->flat<int64>();
// Copy outputs to Tensors.
for (size_t i = 0; i < tokens.size(); ++i) {
const auto& token = tokens[i];
output_tokens(i) = tstring(token.data(), token.length());
}
for (size_t i = 0; i < begin_offsets.size(); ++i) {
output_begin_offsets(i) = begin_offsets[i];
}
for (size_t i = 0; i < end_offsets.size(); ++i) {
output_end_offsets(i) = end_offsets[i];
}
for (size_t i = 0; i < row_splits.size(); ++i) {
output_row_splits(i) = row_splits[i];
}
}
};
SplitMergeTokenizer
def test_split_merge_tokenizer():
tokenizer=text.SplitMergeTokenizer()
input = "WelcomeToChina"
labels = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
print(tokenizer.tokenize(input, labels))
labels = [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
print(tokenizer.tokenize(input, labels))
labels = [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1]
print(tokenizer.tokenize(input, labels))
labels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
print(tokenizer.tokenize(input, labels))
strings = ["HelloMonday", "DearFriday"],
labels = [[0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]]
tokenizer = text.SplitMergeTokenizer()
print(tokenizer.tokenize(strings, labels, False))
運(yùn)行結(jié)果:
tf.Tensor([b'WelcomeToChina'], shape=(1,), dtype=string)
tf.Tensor([b'W' b'elcomeToChina'], shape=(2,), dtype=string)
tf.Tensor([b'Welcome' b'To' b'China'], shape=(3,), dtype=string)
tf.Tensor([b'W' b'e' b'l' b'c' b'o' b'm' b'e' b'T' b'o' b'C' b'h' b'i' b'n' b'a'], shape=(14,), dtype=string)
<tf.RaggedTensor [[[b'Hello', b'Monday'], [b'Dear', b'Friday']]]>