feat: add WebSocket handler with token auth and session management
This commit is contained in:
160
internal/ws/handler.go
Normal file
160
internal/ws/handler.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gofiber/contrib/v3/websocket"
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PasteFunc is called when the server should paste text into the focused app.
|
||||||
|
type PasteFunc func(text string) error
|
||||||
|
|
||||||
|
// ASRFactory creates an ASR session. It returns a channel that receives
|
||||||
|
// partial/final results, and a function to send audio frames.
|
||||||
|
// The cleanup function must be called when the session ends.
|
||||||
|
type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error)
|
||||||
|
|
||||||
|
// Handler holds dependencies for WebSocket connections.
|
||||||
|
type Handler struct {
|
||||||
|
token string
|
||||||
|
pasteFunc PasteFunc
|
||||||
|
asrFactory ASRFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a WS handler with the given dependencies.
|
||||||
|
func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
token: token,
|
||||||
|
pasteFunc: pasteFn,
|
||||||
|
asrFactory: asrFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds the /ws route to the Fiber app.
|
||||||
|
func (h *Handler) Register(app *fiber.App) {
|
||||||
|
// Token check middleware (before upgrade)
|
||||||
|
app.Use("/ws", func(c fiber.Ctx) error {
|
||||||
|
if !websocket.IsWebSocketUpgrade(c) {
|
||||||
|
return fiber.ErrUpgradeRequired
|
||||||
|
}
|
||||||
|
if h.token != "" {
|
||||||
|
q := c.Query("token")
|
||||||
|
if q != h.token {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
app.Get("/ws", websocket.New(h.handleConn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleConn(c *websocket.Conn) {
|
||||||
|
log := slog.With("remote", c.RemoteAddr().String())
|
||||||
|
log.Info("ws connected")
|
||||||
|
defer log.Info("ws disconnected")
|
||||||
|
|
||||||
|
// Result channel for ASR → phone
|
||||||
|
resultCh := make(chan ServerMsg, 32)
|
||||||
|
defer close(resultCh)
|
||||||
|
|
||||||
|
// Writer goroutine: single writer to avoid concurrent writes
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for msg := range resultCh {
|
||||||
|
if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
|
||||||
|
log.Warn("ws write error", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Auto-paste on final result
|
||||||
|
if msg.Type == MsgFinal && msg.Text != "" && h.pasteFunc != nil {
|
||||||
|
if err := h.pasteFunc(msg.Text); err != nil {
|
||||||
|
log.Error("auto-paste failed", "err", err)
|
||||||
|
} else {
|
||||||
|
_ = c.WriteMessage(websocket.TextMessage, ServerMsg{Type: MsgPasted}.Bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ASR session state
|
||||||
|
var (
|
||||||
|
sendAudio func([]byte)
|
||||||
|
cleanup func()
|
||||||
|
active bool
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
mt, data, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch mt {
|
||||||
|
case websocket.BinaryMessage:
|
||||||
|
// Audio frame
|
||||||
|
if active && sendAudio != nil {
|
||||||
|
sendAudio(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
case websocket.TextMessage:
|
||||||
|
var msg ClientMsg
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
log.Warn("invalid json", "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case MsgStart:
|
||||||
|
if active {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sa, cl, err := h.asrFactory(resultCh)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("asr start failed", "err", err)
|
||||||
|
resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sendAudio = sa
|
||||||
|
cleanup = cl
|
||||||
|
active = true
|
||||||
|
log.Info("recording started")
|
||||||
|
|
||||||
|
case MsgStop:
|
||||||
|
if !active {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
cleanup = nil
|
||||||
|
}
|
||||||
|
sendAudio = nil
|
||||||
|
active = false
|
||||||
|
log.Info("recording stopped")
|
||||||
|
|
||||||
|
case MsgPaste:
|
||||||
|
if msg.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if h.pasteFunc != nil {
|
||||||
|
if err := h.pasteFunc(msg.Text); err != nil {
|
||||||
|
log.Error("paste failed", "err", err)
|
||||||
|
resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
|
||||||
|
} else {
|
||||||
|
resultCh <- ServerMsg{Type: MsgPasted}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
41
internal/ws/protocol.go
Normal file
41
internal/ws/protocol.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// ── Client → Server messages ──
|
||||||
|
|
||||||
|
// MsgType enumerates known message types.
|
||||||
|
type MsgType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgStart MsgType = "start" // Begin recording session
|
||||||
|
MsgStop MsgType = "stop" // End recording session
|
||||||
|
MsgPaste MsgType = "paste" // Re-paste a history item
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientMsg is a JSON control message from the phone.
|
||||||
|
type ClientMsg struct {
|
||||||
|
Type MsgType `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"` // Only for "paste"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Server → Client messages ──
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgPartial MsgType = "partial" // Interim ASR result
|
||||||
|
MsgFinal MsgType = "final" // Final ASR result
|
||||||
|
MsgPasted MsgType = "pasted" // Paste confirmed
|
||||||
|
MsgError MsgType = "error" // Error notification
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerMsg is a JSON message sent to the phone.
|
||||||
|
type ServerMsg struct {
|
||||||
|
Type MsgType `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"` // For errors
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ServerMsg) Bytes() []byte {
|
||||||
|
b, _ := json.Marshal(m)
|
||||||
|
return b
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user