package ws import ( "encoding/json" "log/slog" "sync" "github.com/gofiber/contrib/v3/websocket" "github.com/gofiber/fiber/v3" ) // session holds the state for a single WebSocket connection. type session struct { conn *websocket.Conn log *slog.Logger resultCh chan ServerMsg previewMu sync.Mutex previewText string sendAudio func([]byte) cleanup func() active bool } // PasteFunc is called when the server should paste text into the focused app. type PasteFunc func(text string) error // ASRFactory creates an ASR session. It returns a channel that receives // partial/final results, and a function to send audio frames. // The cleanup function must be called when the session ends. type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error) // Handler holds dependencies for WebSocket connections. type Handler struct { token string pasteFunc PasteFunc asrFactory ASRFactory } // NewHandler creates a WS handler with the given dependencies. func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { return &Handler{ token: token, pasteFunc: pasteFn, asrFactory: asrFn, } } // Register adds the /ws route to the Fiber app. func (h *Handler) Register(app *fiber.App) { // Token check middleware (before upgrade) app.Use("/ws", func(c fiber.Ctx) error { if !websocket.IsWebSocketUpgrade(c) { return fiber.ErrUpgradeRequired } if h.token != "" { q := c.Query("token") if q != h.token { return c.Status(fiber.StatusUnauthorized).SendString("invalid token") } } return c.Next() }) app.Get("/ws", websocket.New(h.handleConn)) } func (h *Handler) handleConn(c *websocket.Conn) { sess := &session{ conn: c, log: slog.With("remote", c.RemoteAddr().String()), resultCh: make(chan ServerMsg, 32), } sess.log.Info("ws connected") defer sess.log.Info("ws disconnected") defer close(sess.resultCh) defer sess.cleanupASR() var wg sync.WaitGroup wg.Add(1) go sess.writerLoop(&wg) defer wg.Wait() for { mt, data, err := c.ReadMessage() if err != nil { break } if mt == websocket.BinaryMessage { sess.handleAudioFrame(data) } else if mt == websocket.TextMessage { h.handleTextMessage(sess, data) } } } func (s *session) writerLoop(wg *sync.WaitGroup) { defer wg.Done() for msg := range s.resultCh { if msg.Type == MsgPartial || msg.Type == MsgFinal { s.previewMu.Lock() s.previewText = msg.Text preview := ServerMsg{Type: msg.Type, Text: s.previewText} s.previewMu.Unlock() if err := s.conn.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil { s.log.Warn("ws write error", "err", err) return } continue } if err := s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil { s.log.Warn("ws write error", "err", err) return } } } func (s *session) handleAudioFrame(data []byte) { if s.active && s.sendAudio != nil { s.sendAudio(data) } } func (h *Handler) handleTextMessage(s *session, data []byte) { var msg ClientMsg if err := json.Unmarshal(data, &msg); err != nil { s.log.Warn("invalid json", "err", err) return } switch msg.Type { case MsgStart: h.handleStart(s) case MsgStop: h.handleStop(s) case MsgPaste: h.handlePaste(s, msg.Text) } } func (h *Handler) handleStart(s *session) { if s.active { return } s.previewMu.Lock() s.previewText = "" s.previewMu.Unlock() sa, cl, err := h.asrFactory(s.resultCh) if err != nil { s.log.Error("asr start failed", "err", err) s.resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"} return } s.sendAudio = sa s.cleanup = cl s.active = true s.log.Info("recording started") } func (h *Handler) handleStop(s *session) { if !s.active { return } s.cleanupASR() s.sendAudio = nil s.active = false s.previewMu.Lock() finalText := s.previewText s.previewText = "" s.previewMu.Unlock() if finalText != "" && h.pasteFunc != nil { if err := h.pasteFunc(finalText); err != nil { s.log.Error("auto-paste failed", "err", err) } else { s.resultCh <- ServerMsg{Type: MsgPasted} } } s.log.Info("recording stopped") } func (h *Handler) handlePaste(s *session, text string) { if text == "" { return } if h.pasteFunc != nil { if err := h.pasteFunc(text); err != nil { s.log.Error("paste failed", "err", err) s.resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"} } else { s.resultCh <- ServerMsg{Type: MsgPasted} } } } func (s *session) cleanupASR() { if s.cleanup != nil { s.cleanup() s.cleanup = nil } }