- 在 config.yaml 中添加 hotwords 配置项,支持本地管理热词列表
- 实现热词解析、格式化和表名生成工具(internal/asr/hotwords.go)
- 在 ASR 连接建立时自动将热词发送给豆包(boosting_table_name 参数)
- 支持热词权重配置(1-10,默认 4),格式:"词|权重" 或 "词"
- 支持配置热重载,修改热词后新连接自动生效
- 为未来动态热词功能预留扩展接口
热词格式示例:
hotwords:
- 张三|8
- VoicePaste|10
- 人工智能|6
193 lines
4.8 KiB
Go
193 lines
4.8 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log/slog"
|
|
"sync"
|
|
|
|
"github.com/gofiber/contrib/v3/websocket"
|
|
"github.com/gofiber/fiber/v3"
|
|
)
|
|
|
|
// session holds the state for a single WebSocket connection.
|
|
type session struct {
|
|
conn *websocket.Conn
|
|
log *slog.Logger
|
|
resultCh chan ServerMsg
|
|
previewMu sync.Mutex
|
|
previewText string
|
|
sendAudio func([]byte)
|
|
cleanup func()
|
|
active bool
|
|
}
|
|
// PasteFunc is called when the server should paste text into the focused app.
|
|
type PasteFunc func(text string) error
|
|
|
|
// ASRFactory creates an ASR session. It returns a channel that receives
|
|
// partial/final results, and a function to send audio frames.
|
|
// The cleanup function must be called when the session ends.
|
|
type ASRFactory func(resultCh chan<- ServerMsg) (sendAudio func(pcm []byte), cleanup func(), err error)
|
|
|
|
// Handler holds dependencies for WebSocket connections.
|
|
type Handler struct {
|
|
token string
|
|
pasteFunc PasteFunc
|
|
asrFactory ASRFactory
|
|
}
|
|
|
|
// NewHandler creates a WS handler with the given dependencies.
|
|
func NewHandler(token string, pasteFn PasteFunc, asrFn ASRFactory) *Handler {
|
|
return &Handler{
|
|
token: token,
|
|
pasteFunc: pasteFn,
|
|
asrFactory: asrFn,
|
|
}
|
|
}
|
|
|
|
// Register adds the /ws route to the Fiber app.
|
|
func (h *Handler) Register(app *fiber.App) {
|
|
// Token check middleware (before upgrade)
|
|
app.Use("/ws", func(c fiber.Ctx) error {
|
|
if !websocket.IsWebSocketUpgrade(c) {
|
|
return fiber.ErrUpgradeRequired
|
|
}
|
|
if h.token != "" {
|
|
q := c.Query("token")
|
|
if q != h.token {
|
|
return c.Status(fiber.StatusUnauthorized).SendString("invalid token")
|
|
}
|
|
}
|
|
return c.Next()
|
|
})
|
|
|
|
app.Get("/ws", websocket.New(h.handleConn))
|
|
}
|
|
|
|
func (h *Handler) handleConn(c *websocket.Conn) {
|
|
sess := &session{
|
|
conn: c,
|
|
log: slog.With("remote", c.RemoteAddr().String()),
|
|
resultCh: make(chan ServerMsg, 32),
|
|
}
|
|
sess.log.Info("ws connected")
|
|
defer sess.log.Info("ws disconnected")
|
|
defer close(sess.resultCh)
|
|
defer sess.cleanupASR()
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go sess.writerLoop(&wg)
|
|
defer wg.Wait()
|
|
for {
|
|
mt, data, err := c.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
if mt == websocket.BinaryMessage {
|
|
sess.handleAudioFrame(data)
|
|
} else if mt == websocket.TextMessage {
|
|
h.handleTextMessage(sess, data)
|
|
}
|
|
}
|
|
}
|
|
func (s *session) writerLoop(wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
for msg := range s.resultCh {
|
|
if msg.Type == MsgPartial || msg.Type == MsgFinal {
|
|
s.previewMu.Lock()
|
|
s.previewText = msg.Text
|
|
preview := ServerMsg{Type: msg.Type, Text: s.previewText}
|
|
s.previewMu.Unlock()
|
|
if err := s.conn.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil {
|
|
s.log.Warn("ws write error", "err", err)
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if err := s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
|
|
s.log.Warn("ws write error", "err", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
func (s *session) handleAudioFrame(data []byte) {
|
|
if s.active && s.sendAudio != nil {
|
|
s.sendAudio(data)
|
|
}
|
|
}
|
|
func (h *Handler) handleTextMessage(s *session, data []byte) {
|
|
var msg ClientMsg
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
s.log.Warn("invalid json", "err", err)
|
|
return
|
|
}
|
|
switch msg.Type {
|
|
case MsgStart:
|
|
h.handleStart(s)
|
|
case MsgStop:
|
|
h.handleStop(s)
|
|
case MsgPaste:
|
|
h.handlePaste(s, msg.Text)
|
|
}
|
|
}
|
|
func (h *Handler) handleStart(s *session) {
|
|
if s.active {
|
|
return
|
|
}
|
|
// Future extension: support runtime dynamic hotwords (Phase 2)
|
|
// if len(msg.Hotwords) > 0 {
|
|
// // Priority: runtime hotwords > config.yaml hotwords
|
|
// // Need to modify asrFactory signature to pass msg.Hotwords
|
|
// }
|
|
s.previewMu.Lock()
|
|
s.previewText = ""
|
|
s.previewMu.Unlock()
|
|
sa, cl, err := h.asrFactory(s.resultCh)
|
|
if err != nil {
|
|
s.log.Error("asr start failed", "err", err)
|
|
s.resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"}
|
|
return
|
|
}
|
|
s.sendAudio = sa
|
|
s.cleanup = cl
|
|
s.active = true
|
|
s.log.Info("recording started")
|
|
}
|
|
func (h *Handler) handleStop(s *session) {
|
|
if !s.active {
|
|
return
|
|
}
|
|
s.cleanupASR()
|
|
s.sendAudio = nil
|
|
s.active = false
|
|
s.previewMu.Lock()
|
|
finalText := s.previewText
|
|
s.previewText = ""
|
|
s.previewMu.Unlock()
|
|
if finalText != "" && h.pasteFunc != nil {
|
|
if err := h.pasteFunc(finalText); err != nil {
|
|
s.log.Error("auto-paste failed", "err", err)
|
|
} else {
|
|
s.resultCh <- ServerMsg{Type: MsgPasted}
|
|
}
|
|
}
|
|
s.log.Info("recording stopped")
|
|
}
|
|
func (h *Handler) handlePaste(s *session, text string) {
|
|
if text == "" {
|
|
return
|
|
}
|
|
if h.pasteFunc != nil {
|
|
if err := h.pasteFunc(text); err != nil {
|
|
s.log.Error("paste failed", "err", err)
|
|
s.resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
|
|
} else {
|
|
s.resultCh <- ServerMsg{Type: MsgPasted}
|
|
}
|
|
}
|
|
}
|
|
func (s *session) cleanupASR() {
|
|
if s.cleanup != nil {
|
|
s.cleanup()
|
|
s.cleanup = nil
|
|
}
|
|
} |