Files
voicepaste/internal/ws/handler.go

403 lines
8.8 KiB
Go

package ws
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/gofiber/contrib/v3/websocket"
"github.com/gofiber/fiber/v3"
"github.com/google/uuid"
"github.com/vmihailenco/msgpack/v5"
)
const wsReadTimeout = 75 * time.Second
type session struct {
conn *websocket.Conn
log *slog.Logger
resultCh chan ServerMsg
stateMu sync.Mutex
writeMu sync.Mutex
previewMu sync.Mutex
state string
sessionID string
seq int64
lastAudioSeq uint64
previewText string
sendAudio func([]byte)
cleanup func()
}
type PasteFunc func(text string) error
type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error)
type Handler struct {
token string
pasteFunc PasteFunc
asrFactory ASRFactory
}
func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
return &Handler{
token: token,
pasteFunc: pasteFn,
asrFactory: asrFn,
}
}
func (h *Handler) Register(app *fiber.App) {
app.Use("/ws", func(c fiber.Ctx) error {
if !websocket.IsWebSocketUpgrade(c) {
return fiber.ErrUpgradeRequired
}
if h.token != "" {
q := c.Query("token")
if q != h.token {
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
}
}
return c.Next()
})
app.Get("/ws", websocket.New(h.handleConn))
}
func (h *Handler) handleConn(c *websocket.Conn) {
sess := &session{
conn: c,
log: slog.With("remote", c.RemoteAddr().String()),
resultCh: make(chan ServerMsg, 64),
state: StateIdle,
sessionID: "",
seq: 1,
lastAudioSeq: 0,
}
sess.log.Info("ws connected")
defer sess.log.Info("ws disconnected")
defer close(sess.resultCh)
defer sess.forceStop()
var wg sync.WaitGroup
wg.Add(1)
go sess.writerLoop(&wg)
defer wg.Wait()
_ = sess.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"})
_ = sess.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
for {
_ = c.SetReadDeadline(time.Now().Add(wsReadTimeout))
mt, data, err := c.ReadMessage()
if err != nil {
break
}
switch mt {
case websocket.BinaryMessage:
sess.handleAudioFrame(data)
case websocket.TextMessage:
h.handleTextMessage(sess, data)
}
}
}
func (s *session) writerLoop(wg *sync.WaitGroup) {
defer wg.Done()
for msg := range s.resultCh {
if msg.Type == MsgPartial || msg.Type == MsgFinal {
s.previewMu.Lock()
s.previewText = msg.Text
s.previewMu.Unlock()
s.stateMu.Lock()
msg.SessionID = s.sessionID
msg.Seq = s.seq
s.seq++
s.stateMu.Unlock()
}
if err := s.writeJSON(msg); err != nil {
s.log.Warn("ws write error", "err", err)
return
}
}
}
func (s *session) writeJSON(msg ServerMsg) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
return s.conn.WriteMessage(websocket.TextMessage, msg.Bytes())
}
func (s *session) handleAudioFrame(data []byte) {
packet, err := decodeAudioPacket(data)
if err != nil {
s.log.Warn("invalid audio packet", "err", err)
return
}
s.stateMu.Lock()
if s.state != StateRecording || s.sendAudio == nil {
s.stateMu.Unlock()
return
}
if packet.SessionID != s.sessionID {
s.stateMu.Unlock()
return
}
if packet.Seq <= s.lastAudioSeq {
s.stateMu.Unlock()
return
}
lastSeq := s.lastAudioSeq
if lastSeq > 0 && packet.Seq > lastSeq+1 {
s.log.Warn(
"audio seq gap",
"session_id",
s.sessionID,
"last",
lastSeq,
"curr",
packet.Seq,
)
}
s.lastAudioSeq = packet.Seq
sendAudio := s.sendAudio
s.stateMu.Unlock()
sendAudio(packet.PCM)
}
type audioPacket struct {
Version int `msgpack:"v"`
SessionID string `msgpack:"sessionId"`
Seq uint64 `msgpack:"seq"`
PCM []byte `msgpack:"pcm"`
}
func decodeAudioPacket(data []byte) (audioPacket, error) {
var packet audioPacket
if err := msgpack.Unmarshal(data, &packet); err != nil {
return audioPacket{}, fmt.Errorf("msgpack decode failed: %w", err)
}
if packet.Version != 1 {
return audioPacket{}, fmt.Errorf("unsupported version: %d", packet.Version)
}
if packet.SessionID == "" || len(packet.SessionID) > 96 {
return audioPacket{}, fmt.Errorf("invalid session id")
}
if packet.Seq == 0 {
return audioPacket{}, fmt.Errorf("invalid seq")
}
if len(packet.PCM) == 0 || len(packet.PCM)%2 != 0 {
return audioPacket{}, fmt.Errorf("invalid pcm size")
}
return packet, nil
}
func (h *Handler) handleTextMessage(s *session, data []byte) {
var msg ClientMsg
if err := json.Unmarshal(data, &msg); err != nil {
s.log.Warn("invalid json", "err", err)
return
}
switch msg.Type {
case MsgHello:
_ = s.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"})
s.stateMu.Lock()
state := s.state
sid := s.sessionID
s.stateMu.Unlock()
_ = s.writeJSON(ServerMsg{Type: MsgState, State: state, SessionID: sid})
case MsgPing:
_ = s.writeJSON(ServerMsg{Type: MsgPong, TS: msg.TS})
case MsgStart:
h.handleStart(s, msg)
case MsgStop:
h.handleStop(s, msg)
case MsgPaste:
h.handlePaste(s, msg.Text, msg.SessionID)
default:
h.sendError(s, "bad_message", "unsupported message type", true, "")
}
}
func (h *Handler) handleStart(s *session, msg ClientMsg) {
if msg.SessionID == "" {
msg.SessionID = uuid.NewString()
}
if len(msg.SessionID) > 96 {
h.sendError(s, "invalid_session", "invalid sessionId", false, "")
return
}
s.stateMu.Lock()
if s.state != StateIdle {
current := s.sessionID
s.stateMu.Unlock()
h.sendError(s, "busy", "session is not idle", true, current)
return
}
s.state = StateRecording
s.sessionID = msg.SessionID
s.seq = 1
s.lastAudioSeq = 0
s.sendAudio = nil
s.cleanup = nil
s.stateMu.Unlock()
s.previewMu.Lock()
s.previewText = ""
s.previewMu.Unlock()
sa, cl, err := h.asrFactory(s.resultCh)
if err != nil {
s.stateMu.Lock()
s.state = StateIdle
s.sessionID = ""
s.sendAudio = nil
s.cleanup = nil
s.stateMu.Unlock()
h.sendError(s, "start_failed", "ASR start failed", true, "")
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
return
}
s.stateMu.Lock()
if s.sessionID != msg.SessionID || s.state != StateRecording {
s.stateMu.Unlock()
cl()
return
}
s.sendAudio = sa
s.cleanup = cl
s.stateMu.Unlock()
_ = s.writeJSON(ServerMsg{Type: MsgStartAck, SessionID: msg.SessionID})
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateRecording, SessionID: msg.SessionID})
s.log.Info("recording started", "session_id", msg.SessionID)
}
func (h *Handler) handleStop(s *session, msg ClientMsg) {
s.stateMu.Lock()
if s.state == StateIdle {
s.stateMu.Unlock()
return
}
if s.state == StateStopping {
sid := s.sessionID
s.stateMu.Unlock()
_ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid})
return
}
if msg.SessionID != "" && msg.SessionID != s.sessionID {
current := s.sessionID
s.stateMu.Unlock()
h.sendError(s, "session_mismatch", "stop sessionId mismatch", false, current)
return
}
s.state = StateStopping
sid := s.sessionID
s.stateMu.Unlock()
_ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid})
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateStopping, SessionID: sid})
go h.finalizeStop(s, sid)
}
func (h *Handler) finalizeStop(s *session, sid string) {
cleanup := s.detachCleanup(sid)
if cleanup != nil {
cleanup()
}
s.previewMu.Lock()
finalText := s.previewText
s.previewText = ""
s.previewMu.Unlock()
if finalText != "" && h.pasteFunc != nil {
if err := h.pasteFunc(finalText); err != nil {
h.sendError(s, "paste_failed", "auto-paste failed", true, sid)
} else {
_ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid})
}
}
s.stateMu.Lock()
if s.sessionID == sid {
s.state = StateIdle
s.sessionID = ""
s.seq = 1
s.lastAudioSeq = 0
s.sendAudio = nil
s.cleanup = nil
}
s.stateMu.Unlock()
_ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle})
s.log.Info("recording stopped", "session_id", sid)
}
func (h *Handler) handlePaste(s *session, text, sid string) {
if text == "" {
return
}
if h.pasteFunc != nil {
if err := h.pasteFunc(text); err != nil {
h.sendError(s, "paste_failed", "paste failed", true, sid)
} else {
_ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid})
}
}
}
func (h *Handler) sendError(s *session, code, message string, retryable bool, sid string) {
err := s.writeJSON(ServerMsg{
Type: MsgError,
Code: code,
Message: message,
Retryable: retryable,
SessionID: sid,
})
if err != nil {
s.log.Warn("send error message failed", "err", err)
}
}
func (s *session) detachCleanup(sid string) func() {
s.stateMu.Lock()
defer s.stateMu.Unlock()
if sid != "" && s.sessionID != sid {
return nil
}
cleanup := s.cleanup
s.cleanup = nil
s.sendAudio = nil
return cleanup
}
func (s *session) forceStop() {
cleanup := s.detachCleanup("")
if cleanup != nil {
cleanup()
}
s.stateMu.Lock()
s.state = StateIdle
s.sessionID = ""
s.seq = 1
s.lastAudioSeq = 0
s.stateMu.Unlock()
}