From 39e56e5acc85902a117faa7268fcfb74aa01cbcb Mon Sep 17 00:00:00 2001 From: imbytecat Date: Sun, 1 Mar 2026 03:03:34 +0800 Subject: [PATCH] feat: add WebSocket handler with token auth and session management --- internal/ws/handler.go | 160 ++++++++++++++++++++++++++++++++++++++++ internal/ws/protocol.go | 41 ++++++++++ 2 files changed, 201 insertions(+) create mode 100644 internal/ws/handler.go create mode 100644 internal/ws/protocol.go diff --git a/internal/ws/handler.go b/internal/ws/handler.go new file mode 100644 index 0000000..3130c70 --- /dev/null +++ b/internal/ws/handler.go @@ -0,0 +1,160 @@ +package ws + +import ( + "encoding/json" + "log/slog" + "sync" + + "github.com/gofiber/contrib/v3/websocket" + "github.com/gofiber/fiber/v3" +) + +// PasteFunc is called when the server should paste text into the focused app. +type PasteFunc func(text string) error + +// ASRFactory creates an ASR session. It returns a channel that receives +// partial/final results, and a function to send audio frames. +// The cleanup function must be called when the session ends. +type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error) + +// Handler holds dependencies for WebSocket connections. +type Handler struct { + token string + pasteFunc PasteFunc + asrFactory ASRFactory +} + +// NewHandler creates a WS handler with the given dependencies. +func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { + return &Handler{ + token: token, + pasteFunc: pasteFn, + asrFactory: asrFn, + } +} + +// Register adds the /ws route to the Fiber app. +func (h *Handler) Register(app *fiber.App) { + // Token check middleware (before upgrade) + app.Use("/ws", func(c fiber.Ctx) error { + if !websocket.IsWebSocketUpgrade(c) { + return fiber.ErrUpgradeRequired + } + if h.token != "" { + q := c.Query("token") + if q != h.token { + return c.Status(fiber.StatusUnauthorized).SendString("invalid token") + } + } + return c.Next() + }) + + app.Get("/ws", websocket.New(h.handleConn)) +} + +func (h *Handler) handleConn(c *websocket.Conn) { + log := slog.With("remote", c.RemoteAddr().String()) + log.Info("ws connected") + defer log.Info("ws disconnected") + + // Result channel for ASR → phone + resultCh := make(chan ServerMsg, 32) + defer close(resultCh) + + // Writer goroutine: single writer to avoid concurrent writes + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for msg := range resultCh { + if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil { + log.Warn("ws write error", "err", err) + return + } + // Auto-paste on final result + if msg.Type == MsgFinal && msg.Text != "" && h.pasteFunc != nil { + if err := h.pasteFunc(msg.Text); err != nil { + log.Error("auto-paste failed", "err", err) + } else { + _ = c.WriteMessage(websocket.TextMessage, ServerMsg{Type: MsgPasted}.Bytes()) + } + } + } + }() + + // ASR session state + var ( + sendAudio func([]byte) + cleanup func() + active bool + ) + defer func() { + if cleanup != nil { + cleanup() + } + wg.Wait() + }() + + for { + mt, data, err := c.ReadMessage() + if err != nil { + break + } + + switch mt { + case websocket.BinaryMessage: + // Audio frame + if active && sendAudio != nil { + sendAudio(data) + } + + case websocket.TextMessage: + var msg ClientMsg + if err := json.Unmarshal(data, &msg); err != nil { + log.Warn("invalid json", "err", err) + continue + } + switch msg.Type { + case MsgStart: + if active { + continue + } + sa, cl, err := h.asrFactory(resultCh) + if err != nil { + log.Error("asr start failed", "err", err) + resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"} + continue + } + sendAudio = sa + cleanup = cl + active = true + log.Info("recording started") + + case MsgStop: + if !active { + continue + } + if cleanup != nil { + cleanup() + cleanup = nil + } + sendAudio = nil + active = false + log.Info("recording stopped") + + case MsgPaste: + if msg.Text == "" { + continue + } + if h.pasteFunc != nil { + if err := h.pasteFunc(msg.Text); err != nil { + log.Error("paste failed", "err", err) + resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"} + } else { + resultCh <- ServerMsg{Type: MsgPasted} + } + } + } + } + } +} \ No newline at end of file diff --git a/internal/ws/protocol.go b/internal/ws/protocol.go new file mode 100644 index 0000000..45826c9 --- /dev/null +++ b/internal/ws/protocol.go @@ -0,0 +1,41 @@ +package ws + +import "encoding/json" + +// ── Client → Server messages ── + +// MsgType enumerates known message types. +type MsgType string + +const ( + MsgStart MsgType = "start" // Begin recording session + MsgStop MsgType = "stop" // End recording session + MsgPaste MsgType = "paste" // Re-paste a history item +) + +// ClientMsg is a JSON control message from the phone. +type ClientMsg struct { + Type MsgType `json:"type"` + Text string `json:"text,omitempty"` // Only for "paste" +} + +// ── Server → Client messages ── + +const ( + MsgPartial MsgType = "partial" // Interim ASR result + MsgFinal MsgType = "final" // Final ASR result + MsgPasted MsgType = "pasted" // Paste confirmed + MsgError MsgType = "error" // Error notification +) + +// ServerMsg is a JSON message sent to the phone. +type ServerMsg struct { + Type MsgType `json:"type"` + Text string `json:"text,omitempty"` + Message string `json:"message,omitempty"` // For errors +} + +func (m ServerMsg) Bytes() []byte { + b, _ := json.Marshal(m) + return b +}