refactor: 重构后端会话状态机并接入 MessagePack 音频解包
This commit is contained in:
2
go.mod
2
go.mod
@@ -54,6 +54,8 @@ require (
|
||||
github.com/vcaesar/keycode v0.10.1 // indirect
|
||||
github.com/vcaesar/screenshot v0.11.1 // indirect
|
||||
github.com/vcaesar/tt v0.20.1 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -120,6 +120,10 @@ github.com/vcaesar/screenshot v0.11.1 h1:GgPuN89XC4Yh38dLx4quPlSo3YiWWhwIria/j3L
|
||||
github.com/vcaesar/screenshot v0.11.1/go.mod h1:gJNwHBiP1v1v7i8TQ4yV1XJtcyn2I/OJL7OziVQkwjs=
|
||||
github.com/vcaesar/tt v0.20.1 h1:D/jUeeVCNbq3ad8M7hhtB3J9x5RZ6I1n1eZ0BJp7M+4=
|
||||
github.com/vcaesar/tt v0.20.1/go.mod h1:cH2+AwGAJm19Wa6xvEa+0r+sXDJBT0QgNQey6mwqLeU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
|
||||
@@ -2,40 +2,45 @@ package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// session holds the state for a single WebSocket connection.
|
||||
const wsReadTimeout = 75 * time.Second
|
||||
|
||||
type session struct {
|
||||
conn *websocket.Conn
|
||||
log *slog.Logger
|
||||
resultCh chan ServerMsg
|
||||
previewMu sync.Mutex
|
||||
previewText string
|
||||
sendAudio func([]byte)
|
||||
cleanup func()
|
||||
active bool
|
||||
conn *websocket.Conn
|
||||
log *slog.Logger
|
||||
resultCh chan ServerMsg
|
||||
|
||||
stateMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
previewMu sync.Mutex
|
||||
state string
|
||||
sessionID string
|
||||
seq int64
|
||||
lastAudioSeq uint64
|
||||
previewText string
|
||||
sendAudio func([]byte)
|
||||
cleanup func()
|
||||
}
|
||||
// PasteFunc is called when the server should paste text into the focused app.
|
||||
|
||||
type PasteFunc func(text string) error
|
||||
|
||||
// ASRFactory creates an ASR session. It returns a channel that receives
|
||||
// partial/final results, and a function to send audio frames.
|
||||
// The cleanup function must be called when the session ends.
|
||||
type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error)
|
||||
|
||||
// Handler holds dependencies for WebSocket connections.
|
||||
type Handler struct {
|
||||
token string
|
||||
pasteFunc PasteFunc
|
||||
asrFactory ASRFactory
|
||||
}
|
||||
|
||||
// NewHandler creates a WS handler with the given dependencies.
|
||||
func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
|
||||
return &Handler{
|
||||
token: token,
|
||||
@@ -44,150 +49,350 @@ func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds the /ws route to the Fiber app.
|
||||
func (h *Handler) Register(app *fiber.App) {
|
||||
// Token check middleware (before upgrade)
|
||||
app.Use("/ws", func(c fiber.Ctx) error {
|
||||
if !websocket.IsWebSocketUpgrade(c) {
|
||||
return fiber.ErrUpgradeRequired
|
||||
}
|
||||
if h.token != "" {
|
||||
q := c.Query("token")
|
||||
if q != h.token {
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
})
|
||||
app.Use("/ws", func(c fiber.Ctx) error {
|
||||
if !websocket.IsWebSocketUpgrade(c) {
|
||||
return fiber.ErrUpgradeRequired
|
||||
}
|
||||
if h.token != "" {
|
||||
q := c.Query("token")
|
||||
if q != h.token {
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
app.Get("/ws", websocket.New(h.handleConn))
|
||||
app.Get("/ws", websocket.New(h.handleConn))
|
||||
}
|
||||
|
||||
func (h *Handler) handleConn(c *websocket.Conn) {
|
||||
sess := &session{
|
||||
conn: c,
|
||||
log: slog.With("remote", c.RemoteAddr().String()),
|
||||
resultCh: make(chan ServerMsg, 32),
|
||||
conn: c,
|
||||
log: slog.With("remote", c.RemoteAddr().String()),
|
||||
resultCh: make(chan ServerMsg, 64),
|
||||
state: StateIdle,
|
||||
sessionID: "",
|
||||
seq: 1,
|
||||
lastAudioSeq: 0,
|
||||
}
|
||||
|
||||
sess.log.Info("ws connected")
|
||||
defer sess.log.Info("ws disconnected")
|
||||
defer close(sess.resultCh)
|
||||
defer sess.cleanupASR()
|
||||
defer sess.forceStop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go sess.writerLoop(&wg)
|
||||
defer wg.Wait()
|
||||
|
||||
_ = sess.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"})
|
||||
_ = sess.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
|
||||
|
||||
for {
|
||||
_ = c.SetReadDeadline(time.Now().Add(wsReadTimeout))
|
||||
mt, data, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if mt == websocket.BinaryMessage {
|
||||
switch mt {
|
||||
case websocket.BinaryMessage:
|
||||
sess.handleAudioFrame(data)
|
||||
} else if mt == websocket.TextMessage {
|
||||
case websocket.TextMessage:
|
||||
h.handleTextMessage(sess, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) writerLoop(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for msg := range s.resultCh {
|
||||
if msg.Type == MsgPartial || msg.Type == MsgFinal {
|
||||
s.previewMu.Lock()
|
||||
s.previewText = msg.Text
|
||||
preview := ServerMsg{Type: msg.Type, Text: s.previewText}
|
||||
s.previewMu.Unlock()
|
||||
if err := s.conn.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil {
|
||||
s.log.Warn("ws write error", "err", err)
|
||||
return
|
||||
}
|
||||
continue
|
||||
|
||||
s.stateMu.Lock()
|
||||
msg.SessionID = s.sessionID
|
||||
msg.Seq = s.seq
|
||||
s.seq++
|
||||
s.stateMu.Unlock()
|
||||
}
|
||||
if err := s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
|
||||
|
||||
if err := s.writeJSON(msg); err != nil {
|
||||
s.log.Warn("ws write error", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *session) handleAudioFrame(data []byte) {
|
||||
if s.active && s.sendAudio != nil {
|
||||
s.sendAudio(data)
|
||||
}
|
||||
|
||||
func (s *session) writeJSON(msg ServerMsg) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
return s.conn.WriteMessage(websocket.TextMessage, msg.Bytes())
|
||||
}
|
||||
|
||||
func (s *session) handleAudioFrame(data []byte) {
|
||||
packet, err := decodeAudioPacket(data)
|
||||
if err != nil {
|
||||
s.log.Warn("invalid audio packet", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
if s.state != StateRecording || s.sendAudio == nil {
|
||||
s.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
if packet.SessionID != s.sessionID {
|
||||
s.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
if packet.Seq <= s.lastAudioSeq {
|
||||
s.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
lastSeq := s.lastAudioSeq
|
||||
if lastSeq > 0 && packet.Seq > lastSeq+1 {
|
||||
s.log.Warn(
|
||||
"audio seq gap",
|
||||
"session_id",
|
||||
s.sessionID,
|
||||
"last",
|
||||
lastSeq,
|
||||
"curr",
|
||||
packet.Seq,
|
||||
)
|
||||
}
|
||||
s.lastAudioSeq = packet.Seq
|
||||
sendAudio := s.sendAudio
|
||||
s.stateMu.Unlock()
|
||||
|
||||
sendAudio(packet.PCM)
|
||||
}
|
||||
|
||||
type audioPacket struct {
|
||||
Version int `msgpack:"v"`
|
||||
SessionID string `msgpack:"sessionId"`
|
||||
Seq uint64 `msgpack:"seq"`
|
||||
PCM []byte `msgpack:"pcm"`
|
||||
}
|
||||
|
||||
func decodeAudioPacket(data []byte) (audioPacket, error) {
|
||||
var packet audioPacket
|
||||
if err := msgpack.Unmarshal(data, &packet); err != nil {
|
||||
return audioPacket{}, fmt.Errorf("msgpack decode failed: %w", err)
|
||||
}
|
||||
if packet.Version != 1 {
|
||||
return audioPacket{}, fmt.Errorf("unsupported version: %d", packet.Version)
|
||||
}
|
||||
if packet.SessionID == "" || len(packet.SessionID) > 96 {
|
||||
return audioPacket{}, fmt.Errorf("invalid session id")
|
||||
}
|
||||
if packet.Seq == 0 {
|
||||
return audioPacket{}, fmt.Errorf("invalid seq")
|
||||
}
|
||||
if len(packet.PCM) == 0 || len(packet.PCM)%2 != 0 {
|
||||
return audioPacket{}, fmt.Errorf("invalid pcm size")
|
||||
}
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleTextMessage(s *session, data []byte) {
|
||||
var msg ClientMsg
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
s.log.Warn("invalid json", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case MsgHello:
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"})
|
||||
s.stateMu.Lock()
|
||||
state := s.state
|
||||
sid := s.sessionID
|
||||
s.stateMu.Unlock()
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgState, State: state, SessionID: sid})
|
||||
case MsgPing:
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgPong, TS: msg.TS})
|
||||
case MsgStart:
|
||||
h.handleStart(s)
|
||||
h.handleStart(s, msg)
|
||||
case MsgStop:
|
||||
h.handleStop(s)
|
||||
h.handleStop(s, msg)
|
||||
case MsgPaste:
|
||||
h.handlePaste(s, msg.Text)
|
||||
h.handlePaste(s, msg.Text, msg.SessionID)
|
||||
default:
|
||||
h.sendError(s, "bad_message", "unsupported message type", true, "")
|
||||
}
|
||||
}
|
||||
func (h *Handler) handleStart(s *session) {
|
||||
if s.active {
|
||||
|
||||
func (h *Handler) handleStart(s *session, msg ClientMsg) {
|
||||
if msg.SessionID == "" || len(msg.SessionID) > 96 {
|
||||
h.sendError(s, "invalid_session", "invalid sessionId", false, "")
|
||||
return
|
||||
}
|
||||
// Future extension: support runtime dynamic hotwords (Phase 2)
|
||||
// if len(msg.Hotwords) > 0 {
|
||||
// // Priority: runtime hotwords > config.yaml hotwords
|
||||
// // Need to modify asrFactory signature to pass msg.Hotwords
|
||||
// }
|
||||
|
||||
s.stateMu.Lock()
|
||||
if s.state != StateIdle {
|
||||
current := s.sessionID
|
||||
s.stateMu.Unlock()
|
||||
h.sendError(s, "busy", "session is not idle", true, current)
|
||||
return
|
||||
}
|
||||
s.state = StateRecording
|
||||
s.sessionID = msg.SessionID
|
||||
s.seq = 1
|
||||
s.lastAudioSeq = 0
|
||||
s.sendAudio = nil
|
||||
s.cleanup = nil
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.previewMu.Lock()
|
||||
s.previewText = ""
|
||||
s.previewMu.Unlock()
|
||||
|
||||
sa, cl, err := h.asrFactory(s.resultCh)
|
||||
if err != nil {
|
||||
s.log.Error("asr start failed", "err", err)
|
||||
s.resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"}
|
||||
s.stateMu.Lock()
|
||||
s.state = StateIdle
|
||||
s.sessionID = ""
|
||||
s.sendAudio = nil
|
||||
s.cleanup = nil
|
||||
s.stateMu.Unlock()
|
||||
h.sendError(s, "start_failed", "ASR start failed", true, "")
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
|
||||
return
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
if s.sessionID != msg.SessionID || s.state != StateRecording {
|
||||
s.stateMu.Unlock()
|
||||
cl()
|
||||
return
|
||||
}
|
||||
s.sendAudio = sa
|
||||
s.cleanup = cl
|
||||
s.active = true
|
||||
s.log.Info("recording started")
|
||||
s.stateMu.Unlock()
|
||||
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgStartAck, SessionID: msg.SessionID})
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateRecording, SessionID: msg.SessionID})
|
||||
s.log.Info("recording started", "session_id", msg.SessionID)
|
||||
}
|
||||
func (h *Handler) handleStop(s *session) {
|
||||
if !s.active {
|
||||
|
||||
func (h *Handler) handleStop(s *session, msg ClientMsg) {
|
||||
s.stateMu.Lock()
|
||||
if s.state == StateIdle {
|
||||
s.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
s.cleanupASR()
|
||||
s.sendAudio = nil
|
||||
s.active = false
|
||||
if s.state == StateStopping {
|
||||
sid := s.sessionID
|
||||
s.stateMu.Unlock()
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid})
|
||||
return
|
||||
}
|
||||
|
||||
if msg.SessionID != "" && msg.SessionID != s.sessionID {
|
||||
current := s.sessionID
|
||||
s.stateMu.Unlock()
|
||||
h.sendError(s, "session_mismatch", "stop sessionId mismatch", false, current)
|
||||
return
|
||||
}
|
||||
|
||||
s.state = StateStopping
|
||||
sid := s.sessionID
|
||||
s.stateMu.Unlock()
|
||||
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid})
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateStopping, SessionID: sid})
|
||||
|
||||
go h.finalizeStop(s, sid)
|
||||
}
|
||||
|
||||
func (h *Handler) finalizeStop(s *session, sid string) {
|
||||
cleanup := s.detachCleanup(sid)
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
|
||||
s.previewMu.Lock()
|
||||
finalText := s.previewText
|
||||
s.previewText = ""
|
||||
s.previewMu.Unlock()
|
||||
|
||||
if finalText != "" && h.pasteFunc != nil {
|
||||
if err := h.pasteFunc(finalText); err != nil {
|
||||
s.log.Error("auto-paste failed", "err", err)
|
||||
h.sendError(s, "paste_failed", "auto-paste failed", true, sid)
|
||||
} else {
|
||||
s.resultCh <- ServerMsg{Type: MsgPasted}
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid})
|
||||
}
|
||||
}
|
||||
s.log.Info("recording stopped")
|
||||
|
||||
s.stateMu.Lock()
|
||||
if s.sessionID == sid {
|
||||
s.state = StateIdle
|
||||
s.sessionID = ""
|
||||
s.seq = 1
|
||||
s.lastAudioSeq = 0
|
||||
s.sendAudio = nil
|
||||
s.cleanup = nil
|
||||
}
|
||||
s.stateMu.Unlock()
|
||||
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
|
||||
s.log.Info("recording stopped", "session_id", sid)
|
||||
}
|
||||
func (h *Handler) handlePaste(s *session, text string) {
|
||||
|
||||
func (h *Handler) handlePaste(s *session, text, sid string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if h.pasteFunc != nil {
|
||||
if err := h.pasteFunc(text); err != nil {
|
||||
s.log.Error("paste failed", "err", err)
|
||||
s.resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
|
||||
h.sendError(s, "paste_failed", "paste failed", true, sid)
|
||||
} else {
|
||||
s.resultCh <- ServerMsg{Type: MsgPasted}
|
||||
_ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid})
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *session) cleanupASR() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
s.cleanup = nil
|
||||
|
||||
func (h *Handler) sendError(s *session, code, message string, retryable bool, sid string) {
|
||||
err := s.writeJSON(ServerMsg{
|
||||
Type: MsgError,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Retryable: retryable,
|
||||
SessionID: sid,
|
||||
})
|
||||
if err != nil {
|
||||
s.log.Warn("send error message failed", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) detachCleanup(sid string) func() {
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
if sid != "" && s.sessionID != sid {
|
||||
return nil
|
||||
}
|
||||
cleanup := s.cleanup
|
||||
s.cleanup = nil
|
||||
s.sendAudio = nil
|
||||
return cleanup
|
||||
}
|
||||
|
||||
func (s *session) forceStop() {
|
||||
cleanup := s.detachCleanup("")
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
s.stateMu.Lock()
|
||||
s.state = StateIdle
|
||||
s.sessionID = ""
|
||||
s.seq = 1
|
||||
s.lastAudioSeq = 0
|
||||
s.stateMu.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user