背景
發(fā)現(xiàn)fastAPI和pytorch一起使用時瓤鼻,如果不使用async
定義接口則會產(chǎn)生內(nèi)存泄露诉植,走讀一下fastAPI代碼看下區(qū)別到底在哪脊另,相關git issue為https://github.com/tiangolo/fastapi/issues/596
fastAPI uvicorn代碼走讀
調(diào)用rest接口時,會走到starlette.routing.py
中class Router
的call()
方法恶耽,進行url匹配栓票,如果走的是默認url群匹配决左,看這幾行代碼就足夠了,下面不重要。
starlette.routing.py class Router
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The main entry point to the Router class.
"""
assert scope["type"] in ("http", "websocket", "lifespan")
if "router" not in scope:
scope["router"] = self
# life span是控制服務器的起停的哆窿,這里不用關注
if scope["type"] == "lifespan":
await self.lifespan(scope, receive, send)
return
partial = None
for route in self.routes:
# Determine if any route matches the incoming scope,
# and hand over to the matching route if found.
match, child_scope = route.matches(scope)
if match == Match.FULL:
scope.update(child_scope)
# 全匹配走到這里去調(diào)用實現(xiàn)并封裝http請求
await route.handle(scope, receive, send)
return
elif match == Match.PARTIAL and partial is None:
partial = route
partial_scope = child_scope
此處的routing實例應該是類fastapi.routing.py
中的class APIRoute
的實例,但是那塊沒覆寫__call__()
方法厉斟,所以此處的self.routes
屬性就是ASGI初始化的時候通過裝飾器放入的starlette.Route
對象的實例挚躯。對應的handle實現(xiàn)如下
starlette.routing.py class Route
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.methods and scope["method"] not in self.methods:
if "app" in scope:
raise HTTPException(status_code=405)
else:
response = PlainTextResponse("Method Not Allowed", status_code=405)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
其中,F(xiàn)astAPI中的route對象的實現(xiàn)為fastapi.routing.py
中class APIRoute(routing.Route)
為starlette Route對象的子類擦秽,app屬性的初始化方法如下码荔。
fastapi.routing.py class APIRoute
class APIRoute(routing.Route):
def __init__:
#其他屬性初始化省略了
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable:
return get_request_handler(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
response_class=self.response_class or JSONResponse,
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
)
下面都是對http請求處理的實現(xiàn):
fastapi.routing.py
def get_request_handler(
dependant: Dependant,
body_field: ModelField = None,
status_code: int = 200,
response_class: Type[Response] = JSONResponse,
response_field: ModelField = None,
response_model_include: Union[SetIntStr, DictIntStrAny] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Any = None,
) -> Callable:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(get_field_info(body_field), params.Form)
async def app(request: Request) -> Response:
try:
body = None
if body_field:
if is_body_form:
body = await request.form()
else:
body_bytes = await request.body()
if body_bytes:
body = await request.json()
except Exception as e:
logger.error(f"Error getting request body: {e}")
raise HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
solved_result = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(errors, body=body)
else:
# 在這里調(diào)用你的rest接口實現(xiàn)
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_data = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
)
response = response_class(
content=response_data,
status_code=status_code,
background=background_tasks,
)
response.headers.raw.extend(sub_response.headers.raw)
if sub_response.status_code:
response.status_code = sub_response.status_code
return response
return app
starlette.routing.py
def request_response(func: typing.Callable) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
is_coroutine = asyncio.iscoroutinefunction(func)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
# 在fastAPI中 func就是get_request_handler返回的協(xié)程對象,is_corutine總是true感挥。
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
return app
上面我們已經(jīng)看到了缩搅,fastAPI在是通過dependant對象來驅動接口實現(xiàn)的,下面進去看下dependant對象的初始化触幼。
fastapi.dependencies.utils.py
def get_dependant(
*,
path: str,
call: Callable,
name: str = None,
security_scopes: List[str] = None,
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
check_dependency_contextmanagers()
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
sub_dependant = get_param_sub_dependant(
param=param, path=path, security_scopes=security_scopes
)
dependant.dependencies.append(sub_dependant)
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
continue
if add_non_field_param_to_dependency(param=param, dependant=dependant):
continue
param_field = get_param_field(
param=param, default_field_info=params.Query, param_name=param_name
)
if param_name in path_param_names:
assert is_scalar_field(
field=param_field
), f"Path params must be of one of the supported types"
if isinstance(param.default, params.Path):
ignore_default = False
else:
ignore_default = True
param_field = get_param_field(
param=param,
param_name=param_name,
default_field_info=params.Path,
force_type=params.ParamTypes.path,
ignore_default=ignore_default,
)
add_param_to_fields(field=param_field, dependant=dependant)
elif is_scalar_field(field=param_field):
add_param_to_fields(field=param_field, dependant=dependant)
elif isinstance(
param.default, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
add_param_to_fields(field=param_field, dependant=dependant)
else:
field_info = get_field_info(param_field)
assert isinstance(
field_info, params.Body
), f"Param: {param_field.name} can only be a request body, using Body(...)"
dependant.body_params.append(param_field)
這里看到也就是對一下路徑參數(shù)啥的初始化也校驗啥的硼瓣,沒啥了,直接往下看調(diào)用邏輯吧
async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
) -> Any:
# Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile.
assert dependant.call is not None, "dependant.call must be a function"
if is_coroutine:
return await dependant.call(**values)
else:
return await run_in_threadpool(dependant.call, **values)
OK置谦,這里就可以知道fastAPI定義rest接口加不加async有什么區(qū)別了堂鲤,一個是直接協(xié)程調(diào)用,不加async走了run_in_threadpool
async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
) -> T:
loop = asyncio.get_event_loop()
if contextvars is not None: # pragma: no cover
# Ensure we run in the same context
child = functools.partial(func, *args, **kwargs)
context = contextvars.copy_context()
func = context.run
args = (child,)
elif kwargs: # pragma: no cover
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await loop.run_in_executor(None, func, *args)
這里已經(jīng)看到實際執(zhí)行時仍然使用的uvloop的事件循環(huán)loop.run_in_executor(None, func, *args)
媒峡,下面就可以通過這一步入手來看是不是pytorch于uvloop跑在一起就存在內(nèi)存泄露了瘟栖。
當前結論:如果使用事件循環(huán)的run_in_executor并不指定executor時,默認executor的worker數(shù)量為cpu數(shù)量x5谅阿,線程在執(zhí)行完后不會釋放資源半哟,但是當線程池已經(jīng)滿了以后理論上內(nèi)存不應繼續(xù)上漲
接下來貼下我的測試代碼:
import asyncio
import cv2 as cv
import gc
from pympler import tracker
from concurrent import futures
executor = futures.ThreadPoolExecutor(max_workers=1)
memory_tracker = tracker.SummaryTracker()
def mm():
img = cv.imread("cap.jpg", 0)
detector = cv.AKAZE_create()
kpts, desc = detector.detectAndCompute(img, None)
gc.collect()
memory_tracker.print_diff()
return None
async def main():
while True:
loop = asyncio.get_event_loop()
await loop.run_in_executor(executor, mm)
if __name__=='__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
我的測試機上有40個cpu,所以理論上線程池的worker上線為200签餐,如果指定executor最大數(shù)量的話測試(如以上代碼)寓涨,會發(fā)現(xiàn)內(nèi)存穩(wěn)定沒有泄露,但是如果跟fastAPI一樣的話會發(fā)現(xiàn)內(nèi)存在前200次循環(huán)會一直上漲氯檐,之后穩(wěn)定缅茉,但是如果你再thread_pool里執(zhí)行的是特別大的模型的話,這里200這個數(shù)量級就太大了男摧,有可能會吃掉非常多的內(nèi)存蔬墩。
結論:如果用fastAPI跑非常大的深度學習模型,且部署的機器CPU數(shù)量較多的話耗拓,的確會吃掉很多內(nèi)存拇颅,但是這里不是內(nèi)存泄露,還是有上限的乔询,但是還是建議starlette可以修改可以配置線程池大小樟插,否則吃掉的內(nèi)存太多了。當前建議容器化封裝的時候只給對應服務分配少量的cpu資源,可以解決這個問題黄锤。
另外搪缨,python 3.8已經(jīng)限制了線程池的最大數(shù)量如下,如果你用的python 3.8也不用操心這個問題了鸵熟。
if max_workers is None:
# ThreadPoolExecutor is often used to:
# * CPU bound task which releases GIL
# * I/O bound task (which releases GIL, of course)
#
# We use cpu_count + 4 for both types of tasks.
# But we limit it to 32 to avoid consuming surprisingly large resource
# on many core machine.
max_workers = min(32, (os.cpu_count() or 1) + 4)
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")