在模型實(shí)際的應(yīng)用中苏研,一般有兩種使用方法,一個(gè)是跑批數(shù)據(jù)欲诺,就像我們之前跑驗(yàn)證集那樣抄谐。比如說(shuō)我們收集到了很多需要去分類(lèi)的圖像,然后一次性的導(dǎo)入并使用我們訓(xùn)練好的模型給出結(jié)果扰法,預(yù)測(cè)完這一批之后程序就自動(dòng)關(guān)閉了蛹含,等到下一次我們有需要的時(shí)候再啟動(dòng)。另外一種就是應(yīng)用于線上服務(wù)塞颁,構(gòu)建一個(gè)服務(wù)等待新的請(qǐng)求浦箱,當(dāng)有請(qǐng)求發(fā)起的時(shí)候就接收數(shù)據(jù),然后給出結(jié)果祠锣,在沒(méi)有請(qǐng)求的時(shí)候酷窥,模型服務(wù)仍然處于運(yùn)行的狀態(tài),只不過(guò)是等待下一個(gè)請(qǐng)求伴网。
Flask框架
關(guān)于一次性處理批數(shù)據(jù)蓬推,我們前面的流程基本可以滿(mǎn)足了,這里介紹一個(gè)在線實(shí)時(shí)服務(wù)澡腾。FLask框架是一個(gè)用Python編寫(xiě)的Web微服務(wù)框架沸伏,F(xiàn)lask的使用十分簡(jiǎn)單,在日常開(kāi)發(fā)中可以快速地實(shí)現(xiàn)一個(gè)Web服務(wù)动分,而且靈活度很高毅糟。
首先安裝Flask。
pip install Flask
等待安裝完之后澜公,就可以編寫(xiě)代碼了姆另,假設(shè)我們寫(xiě)一個(gè)python腳本名字是flask_hello_world.py,內(nèi)容如下
from flask import Flask
app = Flask(__name__)
@app.route("/hello")
def hello():
return "Hello World!"
if __name__ == '__main__':
app.run()
然后在shell里面運(yùn)行它玛瘸,這里我們?cè)趓un方法里面沒(méi)有設(shè)置參數(shù)蜕青,就會(huì)使用默認(rèn)的127.0.0.1 host地址和5000端口苟蹈,啟動(dòng)成功可以看到下面的顯示
這個(gè)時(shí)候在瀏覽器中打開(kāi)它糊渊,輸入127.0.0.1:5000/hello,即可看到輸出的結(jié)果“Hello World!”慧脱,這就完成了一個(gè)最簡(jiǎn)單的web服務(wù)渺绒。
如果要讓它實(shí)現(xiàn)模型運(yùn)算,重點(diǎn)就是去修改hello方法。
import numpy as np
import sys
import os
import torch
from flask import Flask, request, jsonify
import json
from p2ch13.model_cls import LunaModel
app = Flask(__name__)
#加載模型
model = LunaModel()
model.load_state_dict(torch.load(sys.argv[1],
map_location='cpu')['model_state'])
model.eval()
#運(yùn)行推理部分
def run_inference(in_tensor):
with torch.no_grad():
# LunaModel 接收批量數(shù)據(jù)并輸出一個(gè)元組 (scores, probs)
out_tensor = model(in_tensor.unsqueeze(0))[1].squeeze(0)
probs = out_tensor.tolist()
out = {'prob_malignant': probs[1]}
return out
@app.route("/predict", methods=["POST"])
#預(yù)測(cè)方法的邏輯
def predict():
#使用request接收數(shù)據(jù)
meta = json.load(request.files['meta'])
blob = request.files['blob'].read()
#轉(zhuǎn)換成tensor
in_tensor = torch.from_numpy(np.frombuffer(
blob, dtype=np.float32))
in_tensor = in_tensor.view(*meta['shape'])
#推理宗兼,輸出
out = run_inference(in_tensor)
#返回結(jié)果
return jsonify(out)
if __name__ == '__main__':
app.run()
print (sys.argv[1])
這樣就已經(jīng)寫(xiě)好了最簡(jiǎn)單的服務(wù)代碼躏鱼,然后運(yùn)行它
這時(shí)候我們就已經(jīng)啟動(dòng)了web服務(wù),當(dāng)然我們這里處理的比較簡(jiǎn)單殷绍,在真實(shí)場(chǎng)景下通常都是后臺(tái)運(yùn)行染苛,并且要增加日志輸出和報(bào)警系統(tǒng),防止出現(xiàn)各種問(wèn)題而服務(wù)中斷主到。然后模擬客戶(hù)端向服務(wù)端發(fā)送請(qǐng)求茶行,很快就得到了結(jié)果,當(dāng)然這里有一份預(yù)先準(zhǔn)備好的數(shù)據(jù)登钥,不然光數(shù)據(jù)處理就要花好多時(shí)間畔师。
可以看到惡性腫瘤的可能性不大。到這里牧牢,我們就完成了一個(gè)簡(jiǎn)單的模型部署流程淆珊,當(dāng)然台囱,這里只是一個(gè)單一的服務(wù),如果我們?cè)诠ぷ髦行枰玫讲l(fā)服務(wù),異步服務(wù)可以在這個(gè)基礎(chǔ)上進(jìn)行修改慢洋,或者搭配其他的工具。比如說(shuō)要實(shí)現(xiàn)并發(fā)服務(wù)负甸,我們可以在服務(wù)器上啟動(dòng)多個(gè)服務(wù)虹蓄,然后搭配N(xiāo)ginx實(shí)現(xiàn)負(fù)載均衡。
Sanic框架
然后我們?cè)賮?lái)介紹一個(gè)異步處理框架Sanic±校現(xiàn)在是一個(gè)高并發(fā)的時(shí)代蹬蚁,并發(fā)量是在構(gòu)建服務(wù)時(shí)必須考量的一個(gè)指標(biāo)。所以我們自然就想到了 Python 中的異步框架郑兴,Sanic 的表現(xiàn)十分出色犀斋,使用 Sanic 構(gòu)建的應(yīng)用程序足以比肩 Nodejs。如果你再對(duì) Sanic 在路由處理方面使用 C 語(yǔ)言做一些重構(gòu)情连,那么并發(fā)性能可以和 Go 相媲美叽粹。
異步并發(fā)的流程大概像上圖描述的樣子,多個(gè)客戶(hù)端發(fā)起請(qǐng)求却舀,這些請(qǐng)求會(huì)進(jìn)入一個(gè)任務(wù)隊(duì)列虫几,然后這些任務(wù)的數(shù)據(jù)組成一個(gè)批數(shù)據(jù)傳給模型,模型給出預(yù)測(cè)結(jié)果挽拔,然后由請(qǐng)求處理器拆分結(jié)果并分別回傳給不同的客戶(hù)端辆脸。使用這種方式有助于提高我們的模型工作效率。
首先安裝Sanic螃诅。
pip install sanic
接下來(lái)就是使用sanic完成一個(gè)異步服務(wù)啡氢。我們這里使用的是把馬變成斑馬的模型状囱。來(lái)看看代碼,首先是一些引用項(xiàng)倘是。
import sys
import asyncio
import itertools
import functools
from sanic import Sanic
from sanic.response import json, text
from sanic.log import logger
from sanic.exceptions import ServerError
import sanic
import threading
import PIL.Image
import io
import torch
import torchvision
from .cyclegan import get_pretrained_model
定義一些全局變量或者參數(shù)亭枷。
#實(shí)例sanic
app = Sanic(__name__)
#設(shè)置使用的設(shè)備為cpu
device = torch.device('cpu')
# we only run 1 inference run at any time (one could schedule between several runners if desired)
MAX_QUEUE_SIZE = 3 # 隊(duì)列最大長(zhǎng)度
MAX_BATCH_SIZE = 2 # 批數(shù)據(jù)的最大長(zhǎng)度
MAX_WAIT = 1 # 最大等待時(shí)間
異常處理類(lèi)
class HandlingError(Exception):
def __init__(self, msg, code=500):
super().__init__()
self.handling_code = code
self.handling_msg = msg
模型運(yùn)行類(lèi)
class ModelRunner:
def __init__(self, model_name):
#首先是模型運(yùn)行的初始化
self.model_name = model_name
#聲明使用的隊(duì)列
self.queue = []
#聲明隊(duì)列鎖
self.queue_lock = None
#加載模型
self.model = get_pretrained_model(self.model_name,
map_location=device)
#是否運(yùn)行的標(biāo)記
self.needs_processing = None
#是否使用計(jì)時(shí)器
self.needs_processing_timer = None
調(diào)度運(yùn)行信號(hào)處理
def schedule_processing_if_needed(self):
#判斷隊(duì)列長(zhǎng)度是否已經(jīng)超過(guò)批大小
if len(self.queue) >= MAX_BATCH_SIZE:
logger.debug("next batch ready when processing a batch")
#如果隊(duì)列長(zhǎng)度夠長(zhǎng),把運(yùn)行標(biāo)記設(shè)置為需要運(yùn)行
self.needs_processing.set()
#否則判斷搀崭,如果隊(duì)列不為空叨粘,查看計(jì)時(shí)器
elif self.queue:
logger.debug("queue nonempty when processing a batch, setting next timer")
self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)
處理輸入數(shù)據(jù)并判斷是否需要運(yùn)行
async def process_input(self, input):
our_task = {"done_event": asyncio.Event(loop=app.loop),
"input": input,
"time": app.loop.time()}
async with self.queue_lock:
if len(self.queue) >= MAX_QUEUE_SIZE:
raise HandlingError("I'm too busy", code=503)
self.queue.append(our_task)
logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
self.schedule_processing_if_needed()
#等等處理完成
await our_task["done_event"].wait()
return our_task["output"]
運(yùn)行模型
def run_model(self, batch):
return self.model(batch.to(device)).to('cpu')
async def model_runner(self):
self.queue_lock = asyncio.Lock(loop=app.loop)
self.needs_processing = asyncio.Event(loop=app.loop)
logger.info("started model runner for {}".format(self.model_name))
#while True 無(wú)限循環(huán),程序會(huì)處于監(jiān)聽(tīng)狀態(tài)
while True:
#等待有任務(wù)來(lái)
await self.needs_processing.wait()
self.needs_processing.clear()
#清空計(jì)時(shí)器
if self.needs_processing_timer is not None:
self.needs_processing_timer.cancel()
self.needs_processing_timer = None
#處理隊(duì)列都開(kāi)啟鎖
async with self.queue_lock:
#如果隊(duì)列不為空則設(shè)置最長(zhǎng)等待時(shí)間
if self.queue:
longest_wait = app.loop.time() - self.queue[0]["time"]
else: # oops
longest_wait = None
#日志記錄啟動(dòng)處理瘤睹,隊(duì)列大小宣鄙,等待時(shí)間
logger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))
#獲取一個(gè)批次的數(shù)據(jù)
to_process = self.queue[:MAX_BATCH_SIZE]
#然后把這些數(shù)據(jù)從任務(wù)隊(duì)列中刪除
del self.queue[:len(to_process)]
self.schedule_processing_if_needed()
#生成批數(shù)據(jù)
batch = torch.stack([t["input"] for t in to_process], dim=0)
#在一個(gè)單獨(dú)的線程中運(yùn)行模型,然后返回結(jié)果
result = await app.loop.run_in_executor(
None, functools.partial(self.run_model, batch)
)
#記錄結(jié)果并設(shè)置一個(gè)完成事件
for t, r in zip(to_process, result):
t["output"] = r
t["done_event"].set()
del to_process
類(lèi)實(shí)例化
style_transfer_runner = ModelRunner(sys.argv[1])
最后是處理網(wǎng)絡(luò)交互
#路由策略
@app.route('/image', methods=['PUT'], stream=True)
#處理請(qǐng)求
async def image(request):
try:
#輸出報(bào)頭
print (request.headers)
content_length = int(request.headers.get('content-length', '0'))
#定義接收數(shù)據(jù)最大值
MAX_SIZE = 2**22 # 10MB
#如果接收數(shù)據(jù)超標(biāo)返回異常信息
if content_length:
if content_length > MAX_SIZE:
raise HandlingError("Too large")
#初始化數(shù)據(jù)接收
data = bytearray(content_length)
else:
data = bytearray(MAX_SIZE)
pos = 0
#這里也是True默蚌,一直處于監(jiān)聽(tīng)狀態(tài)
while True:
#讀取數(shù)據(jù)包
data_part = await request.stream.read()
if data_part is None:
break
#數(shù)據(jù)包拼接到data里面
data[pos: len(data_part) + pos] = data_part
pos += len(data_part)
if pos > MAX_SIZE:
raise HandlingError("Too large")
#然后開(kāi)始對(duì)接收的圖像數(shù)據(jù)進(jìn)行預(yù)處理
im = PIL.Image.open(io.BytesIO(data))
im = torchvision.transforms.functional.resize(im, (228, 228))
im = torchvision.transforms.functional.to_tensor(im)
im = im[:3] # drop alpha channel if present
if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:
raise HandlingError("need rgb image")
#使用實(shí)例化的模型程序處理圖像
out_im = await style_transfer_runner.process_input(im)
#結(jié)果轉(zhuǎn)化為圖像信息
out_im = torchvision.transforms.functional.to_pil_image(out_im)
imgByteArr = io.BytesIO()
out_im.save(imgByteArr, format='JPEG')
return sanic.response.raw(imgByteArr.getvalue(), status=200,
content_type='image/jpeg')
except HandlingError as e:
# we don't want these to be logged...
return sanic.response.text(e.handling_msg, status=e.handling_code)
啟動(dòng)服務(wù)部分
app.add_task(style_transfer_runner.model_runner())
app.run(host="0.0.0.0", port=8000,debug=True)
看完代碼冻晤,我們把它啟動(dòng)起來(lái)。
使用curl把圖像數(shù)據(jù)傳到web服務(wù)中绸吸,并設(shè)定了輸出結(jié)果到res1.jpg中
去對(duì)應(yīng)的位置查看鼻弧,果然新生成了一張圖片,可見(jiàn)我們的服務(wù)運(yùn)行良好锦茁。
當(dāng)然這里弄的兩個(gè)實(shí)現(xiàn)方案都挺簡(jiǎn)單的攘轩,不過(guò)核心部分基本都介紹到了,在實(shí)際的工作中就是在這個(gè)基礎(chǔ)上修修補(bǔ)補(bǔ)敲敲打打差不多就可以滿(mǎn)足需求码俩。
歷時(shí)一個(gè)半月度帮,終于把這本書(shū)看完了,英文原版寫(xiě)的挺好稿存,由淺入深笨篷,但是這個(gè)翻譯實(shí)在是有點(diǎn)爛,有需要英文原版電子書(shū)的留下郵箱瓣履。