寫在前面
SSE是LLM進行流式通信常用的技術(shù)方案, 下圖是 kimi 的示例
SSE 簡介
Server-Sent Events(SSE)是一種允許服務(wù)器向客戶端實時推送數(shù)據(jù)的技術(shù)平窘。它基于HTTP協(xié)議叽粹,允許服務(wù)器通過一個持久的HTTP連接向客戶端發(fā)送事件流币砂。以下是SSE的一些關(guān)鍵點:
SSE的本質(zhì):SSE利用HTTP協(xié)議的流信息(streaming)特性,實現(xiàn)服務(wù)器向客戶端的單向通信笼踩。客戶端保持連接打開,等待服務(wù)器發(fā)送新的數(shù)據(jù)流。
-
SSE的特點:
- 使用HTTP協(xié)議冕碟,現(xiàn)有的服務(wù)器軟件都支持。
- 輕量級匆浙,使用簡單安寺,與WebSocket相比,協(xié)議相對簡單首尼。
- 默認支持斷線重連挑庶,而WebSocket需要自己實現(xiàn)。
- 一般只用來傳送文本數(shù)據(jù)软能,二進制數(shù)據(jù)需要編碼后傳送迎捺。
- 支持自定義發(fā)送的消息類型。
-
客戶端API:
-
EventSource
對象用于創(chuàng)建與服務(wù)器的連接并接收事件查排。 - 通過監(jiān)聽
message
事件接收服務(wù)器發(fā)送的消息凳枝。 - 可以監(jiān)聽自定義事件,不僅限于
message
事件跋核。
-
-
服務(wù)器端發(fā)送事件:
- 服務(wù)器端腳本需要使用
text/event-stream
MIME類型響應(yīng)內(nèi)容岖瑰。 - 每個通知以文本塊形式發(fā)送叛买,并以一對換行符結(jié)尾。
- 消息由字段組成锭环,包括
event
聪全、data
、id
和retry
等辅辩。
- 服務(wù)器端腳本需要使用
-
事件流格式:
- 事件流是一個簡單的文本數(shù)據(jù)流难礼,使用UTF-8編碼。
- 消息由一對換行符分開玫锋,以冒號開頭的行為注釋行蛾茉,會被忽略。
- 每條消息由一行或多行文字組成撩鹿,列出該消息的字段谦炬。
-
瀏覽器兼容性:
- SSE在現(xiàn)代瀏覽器中得到了廣泛支持,除了IE/Edge外节沦,其他瀏覽器如Firefox键思、Chrome、Safari等都支持SSE甫贯。
SSE適用于需要服務(wù)器向客戶端單向?qū)崟r推送數(shù)據(jù)的場景吼鳞,如實時通知、股票行情叫搁、新聞推送等赔桌。它是一種有效降低服務(wù)器負載和網(wǎng)絡(luò)資源消耗的技術(shù),通過服務(wù)器主動向客戶端發(fā)送更新事件渴逻,實現(xiàn)實時通信疾党。
py 中使用 SSE
- py 中異步:
async + await
- py 中流式接收 SSE:
httpx
包- py 中流式返回 SSE:
from fastapi.responses import StreamingResponse as FastapiStreamingResponse
- 路由定義
@router.post("/stream", tags=["chat"])
async def streaming_chat(
params: QuestionParams, current_user: TokenData = Depends(get_current_user)
):
if not params.user_id:
params.user_id = current_user.uid
async_generator = RetrievalController().stream_answer(params)
return StreamingResponse(async_generator)
- 流式輸出定義
from typing import Mapping
from fastapi.responses import StreamingResponse as FastapiStreamingResponse
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
class StreamingResponse(FastapiStreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
default_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
default_headers.update(headers or {})
super().__init__(content, status_code, default_headers, media_type, background)
- 流式接收并流式返回
@LogDecorate(
func_name="retrieval_controller::process_stream_answer", raise_exc=True
)
async def stream_answer(self, params: QuestionParams, model: int = 1):
"""
:param model: 1-8B 2-32B
"""
session_id = params.session_id
if params.new_session:
session_id = str(uuid.uuid1()).replace("-", "")
request_body = dict(
messages=msgs,
user_id=params.user_id,
)
stream_answer_api = f"{AI_DOMAIN}{STREAM_ANSWER_API}"
answer = ""
# 流式接收
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
stream_answer_api,
json=request_body,
timeout=60,
headers=dict(trace_id=get_req_ctx("trace_id")),
) as response:
async for chunk in response.aiter_text():
answer += chunk
yield self.get_yield_data(
{"content": chunk, "create_at": int(time.time() * 1000)}
)
yield self.get_yield_data("[DONE]")
yield self.get_yield_data({"session_id": session_id})
yield self.get_yield_data("[END]")
# 落庫
await user_qa_dao.save_user_qa(params.q, answer, session_id, params.user_id)
Go中使用SSE
使用
https://github.com/hertz-contrib/sse
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/hlog"
"github.com/google/uuid"
"github.com/hertz-contrib/sse"
"github.com/spf13/cast"
)
func ChatStream(ctx context.Context, c *app.RequestContext) {
u := ctl.CtxUser(c)
var req struct {
Query string `form:"query" json:"query"`
Model int `form:"model" json:"model"`
Sid string `form:"sid" json:"sid"` // session id
}
if err := c.BindAndValidate(&req); err != nil {
utils.RespErr(c, err)
return
}
// 聊天消息支持多輪對話
var sid string
if req.Sid != "" {
sid = req.Sid
} else {
sid = uuid.New().String()
}
msg := chat.SaveUserMsg(ctx, sid, req.Query)
content := &chat.Content{
Messages: msg,
UserId: cast.ToString(u.ID),
UserName: u.Name,
}
b, _ := json.Marshal(content)
// https://github.com/hertz-contrib/sse/blob/main/examples/client/quickstart/main.go
cli := sse.NewClient(conf.GetConf().Dev.AIDomain + "xxx")
cli.SetMethod("POST")
cli.SetHeaders(map[string]string{"Content-Type": "application/json", "trace_id": httpx.TraceId()})
cli.SetBody(b)
var ans, allAns string // AI 返回內(nèi)容
var flag bool // reply正文標識
events := make(chan *sse.Event)
errChan := make(chan error)
s := sse.NewStream(c)
go func() {
cErr := cli.Subscribe(func(msg *sse.Event) {
if msg != nil && msg.Data != nil {
events <- msg
return
}
})
errChan <- cErr
}()
for {
select {
case e := <-events:
m := map[string]any{}
_ = json.Unmarshal(e.Data, &m)
if v, ok := m["content"]; ok {
allAns += v.(string)
if flag {
ans += v.(string)
}
if v == "__REPLY_START__" {
flag = true
}
da := map[string]any{
"content": v,
"create_at": time.Now().Unix(),
}
jsonData, _ := json.Marshal(da)
hlog.Info("publish event data = %s", string(jsonData))
_ = s.Publish(&sse.Event{Data: jsonData})
} else {
hlog.Info("invalid event data = %s", string(e.Data))
}
case err := <-errChan:
if err != nil {
hlog.CtxErrorf(context.Background(), "err = %s", err.Error())
}
chat.SaveAssistantMsg(ctx, sid, ans, msg)
chat.SaveQA(u.ID, sid, req.Query, allAns)
_ = s.Publish(&sse.Event{Data: []byte("[DONE]")})
_ = s.Publish(&sse.Event{Data: []byte(fmt.Sprintf(`{"session_id": "%s"}`, sid))})
_ = s.Publish(&sse.Event{Data: []byte("[END]")})
hlog.Info("cli get all event")
return
}
}
}
寫在最后
需要注意的點
- py 使用
httpx
接收 SSE 流式數(shù)據(jù), 對數(shù)據(jù)結(jié)構(gòu)沒有要求, 比如 SSE event 常見的data: xxx
, 可以不帶data
標識返回 - go 中使用
https://github.com/hertz-contrib/sse
接收 SSE 流式數(shù)據(jù)- 底層會解析 SSE 數(shù)據(jù)格式, 需要判斷
data
標識, 如果沒有, 會導致解析失敗 - 如果數(shù)據(jù)包含
\n
換行, 也會導致數(shù)據(jù)解析失敗, 比較簡單的做法data: json 格式數(shù)據(jù)
- 底層會解析 SSE 數(shù)據(jù)格式, 需要判斷
// go 中對應(yīng) SSE 庫數(shù)據(jù)解析源碼
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
var e Event
if len(msg) < 1 {
return nil, fmt.Errorf("event message was empty")
}
// Normalize the crlf to lf to make it easier to split the lines.
// Split the line by "\n" or "\r", per the spec.
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
switch {
case bytes.HasPrefix(line, headerID):
e.ID = string(append([]byte(nil), trimHeader(len(headerID), line)...))
case bytes.HasPrefix(line, headerData):
// The spec allows for multiple data fields per event, concatenated them with "\n".
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
e.Data = append(e.Data, byte('\n'))
case bytes.HasPrefix(line, headerEvent):
e.Event = string(append([]byte(nil), trimHeader(len(headerEvent), line)...))
case bytes.HasPrefix(line, headerRetry):
e.Retry, err = strconv.ParseUint(b2s(append([]byte(nil), trimHeader(len(headerRetry), line)...)), 10, 64)
if err != nil {
return nil, fmt.Errorf("process message `retry` failed, err is %s", err)
}
default:
// Ignore any garbage that doesn't match what we're looking for.
}
}
// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))
if c.encodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))
n, err := base64.StdEncoding.Decode(buf, e.Data)
if err != nil {
err = fmt.Errorf("failed to decode event message: %s", err)
return &e, err
}
e.Data = buf[:n]
}
return &e, err
}