package ws import ( "encoding/json" "fmt" "log/slog" "sync" "time" "github.com/gofiber/contrib/v3/websocket" "github.com/gofiber/fiber/v3" "github.com/google/uuid" "github.com/vmihailenco/msgpack/v5" ) const wsReadTimeout = 75 * time.Second type session struct { conn *websocket.Conn log *slog.Logger resultCh chan ServerMsg stateMu sync.Mutex writeMu sync.Mutex previewMu sync.Mutex state string sessionID string seq int64 lastAudioSeq uint64 previewText string sendAudio func([]byte) cleanup func() } type PasteFunc func(text string) error type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error) type Handler struct { token string pasteFunc PasteFunc asrFactory ASRFactory } func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler { return &Handler{ token: token, pasteFunc: pasteFn, asrFactory: asrFn, } } func (h *Handler) Register(app *fiber.App) { app.Use("/ws", func(c fiber.Ctx) error { if !websocket.IsWebSocketUpgrade(c) { return fiber.ErrUpgradeRequired } if h.token != "" { q := c.Query("token") if q != h.token { return c.Status(fiber.StatusUnauthorized).SendString("invalid token") } } return c.Next() }) app.Get("/ws", websocket.New(h.handleConn)) } func (h *Handler) handleConn(c *websocket.Conn) { sess := &session{ conn: c, log: slog.With("remote", c.RemoteAddr().String()), resultCh: make(chan ServerMsg, 64), state: StateIdle, sessionID: "", seq: 1, lastAudioSeq: 0, } sess.log.Info("ws connected") defer sess.log.Info("ws disconnected") defer close(sess.resultCh) defer sess.forceStop() var wg sync.WaitGroup wg.Add(1) go sess.writerLoop(&wg) defer wg.Wait() _ = sess.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"}) _ = sess.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) for { _ = c.SetReadDeadline(time.Now().Add(wsReadTimeout)) mt, data, err := c.ReadMessage() if err != nil { break } switch mt { case websocket.BinaryMessage: sess.handleAudioFrame(data) case websocket.TextMessage: h.handleTextMessage(sess, data) } } } func (s *session) writerLoop(wg *sync.WaitGroup) { defer wg.Done() for msg := range s.resultCh { if msg.Type == MsgPartial || msg.Type == MsgFinal { s.previewMu.Lock() s.previewText = msg.Text s.previewMu.Unlock() s.stateMu.Lock() msg.SessionID = s.sessionID msg.Seq = s.seq s.seq++ s.stateMu.Unlock() } if err := s.writeJSON(msg); err != nil { s.log.Warn("ws write error", "err", err) return } } } func (s *session) writeJSON(msg ServerMsg) error { s.writeMu.Lock() defer s.writeMu.Unlock() return s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()) } func (s *session) handleAudioFrame(data []byte) { packet, err := decodeAudioPacket(data) if err != nil { s.log.Warn("invalid audio packet", "err", err) return } s.stateMu.Lock() if s.state != StateRecording || s.sendAudio == nil { s.stateMu.Unlock() return } if packet.SessionID != s.sessionID { s.stateMu.Unlock() return } if packet.Seq <= s.lastAudioSeq { s.stateMu.Unlock() return } lastSeq := s.lastAudioSeq if lastSeq > 0 && packet.Seq > lastSeq+1 { s.log.Warn( "audio seq gap", "session_id", s.sessionID, "last", lastSeq, "curr", packet.Seq, ) } s.lastAudioSeq = packet.Seq sendAudio := s.sendAudio s.stateMu.Unlock() sendAudio(packet.PCM) } type audioPacket struct { Version int `msgpack:"v"` SessionID string `msgpack:"sessionId"` Seq uint64 `msgpack:"seq"` PCM []byte `msgpack:"pcm"` } func decodeAudioPacket(data []byte) (audioPacket, error) { var packet audioPacket if err := msgpack.Unmarshal(data, &packet); err != nil { return audioPacket{}, fmt.Errorf("msgpack decode failed: %w", err) } if packet.Version != 1 { return audioPacket{}, fmt.Errorf("unsupported version: %d", packet.Version) } if packet.SessionID == "" || len(packet.SessionID) > 96 { return audioPacket{}, fmt.Errorf("invalid session id") } if packet.Seq == 0 { return audioPacket{}, fmt.Errorf("invalid seq") } if len(packet.PCM) == 0 || len(packet.PCM)%2 != 0 { return audioPacket{}, fmt.Errorf("invalid pcm size") } return packet, nil } func (h *Handler) handleTextMessage(s *session, data []byte) { var msg ClientMsg if err := json.Unmarshal(data, &msg); err != nil { s.log.Warn("invalid json", "err", err) return } switch msg.Type { case MsgHello: _ = s.writeJSON(ServerMsg{Type: MsgReady, Message: "ok"}) s.stateMu.Lock() state := s.state sid := s.sessionID s.stateMu.Unlock() _ = s.writeJSON(ServerMsg{Type: MsgState, State: state, SessionID: sid}) case MsgPing: _ = s.writeJSON(ServerMsg{Type: MsgPong, TS: msg.TS}) case MsgStart: h.handleStart(s, msg) case MsgStop: h.handleStop(s, msg) case MsgPaste: h.handlePaste(s, msg.Text, msg.SessionID) default: h.sendError(s, "bad_message", "unsupported message type", true, "") } } func (h *Handler) handleStart(s *session, msg ClientMsg) { if msg.SessionID == "" { msg.SessionID = uuid.NewString() } if len(msg.SessionID) > 96 { h.sendError(s, "invalid_session", "invalid sessionId", false, "") return } s.stateMu.Lock() if s.state != StateIdle { current := s.sessionID s.stateMu.Unlock() h.sendError(s, "busy", "session is not idle", true, current) return } s.state = StateRecording s.sessionID = msg.SessionID s.seq = 1 s.lastAudioSeq = 0 s.sendAudio = nil s.cleanup = nil s.stateMu.Unlock() s.previewMu.Lock() s.previewText = "" s.previewMu.Unlock() sa, cl, err := h.asrFactory(s.resultCh) if err != nil { s.stateMu.Lock() s.state = StateIdle s.sessionID = "" s.sendAudio = nil s.cleanup = nil s.stateMu.Unlock() h.sendError(s, "start_failed", "ASR start failed", true, "") _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) return } s.stateMu.Lock() if s.sessionID != msg.SessionID || s.state != StateRecording { s.stateMu.Unlock() cl() return } s.sendAudio = sa s.cleanup = cl s.stateMu.Unlock() _ = s.writeJSON(ServerMsg{Type: MsgStartAck, SessionID: msg.SessionID}) _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateRecording, SessionID: msg.SessionID}) s.log.Info("recording started", "session_id", msg.SessionID) } func (h *Handler) handleStop(s *session, msg ClientMsg) { s.stateMu.Lock() if s.state == StateIdle { s.stateMu.Unlock() return } if s.state == StateStopping { sid := s.sessionID s.stateMu.Unlock() _ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid}) return } if msg.SessionID != "" && msg.SessionID != s.sessionID { current := s.sessionID s.stateMu.Unlock() h.sendError(s, "session_mismatch", "stop sessionId mismatch", false, current) return } s.state = StateStopping sid := s.sessionID s.stateMu.Unlock() _ = s.writeJSON(ServerMsg{Type: MsgStopAck, SessionID: sid}) _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateStopping, SessionID: sid}) go h.finalizeStop(s, sid) } func (h *Handler) finalizeStop(s *session, sid string) { cleanup := s.detachCleanup(sid) if cleanup != nil { cleanup() } s.previewMu.Lock() finalText := s.previewText s.previewText = "" s.previewMu.Unlock() if finalText != "" && h.pasteFunc != nil { if err := h.pasteFunc(finalText); err != nil { h.sendError(s, "paste_failed", "auto-paste failed", true, sid) } else { _ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid}) } } s.stateMu.Lock() if s.sessionID == sid { s.state = StateIdle s.sessionID = "" s.seq = 1 s.lastAudioSeq = 0 s.sendAudio = nil s.cleanup = nil } s.stateMu.Unlock() _ = s.writeJSON(ServerMsg{Type: MsgState, State: StateIdle}) s.log.Info("recording stopped", "session_id", sid) } func (h *Handler) handlePaste(s *session, text, sid string) { if text == "" { return } if h.pasteFunc != nil { if err := h.pasteFunc(text); err != nil { h.sendError(s, "paste_failed", "paste failed", true, sid) } else { _ = s.writeJSON(ServerMsg{Type: MsgPasted, SessionID: sid}) } } } func (h *Handler) sendError(s *session, code, message string, retryable bool, sid string) { err := s.writeJSON(ServerMsg{ Type: MsgError, Code: code, Message: message, Retryable: retryable, SessionID: sid, }) if err != nil { s.log.Warn("send error message failed", "err", err) } } func (s *session) detachCleanup(sid string) func() { s.stateMu.Lock() defer s.stateMu.Unlock() if sid != "" && s.sessionID != sid { return nil } cleanup := s.cleanup s.cleanup = nil s.sendAudio = nil return cleanup } func (s *session) forceStop() { cleanup := s.detachCleanup("") if cleanup != nil { cleanup() } s.stateMu.Lock() s.state = StateIdle s.sessionID = "" s.seq = 1 s.lastAudioSeq = 0 s.stateMu.Unlock() }