package ws import ( "encoding/json" "log/slog" "sync" "github.com/gofiber/contrib/v3/websocket" "github.com/gofiber/fiber/v3" ) // PasteFunc is called when the server should paste text into the focused app. type PasteFunc func(text string) error // ASRFactory creates an ASR session. It returns a channel that receives // partial/final results, and a function to send audio frames. // The cleanup function must be called when the session ends. type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error) // Handler holds dependencies for WebSocket connections. type Handler struct { token string pasteFunc PasteFunc asrFactory ASRFactory } // NewHandler creates a WS handler with the given dependencies. func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { return &Handler{ token: token, pasteFunc: pasteFn, asrFactory: asrFn, } } // Register adds the /ws route to the Fiber app. func (h *Handler) Register(app *fiber.App) { // Token check middleware (before upgrade) app.Use("/ws", func(c fiber.Ctx) error { if !websocket.IsWebSocketUpgrade(c) { return fiber.ErrUpgradeRequired } if h.token != "" { q := c.Query("token") if q != h.token { return c.Status(fiber.StatusUnauthorized).SendString("invalid token") } } return c.Next() }) app.Get("/ws", websocket.New(h.handleConn)) } func (h *Handler) handleConn(c *websocket.Conn) { log := slog.With("remote", c.RemoteAddr().String()) log.Info("ws connected") defer log.Info("ws disconnected") // Result channel for ASR → phone resultCh := make(chan ServerMsg, 32) defer close(resultCh) // Writer goroutine: single writer to avoid concurrent writes // Accumulates all result texts; paste is triggered by stop, not by ASR final. var wg sync.WaitGroup var accMu sync.Mutex var accText string wg.Add(1) go func() { defer wg.Done() for msg := range resultCh { // Accumulate text from both partial and final results if msg.Type == MsgPartial || msg.Type == MsgFinal { accMu.Lock() accText += msg.Text // Send accumulated preview to phone preview := ServerMsg{Type: msg.Type, Text: accText} accMu.Unlock() if err := c.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil { log.Warn("ws write error", "err", err) return } continue } // Forward other messages (error, pasted) as-is if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil { log.Warn("ws write error", "err", err) return } } }() // ASR session state var ( sendAudio func([]byte) cleanup func() active bool ) defer func() { if cleanup != nil { cleanup() } wg.Wait() }() for { mt, data, err := c.ReadMessage() if err != nil { break } switch mt { case websocket.BinaryMessage: // Audio frame if active && sendAudio != nil { sendAudio(data) } case websocket.TextMessage: var msg ClientMsg if err := json.Unmarshal(data, &msg); err != nil { log.Warn("invalid json", "err", err) continue } switch msg.Type { case MsgStart: if active { continue } // Reset accumulated text for new session accMu.Lock() accText = "" accMu.Unlock() sa, cl, err := h.asrFactory(resultCh) if err != nil { log.Error("asr start failed", "err", err) resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"} continue } sendAudio = sa cleanup = cl active = true log.Info("recording started") case MsgStop: if !active { continue } // Finish ASR session — waits for final result from readLoop if cleanup != nil { cleanup() cleanup = nil } sendAudio = nil active = false // Now paste the accumulated text accMu.Lock() finalText := accText accText = "" accMu.Unlock() if finalText != "" && h.pasteFunc != nil { if err := h.pasteFunc(finalText); err != nil { log.Error("auto-paste failed", "err", err) } else { resultCh <- ServerMsg{Type: MsgPasted} } } log.Info("recording stopped") case MsgPaste: if msg.Text == "" { continue } if h.pasteFunc != nil { if err := h.pasteFunc(msg.Text); err != nil { log.Error("paste failed", "err", err) resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"} } else { resultCh <- ServerMsg{Type: MsgPasted} } } } } } }