- 使用 corpus.context 参数直接传递热词列表(豆包文档支持)
- 移除 boosting_table_id 配置,避免绑定火山引擎控制台
- 实现 BuildHotwordsContext 函数,将本地热词转换为 JSON 格式
- 热词配置完全本地化,便于迁移到其他 ASR 平台
配置示例:
hotwords:
- 张三
- 李四
- VoicePaste
程序自动转换为豆包 API 要求的格式:
{"hotwords":[{"word":"张三"},{"word":"李四"},{"word":"VoicePaste"}]}
186 lines
4.5 KiB
Go
186 lines
4.5 KiB
Go
package asr
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fasthttp/websocket"
|
|
"github.com/google/uuid"
|
|
wsMsg "github.com/imbytecat/voicepaste/internal/ws"
|
|
)
|
|
|
|
const (
|
|
doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
|
|
writeTimeout = 10 * time.Second
|
|
readTimeout = 30 * time.Second
|
|
)
|
|
|
|
// Config holds Doubao ASR connection parameters.
|
|
type Config struct {
|
|
AppID string
|
|
AccessToken string
|
|
ResourceID string
|
|
Hotwords []string // 本地热词列表
|
|
}
|
|
|
|
// Client manages a single ASR session with Doubao.
|
|
type Client struct {
|
|
cfg Config
|
|
conn *websocket.Conn
|
|
mu sync.Mutex
|
|
closed bool
|
|
closeCh chan struct{}
|
|
log *slog.Logger
|
|
}
|
|
// Dial connects to Doubao ASR and sends the initial FullClientRequest.
|
|
// resultCh receives partial/final results. Caller must call Close() when done.
|
|
func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|
connID := uuid.New().String()
|
|
headers := http.Header{
|
|
"X-Api-App-Key": {cfg.AppID},
|
|
"X-Api-Access-Key": {cfg.AccessToken},
|
|
"X-Api-Resource-Id": {cfg.ResourceID},
|
|
"X-Api-Connect-Id": {connID},
|
|
}
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
conn, _, err := dialer.Dial(doubaoEndpoint, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial doubao: %w", err)
|
|
}
|
|
c := &Client{
|
|
cfg: cfg,
|
|
conn: conn,
|
|
closeCh: make(chan struct{}),
|
|
log: slog.With("conn_id", connID),
|
|
}
|
|
// Build corpus configuration
|
|
var corpus *Corpus
|
|
if len(cfg.Hotwords) > 0 {
|
|
contextJSON, err := BuildHotwordsContext(cfg.Hotwords)
|
|
if err != nil {
|
|
slog.Warn("failed to build hotwords context, skipping", "err", err)
|
|
} else {
|
|
corpus = &Corpus{Context: contextJSON}
|
|
slog.Info("hotwords enabled", "count", len(cfg.Hotwords))
|
|
}
|
|
}
|
|
// Send FullClientRequest
|
|
req := &FullClientRequest{
|
|
User: UserMeta{UID: connID},
|
|
Audio: AudioMeta{
|
|
Format: "pcm",
|
|
Codec: "raw",
|
|
Rate: 16000,
|
|
Bits: 16,
|
|
Channel: 1,
|
|
},
|
|
Request: RequestMeta{
|
|
ModelName: "seedasr-2.0",
|
|
EnableITN: true,
|
|
EnablePUNC: true,
|
|
EnableDDC: true,
|
|
ShowUtterances: true,
|
|
ResultType: "full",
|
|
EnableNonstream: true,
|
|
EndWindowSize: 800,
|
|
Corpus: corpus,
|
|
},
|
|
}
|
|
data, err := EncodeFullClientRequest(req)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("encode full request: %w", err)
|
|
}
|
|
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("send full request: %w", err)
|
|
}
|
|
// Start read loop
|
|
go c.readLoop(resultCh)
|
|
return c, nil
|
|
}
|
|
|
|
// SendAudio sends a PCM audio frame to Doubao.
|
|
func (c *Client) SendAudio(pcm []byte, last bool) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.closed {
|
|
return fmt.Errorf("client closed")
|
|
}
|
|
data, err := EncodeAudioFrame(pcm, last)
|
|
if err != nil {
|
|
return fmt.Errorf("encode audio: %w", err)
|
|
}
|
|
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
|
return c.conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
// readLoop reads server responses and forwards them to resultCh.
|
|
func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) {
|
|
defer func() {
|
|
c.conn.Close()
|
|
c.mu.Lock()
|
|
c.closed = true
|
|
c.mu.Unlock()
|
|
close(c.closeCh)
|
|
}()
|
|
for {
|
|
_ = c.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
_, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
c.log.Debug("asr read done", "err", err)
|
|
return
|
|
}
|
|
resp, err := ParseResponse(data)
|
|
if err != nil {
|
|
c.log.Warn("parse asr response", "err", err)
|
|
continue
|
|
}
|
|
if resp.Code != 0 {
|
|
c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg)
|
|
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg}
|
|
return
|
|
}
|
|
// bigmodel_async with enable_nonstream: returns both streaming (partial) and definite (final) results
|
|
text := resp.Text
|
|
if text != "" {
|
|
if resp.IsLast {
|
|
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text}
|
|
} else {
|
|
// Intermediate streaming result (first pass) — preview only
|
|
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text}
|
|
}
|
|
}
|
|
if resp.IsLast {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
// Finish sends the last audio frame and waits for ASR to return final results.
|
|
func (c *Client) Finish() {
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return
|
|
}
|
|
c.mu.Unlock()
|
|
_ = c.SendAudio(nil, true)
|
|
<-c.closeCh
|
|
}
|
|
|
|
// Close forcefully shuts down the ASR connection.
|
|
func (c *Client) Close() {
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return
|
|
}
|
|
c.conn.Close()
|
|
c.mu.Unlock()
|
|
<-c.closeCh
|
|
} |