本文主要介紹如何使用tensorflow serving官方提供的hashmap sourceadapter游两。其實理解了如何使用這個hashmapsourceadapter
垃环,也就真正能對serving進行二次開發(fā)缩滨,網(wǎng)上幾乎找不到任何相關(guān)資料(stackoverflow上的問題Tensorflow-serving: Serving a hashmap翘县,而為了真正讓這個hashmapsourceadapter
有用帜羊,我把serving的代碼全部看了一遍晨横。
這里一共分為6個步驟洋腮,第7步為接口測試。
首先手形,官方的代碼只是定義了 HashmapSourceAdapter
啥供。沒有注冊。
1 定義 HashmapSourceAdapterCreator
如下
// register this source adapter
class HashmapSourceAdapterCreator {
public:
static Status Create(
const HashmapSourceAdapterConfig& config,
std::unique_ptr<SourceAdapter<StoragePath, std::unique_ptr<Loader>>>*
adapter) {
adapter->reset(new HashmapSourceAdapter(config));
return Status::OK();
}
};
同時库糠,把這個類加為 HashmapSourceAdapter
的友元伙狐。
private:
friend class HashmapSourceAdapterCreator;
2 注冊 HashmapSourceAdapter
REGISTER_STORAGE_PATH_SOURCE_ADAPTER(HashmapSourceAdapterCreator,
HashmapSourceAdapterConfig);
3 使用hashmap servable
這一步非常重要,當(dāng)然這一步不是非要我這么做,但這是最簡單的方法贷屎。添加這個步驟的原因在于罢防,標(biāo)準(zhǔn)的C++編譯程序時,如果一個文件中的代碼如果沒有被調(diào)用唉侄,它就會被編譯器優(yōu)化掉咒吐。所以,這其實是一個hack属划。
第一步恬叹,在 hashmap_source_adapter.h 定義一個函數(shù)
void loadHashmapServable();
第二步,在hashmap_source_adapter.cc 中實現(xiàn)這個函數(shù)同眯。
void loadHashmapServable() {
LOG(INFO) << "load hashmap servable...";
}
第三步妄呕,在main函數(shù)的開頭調(diào)用這個函數(shù)
tensorflow::serving::loadHashmapServable();
4 http中添加使用hashmap servable的接口
這一步,我們會在原有http接口的基礎(chǔ)上嗽测,添加一個lookup接口绪励。
一, 在ProcessRequest
中添加分支lookup
} else if (method == "lookup") {
status = ProcessLookupRequest(model_name, model_version, request_body,
output);
}
二唠粥,定義函數(shù) ProcessLookupRequest
Status HttpRestApiHandler::ProcessLookupRequest(
const absl::string_view model_name,
const absl::optional<int64>& model_version,
const absl::string_view request_body, string* output) {
ModelSpec model_spec;
model_spec.set_name(string(model_name));
if (model_version.has_value()) {
model_spec.mutable_version()->set_value(model_version.value());
}
ServableHandle<std::unordered_map<string, string>> bundle;
TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &bundle));
std::unordered_map<std::string, std::string>::const_iterator got = bundle->find(request_body.data());
if (got == bundle->end()) {
output->assign(string("None"));
} else {
output->assign(got->second);
}
return Status::OK();
}
三疏魏,放開URL正則匹配的限制
prediction_api_regex_(
R"((?i)/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict|lookup))"),
5 給hashmap servable加載的文件添加一個文件名
我們程序啟動時會去模型目錄下加載一個名為 data.csv
的文件。
const string fpath = io::JoinPath(path, "data.csv");
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(fpath, &file));
該文件的格式如下:
key0,value0
key1,value1
tom,jerry
pete,henry
hello,world
good,bye
6 從配置文件啟動TF serving
因為加入了hashmap servable晤愧,tensorflow serving不止支持一個platform大莫,當(dāng)tensorflow serving支持多個platform的時候需要從配置文件啟動,命令如下:
tensorflow_model_server --port=8500 --rest_api_port=8501 --platform_config_file=./etc/platform.conf --model_config_file=./etc/models.conf
其中官份,platform.conf內(nèi)容如下:
platform_configs {
key: "hashmap"
value {
source_adapter_config {
type_url: "type.googleapis.com/tensorflow.serving.HashmapSourceAdapterConfig"
}
}
}
platform_configs {
key: "tensorflow"
value {
source_adapter_config {
type_url: "type.googleapis.com/tensorflow.serving.SavedModelBundleSourceAdapterConfig"
value: "\302>\002\022\000"
}
}
}
models.conf內(nèi)容如下:
model_config_list: {
config: {
name: "tensorflow",
base_path: "/data/models/tensorflow",
model_platform: "tensorflow",
model_version_policy: {
all: {}
}
}
config: {
name: "hash",
base_path: "/data/models/hash",
model_platform: "hashmap",
model_version_policy: {
all: {}
}
}
}
platform.conf的編寫可以參考這個issue
how to write a tensorflow serving platform_config_file
7 測試
訪問接口
curl -d 'hello' -X POST http://localhost:8501/v1/models/hash/versions/1:lookup
輸出
world