Files
voicepaste/internal/asr/client.go
imbytecat 96d685fdf2 feat: 添加豆包 ASR 热词功能支持
- 在 config.yaml 中添加 hotwords 配置项,支持本地管理热词列表
- 实现热词解析、格式化和表名生成工具(internal/asr/hotwords.go)
- 在 ASR 连接建立时自动将热词发送给豆包(boosting_table_name 参数)
- 支持热词权重配置(1-10,默认 4),格式:"词|权重" 或 "词"
- 支持配置热重载,修改热词后新连接自动生效
- 为未来动态热词功能预留扩展接口

热词格式示例:
  hotwords:
    - 张三|8
    - VoicePaste|10
    - 人工智能|6
2026-03-02 00:55:37 +08:00

190 lines
4.7 KiB
Go

package asr
import (
"fmt"
"log/slog"
"net/http"
"sync"
"time"
"github.com/fasthttp/websocket"
"github.com/google/uuid"
wsMsg "github.com/imbytecat/voicepaste/internal/ws"
)
const (
doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
writeTimeout = 10 * time.Second
readTimeout = 30 * time.Second
)
// Config holds Doubao ASR connection parameters.
type Config struct {
AppID string
AccessToken string
ResourceID string
Hotwords []string // 热词列表,格式 "词|权重" 或 "词"
}
// Client manages a single ASR session with Doubao.
type Client struct {
cfg Config
conn *websocket.Conn
mu sync.Mutex
closed bool
closeCh chan struct{}
log *slog.Logger
}
// Dial connects to Doubao ASR and sends the initial FullClientRequest.
// resultCh receives partial/final results. Caller must call Close() when done.
func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
connID := uuid.New().String()
headers := http.Header{
"X-Api-App-Key": {cfg.AppID},
"X-Api-Access-Key": {cfg.AccessToken},
"X-Api-Resource-Id": {cfg.ResourceID},
"X-Api-Connect-Id": {connID},
}
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(doubaoEndpoint, headers)
if err != nil {
return nil, fmt.Errorf("dial doubao: %w", err)
}
c := &Client{
cfg: cfg,
conn: conn,
closeCh: make(chan struct{}),
log: slog.With("conn_id", connID),
}
// Parse hotwords configuration
var boostingTableName string
if len(cfg.Hotwords) > 0 {
entries, err := ParseHotwords(cfg.Hotwords)
if err != nil {
slog.Warn("invalid hotwords config, skipping", "err", err)
} else {
boostingTableName = GenerateTableName(entries)
tableContent := FormatHotwordsTable(entries)
slog.Info("hotwords enabled",
"count", len(entries),
"table_name", boostingTableName,
"content", tableContent)
}
}
// Send FullClientRequest
req := &FullClientRequest{
User: UserMeta{UID: connID},
Audio: AudioMeta{
Format: "pcm",
Codec: "raw",
Rate: 16000,
Bits: 16,
Channel: 1,
},
Request: RequestMeta{
ModelName: "seedasr-2.0",
EnableITN: true,
EnablePUNC: true,
EnableDDC: true,
ShowUtterances: true,
ResultType: "full",
EnableNonstream: true,
EndWindowSize: 800,
BoostingTableName: boostingTableName,
},
}
data, err := EncodeFullClientRequest(req)
if err != nil {
conn.Close()
return nil, fmt.Errorf("encode full request: %w", err)
}
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
conn.Close()
return nil, fmt.Errorf("send full request: %w", err)
}
// Start read loop
go c.readLoop(resultCh)
return c, nil
}
// SendAudio sends a PCM audio frame to Doubao.
func (c *Client) SendAudio(pcm []byte, last bool) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("client closed")
}
data, err := EncodeAudioFrame(pcm, last)
if err != nil {
return fmt.Errorf("encode audio: %w", err)
}
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return c.conn.WriteMessage(websocket.BinaryMessage, data)
}
// readLoop reads server responses and forwards them to resultCh.
func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) {
defer func() {
c.conn.Close()
c.mu.Lock()
c.closed = true
c.mu.Unlock()
close(c.closeCh)
}()
for {
_ = c.conn.SetReadDeadline(time.Now().Add(readTimeout))
_, data, err := c.conn.ReadMessage()
if err != nil {
c.log.Debug("asr read done", "err", err)
return
}
resp, err := ParseResponse(data)
if err != nil {
c.log.Warn("parse asr response", "err", err)
continue
}
if resp.Code != 0 {
c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg)
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg}
return
}
// bigmodel_async with enable_nonstream: returns both streaming (partial) and definite (final) results
text := resp.Text
if text != "" {
if resp.IsLast {
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text}
} else {
// Intermediate streaming result (first pass) — preview only
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text}
}
}
if resp.IsLast {
return
}
}
}
// Finish sends the last audio frame and waits for ASR to return final results.
func (c *Client) Finish() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.mu.Unlock()
_ = c.SendAudio(nil, true)
<-c.closeCh
}
// Close forcefully shuts down the ASR connection.
func (c *Client) Close() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.conn.Close()
c.mu.Unlock()
<-c.closeCh
}