From 7cf48246f27fbe3d749597bbbae9a0672bfb6a71 Mon Sep 17 00:00:00 2001 From: imbytecat Date: Fri, 6 Mar 2026 06:54:07 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E4=BC=9A=E8=AF=9D=E7=8A=B6=E6=80=81=E6=9C=BA=E5=B9=B6?= =?UTF-8?q?=E6=8E=A5=E5=85=A5=20MessagePack=20=E9=9F=B3=E9=A2=91=E8=A7=A3?= =?UTF-8?q?=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 2 + go.sum | 4 + internal/ws/handler.go | 361 ++++++++++++++++++++++++++++++++--------- 3 files changed, 289 insertions(+), 78 deletions(-) diff --git a/go.mod b/go.mod index 11b4329..5cba83d 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,8 @@ require ( github.com/vcaesar/keycode v0.10.1 // indirect github.com/vcaesar/screenshot v0.11.1 // indirect github.com/vcaesar/tt v0.20.1 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.48.0 // indirect diff --git a/go.sum b/go.sum index a193601..dc40c14 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,10 @@ github.com/vcaesar/screenshot v0.11.1 h1:GgPuN89XC4Yh38dLx4quPlSo3YiWWhwIria/j3L github.com/vcaesar/screenshot v0.11.1/go.mod h1:gJNwHBiP1v1v7i8TQ4yV1XJtcyn2I/OJL7OziVQkwjs= github.com/vcaesar/tt v0.20.1 h1:D/jUeeVCNbq3ad8M7hhtB3J9x5RZ6I1n1eZ0BJp7M+4= github.com/vcaesar/tt v0.20.1/go.mod h1:cH2+AwGAJm19Wa6xvEa+0r+sXDJBT0QgNQey6mwqLeU= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= diff --git a/internal/ws/handler.go b/internal/ws/handler.go index c1fd683..ed38eb3 100644 --- a/internal/ws/handler.go +++ b/internal/ws/handler.go @@ -2,40 +2,45 @@ package ws import ( "encoding/json" + "fmt" "log/slog" "sync" + "time" "github.com/gofiber/contrib/v3/websocket" "github.com/gofiber/fiber/v3" + "github.com/vmihailenco/msgpack/v5" ) -// session holds the state for a single WebSocket connection. +const wsReadTimeout = 75 * time.Second + type session struct { - conn *websocket.Conn - log *slog.Logger - resultCh chan ServerMsg - previewMu sync.Mutex - previewText string - sendAudio func([]byte) - cleanup func() - active bool + conn *websocket.Conn + log *slog.Logger + resultCh chan ServerMsg + + stateMu sync.Mutex + writeMu sync.Mutex + previewMu sync.Mutex + state string + sessionID string + seq int64 + lastAudioSeq uint64 + previewText string + sendAudio func([]byte) + cleanup func() } -// PasteFunc is called when the server should paste text into the focused app. + type PasteFunc func(text string) error -// ASRFactory creates an ASR session. It returns a channel that receives -// partial/final results, and a function to send audio frames. -// The cleanup function must be called when the session ends. type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error) -// Handler holds dependencies for WebSocket connections. type Handler struct { token string pasteFunc PasteFunc asrFactory ASRFactory } -// NewHandler creates a WS handler with the given dependencies. func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { return &Handler{ token: token, @@ -44,150 +49,350 @@ func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { } } -// Register adds the /ws route to the Fiber app. func (h *Handler) Register(app *fiber.App) { - // Token check middleware (before upgrade) - app.Use("/ws", func(c fiber.Ctx) error { - if !websocket.IsWebSocketUpgrade(c) { - return fiber.ErrUpgradeRequired - } - if h.token != "" { - q := c.Query("token") - if q != h.token { - return c.Status(fiber.StatusUnauthorized).SendString("invalid token") - } - } - return c.Next() - }) + app.Use("/ws", func(c fiber.Ctx) error { + if !websocket.IsWebSocketUpgrade(c) { + return fiber.ErrUpgradeRequired + } + if h.token != "" { + q := c.Query("token") + if q != h.token { + return c.Status(fiber.StatusUnauthorized).SendString("invalid token") + } + } + return c.Next() + }) - app.Get("/ws", websocket.New(h.handleConn)) + app.Get("/ws", websocket.New(h.handleConn)) } func (h *Handler) handleConn(c *websocket.Conn) { sess := &session{ - conn: c, - log: slog.With("remote", c.RemoteAddr().String()), - resultCh: make(chan ServerMsg, 32), + conn: c, + log: slog.With("remote", c.RemoteAddr().String()), + resultCh: make(chan ServerMsg, 64), + state: StateIdle, + sessionID: "", + seq: 1, + lastAudioSeq: 0, } + sess.log.Info("ws connected") defer sess.log.Info("ws disconnected") defer close(sess.resultCh) - defer sess.cleanupASR() + defer sess.forceStop() + var wg sync.WaitGroup wg.Add(1) go sess.writerLoop(&wg) defer wg.Wait() + + _ = sess.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"}) + _ = sess.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) + for { + _ = c.SetReadDeadline(time.Now().Add(wsReadTimeout)) mt, data, err := c.ReadMessage() if err != nil { break } - if mt == websocket.BinaryMessage { + switch mt { + case websocket.BinaryMessage: sess.handleAudioFrame(data) - } else if mt == websocket.TextMessage { + case websocket.TextMessage: h.handleTextMessage(sess, data) } } } + func (s *session) writerLoop(wg *sync.WaitGroup) { defer wg.Done() for msg := range s.resultCh { if msg.Type == MsgPartial || msg.Type == MsgFinal { s.previewMu.Lock() s.previewText = msg.Text - preview := ServerMsg{Type: msg.Type, Text: s.previewText} s.previewMu.Unlock() - if err := s.conn.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil { - s.log.Warn("ws write error", "err", err) - return - } - continue + + s.stateMu.Lock() + msg.SessionID = s.sessionID + msg.Seq = s.seq + s.seq++ + s.stateMu.Unlock() } - if err := s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil { + + if err := s.writeJSON(msg); err != nil { s.log.Warn("ws write error", "err", err) return } } } -func (s *session) handleAudioFrame(data []byte) { - if s.active && s.sendAudio != nil { - s.sendAudio(data) - } + +func (s *session) writeJSON(msg ServerMsg) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + return s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()) } + +func (s *session) handleAudioFrame(data []byte) { + packet, err := decodeAudioPacket(data) + if err != nil { + s.log.Warn("invalid audio packet", "err", err) + return + } + + s.stateMu.Lock() + if s.state != StateRecording || s.sendAudio == nil { + s.stateMu.Unlock() + return + } + if packet.SessionID != s.sessionID { + s.stateMu.Unlock() + return + } + if packet.Seq <= s.lastAudioSeq { + s.stateMu.Unlock() + return + } + + lastSeq := s.lastAudioSeq + if lastSeq > 0 && packet.Seq > lastSeq+1 { + s.log.Warn( + "audio seq gap", + "session_id", + s.sessionID, + "last", + lastSeq, + "curr", + packet.Seq, + ) + } + s.lastAudioSeq = packet.Seq + sendAudio := s.sendAudio + s.stateMu.Unlock() + + sendAudio(packet.PCM) +} + +type audioPacket struct { + Version int `msgpack:"v"` + SessionID string `msgpack:"sessionId"` + Seq uint64 `msgpack:"seq"` + PCM []byte `msgpack:"pcm"` +} + +func decodeAudioPacket(data []byte) (audioPacket, error) { + var packet audioPacket + if err := msgpack.Unmarshal(data, &packet); err != nil { + return audioPacket{}, fmt.Errorf("msgpack decode failed: %w", err) + } + if packet.Version != 1 { + return audioPacket{}, fmt.Errorf("unsupported version: %d", packet.Version) + } + if packet.SessionID == "" || len(packet.SessionID) > 96 { + return audioPacket{}, fmt.Errorf("invalid session id") + } + if packet.Seq == 0 { + return audioPacket{}, fmt.Errorf("invalid seq") + } + if len(packet.PCM) == 0 || len(packet.PCM)%2 != 0 { + return audioPacket{}, fmt.Errorf("invalid pcm size") + } + + return packet, nil +} + func (h *Handler) handleTextMessage(s *session, data []byte) { var msg ClientMsg if err := json.Unmarshal(data, &msg); err != nil { s.log.Warn("invalid json", "err", err) return } + switch msg.Type { + case MsgHello: + _ = s.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"}) + s.stateMu.Lock() + state := s.state + sid := s.sessionID + s.stateMu.Unlock() + _ = s.writeJSON(ServerMsg{Type: MsgState, State: state, SessionID: sid}) + case MsgPing: + _ = s.writeJSON(ServerMsg{Type: MsgPong, TS: msg.TS}) case MsgStart: - h.handleStart(s) + h.handleStart(s, msg) case MsgStop: - h.handleStop(s) + h.handleStop(s, msg) case MsgPaste: - h.handlePaste(s, msg.Text) + h.handlePaste(s, msg.Text, msg.SessionID) + default: + h.sendError(s, "bad_message", "unsupported message type", true, "") } } -func (h *Handler) handleStart(s *session) { - if s.active { + +func (h *Handler) handleStart(s *session, msg ClientMsg) { + if msg.SessionID == "" || len(msg.SessionID) > 96 { + h.sendError(s, "invalid_session", "invalid sessionId", false, "") return } - // Future extension: support runtime dynamic hotwords (Phase 2) - // if len(msg.Hotwords) > 0 { - // // Priority: runtime hotwords > config.yaml hotwords - // // Need to modify asrFactory signature to pass msg.Hotwords - // } + + s.stateMu.Lock() + if s.state != StateIdle { + current := s.sessionID + s.stateMu.Unlock() + h.sendError(s, "busy", "session is not idle", true, current) + return + } + s.state = StateRecording + s.sessionID = msg.SessionID + s.seq = 1 + s.lastAudioSeq = 0 + s.sendAudio = nil + s.cleanup = nil + s.stateMu.Unlock() + s.previewMu.Lock() s.previewText = "" s.previewMu.Unlock() + sa, cl, err := h.asrFactory(s.resultCh) if err != nil { - s.log.Error("asr start failed", "err", err) - s.resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"} + s.stateMu.Lock() + s.state = StateIdle + s.sessionID = "" + s.sendAudio = nil + s.cleanup = nil + s.stateMu.Unlock() + h.sendError(s, "start_failed", "ASR start failed", true, "") + _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) + return + } + + s.stateMu.Lock() + if s.sessionID != msg.SessionID || s.state != StateRecording { + s.stateMu.Unlock() + cl() return } s.sendAudio = sa s.cleanup = cl - s.active = true - s.log.Info("recording started") + s.stateMu.Unlock() + + _ = s.writeJSON(ServerMsg{Type: MsgStartAck, SessionID: msg.SessionID}) + _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateRecording, SessionID: msg.SessionID}) + s.log.Info("recording started", "session_id", msg.SessionID) } -func (h *Handler) handleStop(s *session) { - if !s.active { + +func (h *Handler) handleStop(s *session, msg ClientMsg) { + s.stateMu.Lock() + if s.state == StateIdle { + s.stateMu.Unlock() return } - s.cleanupASR() - s.sendAudio = nil - s.active = false + if s.state == StateStopping { + sid := s.sessionID + s.stateMu.Unlock() + _ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid}) + return + } + + if msg.SessionID != "" && msg.SessionID != s.sessionID { + current := s.sessionID + s.stateMu.Unlock() + h.sendError(s, "session_mismatch", "stop sessionId mismatch", false, current) + return + } + + s.state = StateStopping + sid := s.sessionID + s.stateMu.Unlock() + + _ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid}) + _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateStopping, SessionID: sid}) + + go h.finalizeStop(s, sid) +} + +func (h *Handler) finalizeStop(s *session, sid string) { + cleanup := s.detachCleanup(sid) + if cleanup != nil { + cleanup() + } + s.previewMu.Lock() finalText := s.previewText s.previewText = "" s.previewMu.Unlock() + if finalText != "" && h.pasteFunc != nil { if err := h.pasteFunc(finalText); err != nil { - s.log.Error("auto-paste failed", "err", err) + h.sendError(s, "paste_failed", "auto-paste failed", true, sid) } else { - s.resultCh <- ServerMsg{Type: MsgPasted} + _ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid}) } } - s.log.Info("recording stopped") + + s.stateMu.Lock() + if s.sessionID == sid { + s.state = StateIdle + s.sessionID = "" + s.seq = 1 + s.lastAudioSeq = 0 + s.sendAudio = nil + s.cleanup = nil + } + s.stateMu.Unlock() + + _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) + s.log.Info("recording stopped", "session_id", sid) } -func (h *Handler) handlePaste(s *session, text string) { + +func (h *Handler) handlePaste(s *session, text, sid string) { if text == "" { return } if h.pasteFunc != nil { if err := h.pasteFunc(text); err != nil { - s.log.Error("paste failed", "err", err) - s.resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"} + h.sendError(s, "paste_failed", "paste failed", true, sid) } else { - s.resultCh <- ServerMsg{Type: MsgPasted} + _ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid}) } } } -func (s *session) cleanupASR() { - if s.cleanup != nil { - s.cleanup() - s.cleanup = nil + +func (h *Handler) sendError(s *session, code, message string, retryable bool, sid string) { + err := s.writeJSON(ServerMsg{ + Type: MsgError, + Code: code, + Message: message, + Retryable: retryable, + SessionID: sid, + }) + if err != nil { + s.log.Warn("send error message failed", "err", err) } -} \ No newline at end of file +} + +func (s *session) detachCleanup(sid string) func() { + s.stateMu.Lock() + defer s.stateMu.Unlock() + if sid != "" && s.sessionID != sid { + return nil + } + cleanup := s.cleanup + s.cleanup = nil + s.sendAudio = nil + return cleanup +} + +func (s *session) forceStop() { + cleanup := s.detachCleanup("") + if cleanup != nil { + cleanup() + } + s.stateMu.Lock() + s.state = StateIdle + s.sessionID = "" + s.seq = 1 + s.lastAudioSeq = 0 + s.stateMu.Unlock() +}