新增的功能
在《簡易協(xié)程-2》的基礎(chǔ)上增加協(xié)程同步等待吻氧、IO超時的支持僚匆。
增加一個新類JoinAction支持協(xié)程同步等待憋飞,yield這個類的對象會讓協(xié)程進入等待狀態(tài),直到目標協(xié)程退出或者超時争剿。使用示例如下已艰。
# 生成并運行另外一個協(xié)程c
c = cf1()
Scheduler.add(c)
t1 = time()
# 等待c完成,超時時間0.5秒蚕苇,結(jié)果是is_timeout 哩掺,如果為True則表示等待超時了
is_timeout = yield JoinAction(c, 0.5)
IO超時的實現(xiàn)是在SocketIO中增加了超時時間參數(shù),單位也是秒涩笤。如果請求的事件未能在給定時間到達嚼吞,則調(diào)度器會在協(xié)程內(nèi)拋出一個異常。示例如下蹬碧。
# require to write data, timeout is 5s
yield SocketIO(sock.fileno(), read=False, timeout=5)
sock.send("data")
yield SocketIO(sock.fileno(), read=True, timeout=5)
data = sock.recv(1024)
完整代碼
以下是詳細代碼舱禽。
#!/usr/bin/env python
# coding: utf-8
from collections import deque
from errno import ETIMEDOUT
from heapq import heappop
from heapq import heappush
from itertools import chain
from select import select
from socket import timeout as SocketTimeoutError
from sys import exc_info
from sys import maxint
from time import sleep
from time import time
from types import GeneratorType
class Sleep(object):
__slots__ = ["seconds", ]
def __init__(self, seconds):
# type: (float) -> object
self.seconds = seconds
# assert seconds >= 0
class SocketIO(object):
__slots__ = ["sock_fd", "read", "timeout"]
def __init__(self, sock_fd, read=True, timeout=-1):
self.sock_fd = sock_fd
self.read = read
self.timeout = timeout
class JoinAction(object):
__slots__ = ["target_coroutine", "timeout"]
def __init__(self, target_coroutine, timeout=-1.0):
# type: (GeneratorType, float) -> JoinAction
"""
:param target_coroutine: target generator
:param timeout: seconds
"""
self.target_coroutine = target_coroutine
self.timeout = timeout
class Coroutine(object):
__slots__ = ["generator", "parent", "init_value", "exception_info", "name"]
def __init__(self, generator, parent=None, init_value=None, exception_info=(), name=""):
# type: (GeneratorType, Coroutine, object, tuple) -> Coroutine
self.generator = generator
self.parent = parent
self.init_value = init_value
self.exception_info = exception_info
self.name = name
if not name:
self.name = generator.gi_code.co_name
def __str__(self):
return "%s.%s" % (self.name, self.cid())
__repr__ = __str__
def cid(self):
return id(self.generator)
def reset_input(self, value=None, exception_info=()):
self.init_value = value
self.exception_info = exception_info
def run(self):
if self.exception_info:
value = self.generator.throw(*self.exception_info)
self.exception_info = ()
else:
value = self.generator.send(self.init_value)
self.init_value = value
return value
class CoroutineError(Exception):
pass
class FakeSocket(object):
__slots__ = ["data"]
def __init__(self):
self.data = ""
def fileno(self):
return id(self)
def send(self, data):
self.data = data
return len(data)
def recv(self, _):
return "HTTP/1.1 200 OK\r\nContent-Length:0\r\n\r\n"
from random import random
next_time = {}
def fake_select(rlist, wlist, xlist, timeout):
rxlist = list(rlist)
wxlist = list(wlist)
return rxlist, wxlist, []
WAIT_CANCELED = 0
WAIT_SOCKET = 1
WAIT_JOIN = 2
WAIT_SLEEP = 3
class _TimeoutItem(object):
def __init__(self, till, wait_type, arg):
# type: (int, int, object) -> _TimeoutItem
self.till = till
self.wait_type = wait_type
self.arg = arg
self.id = id(self)
class Scheduler(object):
_instance = None
def __init__(self, ignore_exception=True, debug=False):
"""
:param debug: output running detail
:param ignore_exception: ignore coroutine's uncaught exception
"""
self.ignore_exception = ignore_exception
self.debug = debug
# if true, append debug logs to _debug_logs; else, print them to stdout
self.collect_debug_logs = False
self._debug_logs = []
# use fake_select() to test performance or simulation, work with FakeSocket
self.use_fake_select = False
#
self.start_time = time()
# map coroutine_id => coroutine
self.cid2coroutine = {}
# running queue
self.queue = deque()
# map: sock_fd -> [coroutine, timeout_item]
self.sock_map = {}
self.io_read_queue = set()
self.io_write_queue = set()
# map: coroutine_id-> waiters {waiter_coroutine_it->timeout_item, ...}, by join wait
self.waiting_map = {}
# [timeout_item, ...]
# map: millisecond (int) -> dict(item_id -> timeout_item)
self.timer_slots_map = {}
# [ms1, ms2, ...]
self.millisecond_heap = []
#
self.alive_coroutine_num = 0
# current running coroutine
self.current = None
# whether run() is calling
self.running = False
@classmethod
def get_instance(cls):
# type: () -> Scheduler
if not cls._instance:
cls._instance = cls()
return cls._instance
def _debug_output(self, msg, *args):
if self.debug:
if self.collect_debug_logs:
self._debug_logs.append(("%.6f" % time(), msg % args))
else:
print "%.6f" % time(), msg % args
else:
pass
def _add(self, generator):
co = Coroutine(generator)
cid = co.cid()
self.cid2coroutine[cid] = co
self.alive_coroutine_num += 1
self._debug_output("add new coroutine %d, alive_coroutine_num=%d",
cid, self.alive_coroutine_num)
self.queue.append(co)
return self
def _coroutine_exit(self, coroutine, is_error):
# type: (Coroutine, bool) -> object
cid = coroutine.cid()
assert cid in self.cid2coroutine
parent = coroutine.parent
if parent is None:
# wake up all waiters or cancel io wait timeout
waiters = self.waiting_map.pop(cid, None)
if waiters:
assert isinstance(waiters, dict)
# join wait
self._debug_output("%s wake up %d waiters", cid, len(waiters))
for wcid, timeout_item in waiters.iteritems():
waiter = self.cid2coroutine[wcid]
waiter.reset_input(False)
self.queue.append(waiter)
# invalid timeout_item
self.timer_slots_map[timeout_item.till].pop(timeout_item.id)
del waiters
self.alive_coroutine_num -= 1
else:
if is_error:
parent.reset_input(None, exc_info())
else:
parent.reset_input(coroutine.init_value, ())
self.queue.append(parent)
self.cid2coroutine.pop(cid)
self._debug_output("coroutine %d exited, alive_coroutine_num=%d", cid, self.alive_coroutine_num)
def _current_coroutine(self):
# type: () -> Coroutine
return self.current
@classmethod
def current_id(cls):
return cls.get_instance()._current_coroutine().cid()
@classmethod
def current_name(cls):
return cls.get_instance()._current_coroutine().name
def _add_timeout(self, seconds, wait_type, arg):
# type: (float, int, object) -> _TimeoutItem
till = int(1000 * (time() - self.start_time + seconds + 0.0005)) if seconds >= 0 else maxint
self._debug_output('coroutine add a timeout task at %sms from start', till)
timeout_item = _TimeoutItem(till, wait_type, arg)
# insert new item
if till in self.timer_slots_map:
self.timer_slots_map[till][timeout_item.id] = timeout_item
else:
self.timer_slots_map[till] = {timeout_item.id: timeout_item}
heappush(self.millisecond_heap, till)
return timeout_item
def _do_coroutine_io(self, coroutine, event):
# type: (Coroutine, SocketIO) -> object
coroutine.reset_input()
sock_fd = event.sock_fd
if event.read:
self.io_read_queue.add(sock_fd)
else:
self.io_write_queue.add(sock_fd)
timeout_item = self._add_timeout(event.timeout, WAIT_SOCKET, sock_fd)
self.sock_map[sock_fd] = [coroutine, timeout_item]
def _do_coroutine_sleep(self, coroutine, seconds):
coroutine.reset_input()
timeout_item = self._add_timeout(seconds, WAIT_SLEEP, coroutine)
self._debug_output('coroutine go to sleep until %s', timeout_item.till)
def _do_coroutine_join(self, coroutine, event):
# type: (Coroutine, JoinAction) -> None
target_cid = id(event.target_coroutine)
timeout = event.timeout
cid = coroutine.cid()
if cid == target_cid:
try:
raise CoroutineError("can't join self")
except CoroutineError:
coroutine.reset_input(None, exc_info())
self.queue.append(coroutine)
elif target_cid not in self.cid2coroutine:
# target coroutine exited, join action ends
coroutine.reset_input(False)
self.queue.append(coroutine)
elif 0 <= timeout < 0.001:
# timeout too small, so just tell coroutine he is timeout
coroutine.reset_input(True)
self.queue.append(coroutine)
else:
self._debug_output("coroutine %s try to join %s, timeout=%f",
cid, target_cid, timeout)
timeout_item = self._add_timeout(timeout, WAIT_JOIN, (cid, target_cid))
if target_cid in self.waiting_map:
self.waiting_map[target_cid][cid] = timeout_item
else:
self.waiting_map[target_cid] = {cid: timeout_item}
# noinspection PyBroadException
def _process_running_queue(self):
old_queue = self.queue
self.queue = deque()
append = self.queue.append
for coroutine in old_queue:
self.current = coroutine
# assert isinstance(coroutine, Coroutine)
try:
value = coroutine.run()
except StopIteration:
self._coroutine_exit(coroutine, False)
continue
except:
self._coroutine_exit(coroutine, True)
if coroutine.parent is None and not self.ignore_exception:
self._debug_output("%s raise uncaught exception", coroutine.cid())
raise
else:
continue
if value is None:
# yield to other coroutines
append(coroutine)
elif isinstance(value, GeneratorType):
sub = Coroutine(value, coroutine)
append(sub)
self.cid2coroutine[sub.cid()] = sub
elif isinstance(value, SocketIO):
self._do_coroutine_io(coroutine, value)
elif isinstance(value, Sleep):
self._do_coroutine_sleep(coroutine, value.seconds)
elif isinstance(value, JoinAction):
self._do_coroutine_join(coroutine, value)
else:
# this coroutine exit
self._coroutine_exit(coroutine, False)
self.current = None
def _process_sleep_queue(self):
now = time()
from_start_ms = int(1000 * (now - self.start_time))
millisecond_heap = self.millisecond_heap
while millisecond_heap:
# check recent till millisecond time
till = heappop(millisecond_heap)
# get all timeout tasks in this millisecond
item_map = self.timer_slots_map.pop(till)
if till > from_start_ms:
# there are some tasks in this millisecond, so loop ends
if item_map:
self.timer_slots_map[till] = item_map
heappush(self.millisecond_heap, till)
return min(1.0, 0.001 * (till - from_start_ms))
else:
# no task, continue to next millisecond
continue
# do time out tasks
assert isinstance(item_map, dict)
for timeout_item in item_map.itervalues():
assert isinstance(timeout_item, _TimeoutItem)
wait_type = timeout_item.wait_type
if wait_type is WAIT_CANCELED:
continue
assert timeout_item.till == till
arg = timeout_item.arg
if wait_type is WAIT_JOIN:
# join time out
waiting_cid, target_cid = arg
waiters = self.waiting_map[target_cid]
assert isinstance(waiters, dict)
self._debug_output("coroutine %s join %s time out", waiting_cid, target_cid)
del waiters[waiting_cid]
if not waiters:
del self.waiting_map[target_cid]
# wake up this waiter
waiter = self.cid2coroutine[waiting_cid]
# true: really timeout
waiter.reset_input(True)
self.queue.append(waiter)
self._debug_output("%s timeout on join", waiter)
elif wait_type is WAIT_SOCKET:
# io time out
sock_fd = arg
self._debug_output("socket %s io timeout", sock_fd)
# sock_fd already never listen for events
if sock_fd not in self.sock_map:
continue
# un-register event watch
self.io_read_queue.discard(sock_fd)
self.io_write_queue.discard(sock_fd)
# find the owner coroutine of this sock_fd
coroutine, timeout_item = self.sock_map[sock_fd]
assert isinstance(coroutine, Coroutine)
# owner maybe already exited
if coroutine.cid() not in self.cid2coroutine:
continue
try:
raise SocketTimeoutError(ETIMEDOUT, "timeout")
except SocketTimeoutError:
# raise exception to this coroutine
coroutine.reset_input(None, exc_info())
self.queue.append(coroutine)
self._debug_output("%s timeout on socket", coroutine)
else:
# sleep type, arg is sleeping coroutine. sleep is reached, so wake up this coroutine
assert wait_type is WAIT_SLEEP
assert isinstance(arg, Coroutine)
self.queue.append(arg)
self._debug_output("%s wake up from sleep", arg)
del item_map
return 0.0
def _process_io(self, sleep_seconds):
io_read_queue = self.io_read_queue
io_write_queue = self.io_write_queue
queue_append = self.queue.append
if self.use_fake_select:
rxlist, wxlist, exlist = fake_select(io_read_queue,
io_write_queue, [],
sleep_seconds)
else:
rxlist, wxlist, exlist = select(io_read_queue,
io_write_queue, [],
sleep_seconds)
# collect coroutines waiting for these sockets
io_read_queue -= set(rxlist)
io_write_queue -= set(wxlist)
if exlist:
exset = set(exlist)
io_read_queue -= exset
io_write_queue -= exset
# wake coroutines
for sock_fd in chain(rxlist, wxlist, exlist):
self._debug_output("socket %s become ready", sock_fd)
coroutine, timeout_item = self.sock_map[sock_fd]
cid = coroutine.cid()
assert cid in self.cid2coroutine
queue_append(coroutine)
# try to cancel io timeout item
assert timeout_item.wait_type is WAIT_SOCKET
self.timer_slots_map[timeout_item.till].pop(timeout_item.id)
def _run(self):
if self.running:
raise CoroutineError("already running")
self.running = True
self._debug_logs = []
io_read_queue = self.io_read_queue
io_write_queue = self.io_write_queue
# start to run all coroutines until all exited
while self.alive_coroutine_num > 0:
self._process_running_queue()
sleep_seconds = self._process_sleep_queue()
if io_read_queue or io_write_queue:
if self.queue:
# print "queue is not empty, io timeout set 0"
sleep_seconds = 0
elif sleep_seconds > 1:
sleep_seconds = 1
self._process_io(sleep_seconds)
elif sleep_seconds > 0 and not self.queue and self.millisecond_heap:
# sleep_seconds += 0.0003
self._debug_output("try to sleep for %.6fs", sleep_seconds)
sleep(sleep_seconds)
self._debug_output("wake up at %.6f", time())
# ended
self.running = False
assert not self.queue
assert not self.io_read_queue
assert not self.io_write_queue
assert not self.millisecond_heap
assert not self.timer_slots_map
assert not self.waiting_map
assert not self.cid2coroutine
assert self.current is None
self.sock_map.clear()
@classmethod
def add(cls, coroutine):
return cls.get_instance()._add(coroutine)
@classmethod
def add_many(cls, coroutine_list):
"""
add many coroutines to scheduler
:param coroutine_list: coroutine array
:return: scheduler
"""
for coroutine in coroutine_list:
cls.get_instance()._add(coroutine)
return cls.get_instance()
@classmethod
def run(cls):
return cls.get_instance()._run()
@classmethod
def set_debug(cls, debug=True, collect_logs=False):
cls.get_instance().debug = debug
cls.get_instance().collect_debug_logs = collect_logs
@classmethod
def get_debug_logs(cls):
return cls.get_instance()._debug_logs
@classmethod
def set_use_fake_select(cls, use_fake_select=True):
cls.get_instance().use_fake_select = use_fake_select
def async_urlopen(sock, url, method="GET", headers=(), data=""):
"""
async HTTP request
:param sock:
:param url:
:param method:
:param headers: (head, value) headers list
:param data:
:return response: (code, reason, headers, body)
"""
pieces = [method, ' ', url, ' HTTP/1.1\r\n', ]
for head, val in headers:
pieces.extend((head, ':', val, '\r\n'))
pieces.extend(('Content-Length:', str(len(data)), '\r\n'))
pieces.append('Connection: keep-alive\r\n\r\n')
pieces.append(data)
req_bin = ''.join(pieces)
while req_bin:
yield SocketIO(sock.fileno(), read=False)
sent = sock.send(req_bin)
req_bin = req_bin[sent:]
resp_bin = ""
resp_len = -1
code = 400
reason = "bad request"
while resp_len != len(resp_bin):
yield SocketIO(sock.fileno(), read=True)
data = sock.recv(32 << 10)
if resp_len > 0:
resp_bin += data
else:
resp_bin += data
parts = resp_bin.split('\r\n\r\n', 1)
if len(parts) != 2:
continue
head_bin, resp_bin = parts
lines = head_bin.split('\r\n')
status_line = lines[0]
version, code, reason = status_line.split(' ', 2)
code = int(code)
headers = [line.split(':', 1) for line in lines[1:-1]]
if method == 'HEAD':
break
resp_len = 0
for head, val in headers:
if head.lower() == 'content-length':
resp_len = int(val)
break
yield (code, reason, headers, resp_bin)
超時的實現(xiàn)原理
超時用于三個功能:休眠、IO超時锰茉、JoinAction超時呢蔫。這三者具有一定的相似性,都需要計算一段時間飒筑,到達指定時間再用不同的方式處理片吊。歸結(jié)到一起就是,都需要創(chuàng)建一個一次性定時任務协屡。到達指定時間后俏脊,對于休眠任務則喚醒協(xié)程,加入到可運行隊列肤晓;對于IO任務爷贫,則喚醒協(xié)程,產(chǎn)生超時異常給協(xié)程补憾;對于JoinAction任務漫萄,則喚醒等待的協(xié)程,并用超時結(jié)果傳遞給這個協(xié)程盈匾。后兩者有一點不同的地方是腾务,這兩處的定時任務可能中途會被取消。如果IO及時到達削饵,超時任務必須取消岩瘦。如果目標協(xié)程及時退出,JoinAction超時任務也必須取消窿撬。
為了簡化實現(xiàn)启昧,超時任務的精度只取到毫秒級,這樣就可以用整數(shù)來表示毫秒劈伴。
先說一下設(shè)計的主要的數(shù)據(jù)結(jié)構(gòu)timer_slots_map和millisecond_heap密末。
millisecond_heap如名字所示,是一個毫秒整數(shù)的小根堆跛璧。毫秒數(shù)是當前時間減去進程啟動時間的毫秒時間苏遥。每個數(shù)字表示這個時間段內(nèi)可能存在著超時任務。示例如下赡模。
+++++++++++++
|100|200|280|
+++++++++++++
堆中有三個時間田炭,100、200漓柑、280教硫,也就是說明,在99-100毫秒辆布、199-200毫秒瞬矩、279-280毫秒這些時間段可能存在超時任務。
使用堆這種數(shù)據(jù)結(jié)構(gòu)锋玲,我們可以快速的得到最近的時間景用、快速的插入新的時間。由于堆的結(jié)構(gòu)自身的高效性以及python使用c語言的實現(xiàn),所以即使長度很大伞插,添加割粮、刪除的耗時依然會很小。
接下來是timer_slots_map媚污,這是一個稍微復雜的數(shù)據(jù)結(jié)構(gòu)舀瓢,功能是保存所有定時任務。這個結(jié)構(gòu)的第一級是一個字典耗美,毫秒時間映射到對應的定時任務列表京髓。定時任務列表也是一個字典結(jié)構(gòu),每個定時任務用timeout_item表示商架,則列表的映射方式是id(timeout_item) -> timeout_item堰怨。
以下是一個timer_slots_map的示例。
{
100 => { id1 => timeout_item1 , id2 => timeout_item2 },
280 => { id3 => timeout_item3},
200 => {}
}
如上所示蛇摸,有三個時間點有定時任務备图,其中100這個時間有兩個任務,而200這個時間點則沒有皇型,這是個正常的現(xiàn)象诬烹,當定時任務取消時就會出現(xiàn)。
現(xiàn)在來說一下幾個需要實現(xiàn)的定時任務接口:
- 增加定時任務
- 取消定時任務
- 獲取時間最近的定時任務
1 增加定時任務
功能就是將定時任務timeout_item加入到隊列中弃鸦,定時任務包含具體的類型绞吁、參數(shù)等,這里我們只關(guān)注時間唬格。
首先是計算時間家破,可以得到一個毫秒整數(shù)till。檢查till是否在timer_slots_map中购岗,如果是汰聋,則till必然已經(jīng)在millisecond_heap中,否則需要追加到millisecond_heap尾部喊积,使用heappush()自動維護堆的結(jié)構(gòu)烹困。最后就是將timeout_item插入到timer_slots_map[till]這個字典中。
timer_slots_map[till][id(timeout_item)] = timeout_item
2. 取消定時任務
輸入?yún)?shù)timeout_item乾吻。
首先根據(jù)這個定時任務計算超時時間till髓梅,再從timer_slots_map[till]這個字典中刪除timeout_item。由于我們采用timeout_item的id作為鍵绎签,所以只需要用timeout_item的id刪除即可枯饿。這實際上也就是要,這個刪除的timeout_item必須是先前增加定時任務使用的對象诡必。
del timer_slots_map[till][id(timeout_item)]
注意到一點奢方,添加的時候millisecond_heap可能加入了till,但是刪除的時候,卻沒有從millisecond_heap刪除till這個時間蟋字。這么做是有原因的稿蹲,堆本質(zhì)是數(shù)組,從數(shù)組中間刪除元素的代價是很大的愉老。保留till在原處并不會影響多大场绿,而且由于我們采用的是毫秒為時間剖效,這也就限制了millisecond_heap的長度嫉入。如果采用的精確的雙精度表示時間,則millisecond_heap必然會膨脹到無法承受的長度璧尸。
3. 獲取時間最近的定時任務
millisecond_heap是小根堆咒林,第一個元素就是最近的時間。使用heappop()函數(shù)可以方便的從millisecond_heap彈出首個時間爷光,再根據(jù)這個時間去timer_slots_map查找對應的定時任務列表垫竞。
總結(jié)
使用堆和字典兩個數(shù)據(jù)結(jié)構(gòu),高效而簡潔的實現(xiàn)了定時任務蛀序。
在IO很多的時候欢瞪,定時任務可能會快速增加,為了減少millisecond_heap的長度徐裸,可以將這個超時時間取整到如10毫秒甚至100毫秒遣鼓。
JoinAction的實現(xiàn)原理
主要依賴的數(shù)據(jù)結(jié)構(gòu)是waiting_map。這是一個字典結(jié)構(gòu)重贺,鍵是協(xié)程id骑祟,值是等待這個協(xié)程的所有協(xié)程列表,這是一個字典結(jié)構(gòu)气笙,鍵是協(xié)程id次企,值是定時任務。
示例如下潜圃。
{
c1 => { waiter1 => timeout_item1, waiter2 => timeout_item2 },
c2 => { waiter3 => timeout_item3, waiter4 => timeout_item4 }
}
waiter1 和 waiter2 都在等待協(xié)程c1缸棵,并分別設(shè)有超時任務。
當協(xié)程c1退出時谭期,遍歷c1對應的等待列表堵第,喚醒所有等待協(xié)程,刪除超時 任務崇堵。