Files
voicepaste/internal/ws/handler.go

160 lines
4.0 KiB
Go

package ws
import (
"encoding/json"
"log/slog"
"sync"
"github.com/gofiber/contrib/v3/websocket"
"github.com/gofiber/fiber/v3"
)
// PasteFunc is called when the server should paste text into the focused app.
type PasteFunc func(text string) error
// ASRFactory creates an ASR session. It returns a channel that receives
// partial/final results, and a function to send audio frames.
// The cleanup function must be called when the session ends.
type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error)
// Handler holds dependencies for WebSocket connections.
type Handler struct {
token string
pasteFunc PasteFunc
asrFactory ASRFactory
}
// NewHandler creates a WS handler with the given dependencies.
func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
return &Handler{
token: token,
pasteFunc: pasteFn,
asrFactory: asrFn,
}
}
// Register adds the /ws route to the Fiber app.
func (h *Handler) Register(app *fiber.App) {
// Token check middleware (before upgrade)
app.Use("/ws", func(c fiber.Ctx) error {
if !websocket.IsWebSocketUpgrade(c) {
return fiber.ErrUpgradeRequired
}
if h.token != "" {
q := c.Query("token")
if q != h.token {
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
}
}
return c.Next()
})
app.Get("/ws", websocket.New(h.handleConn))
}
func (h *Handler) handleConn(c *websocket.Conn) {
log := slog.With("remote", c.RemoteAddr().String())
log.Info("ws connected")
defer log.Info("ws disconnected")
// Result channel for ASR → phone
resultCh := make(chan ServerMsg, 32)
defer close(resultCh)
// Writer goroutine: single writer to avoid concurrent writes
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for msg := range resultCh {
if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
log.Warn("ws write error", "err", err)
return
}
// Auto-paste on final result
if msg.Type == MsgFinal && msg.Text != "" && h.pasteFunc != nil {
if err := h.pasteFunc(msg.Text); err != nil {
log.Error("auto-paste failed", "err", err)
} else {
_ = c.WriteMessage(websocket.TextMessage, ServerMsg{Type: MsgPasted}.Bytes())
}
}
}
}()
// ASR session state
var (
sendAudio func([]byte)
cleanup func()
active bool
)
defer func() {
if cleanup != nil {
cleanup()
}
wg.Wait()
}()
for {
mt, data, err := c.ReadMessage()
if err != nil {
break
}
switch mt {
case websocket.BinaryMessage:
// Audio frame
if active && sendAudio != nil {
sendAudio(data)
}
case websocket.TextMessage:
var msg ClientMsg
if err := json.Unmarshal(data, &msg); err != nil {
log.Warn("invalid json", "err", err)
continue
}
switch msg.Type {
case MsgStart:
if active {
continue
}
sa, cl, err := h.asrFactory(resultCh)
if err != nil {
log.Error("asr start failed", "err", err)
resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"}
continue
}
sendAudio = sa
cleanup = cl
active = true
log.Info("recording started")
case MsgStop:
if !active {
continue
}
if cleanup != nil {
cleanup()
cleanup = nil
}
sendAudio = nil
active = false
log.Info("recording stopped")
case MsgPaste:
if msg.Text == "" {
continue
}
if h.pasteFunc != nil {
if err := h.pasteFunc(msg.Text); err != nil {
log.Error("paste failed", "err", err)
resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
} else {
resultCh <- ServerMsg{Type: MsgPasted}
}
}
}
}
}
}