From b786d9f90bf3446cde4be310425376938efe0e9d Mon Sep 17 00:00:00 2001 From: imbytecat Date: Mon, 2 Mar 2026 01:36:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E7=83=AD=E8=AF=8D=E7=AE=A1=E7=90=86=EF=BC=8C=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E7=BB=91=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 corpus.context 参数直接传递热词列表(豆包文档支持) - 移除 boosting_table_id 配置,避免绑定火山引擎控制台 - 实现 BuildHotwordsContext 函数,将本地热词转换为 JSON 格式 - 热词配置完全本地化,便于迁移到其他 ASR 平台 配置示例: hotwords: - 张三 - 李四 - VoicePaste 程序自动转换为豆包 API 要求的格式: {"hotwords":[{"word":"张三"},{"word":"李四"},{"word":"VoicePaste"}]} --- config.example.yaml | 6 +++++- internal/asr/client.go | 37 +++++++++++++++++++++------------ internal/asr/hotwords.go | 43 +++++++++++++++++++++++++++++++++++++++ internal/asr/protocol.go | 23 +++++++++++++-------- internal/config/config.go | 8 ++++---- main.go | 8 ++++---- 6 files changed, 94 insertions(+), 31 deletions(-) create mode 100644 internal/asr/hotwords.go diff --git a/config.example.yaml b/config.example.yaml index d24b55a..e8005f1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -6,7 +6,11 @@ doubao: app_id: "" # env: DOUBAO_APP_ID access_token: "" # env: DOUBAO_ACCESS_TOKEN resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID - boosting_table_id: "" # 可选:热词表 ID(从控制台 https://console.volcengine.com/speech/hotword 创建) + hotwords: # 可选:本地热词列表 + # - 张三 + # - 李四 + # - VoicePaste + # - 人工智能 # 服务配置 server: diff --git a/internal/asr/client.go b/internal/asr/client.go index 6bdad1b..f3b65c8 100644 --- a/internal/asr/client.go +++ b/internal/asr/client.go @@ -20,10 +20,10 @@ const ( // Config holds Doubao ASR connection parameters. type Config struct { - AppID string - AccessToken string - ResourceID string - BoostingTableID string // 热词表 ID(从控制台创建) + AppID string + AccessToken string + ResourceID string + Hotwords []string // 本地热词列表 } // Client manages a single ASR session with Doubao. @@ -58,6 +58,17 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { closeCh: make(chan struct{}), log: slog.With("conn_id", connID), } + // Build corpus configuration + var corpus *Corpus + if len(cfg.Hotwords) > 0 { + contextJSON, err := BuildHotwordsContext(cfg.Hotwords) + if err != nil { + slog.Warn("failed to build hotwords context, skipping", "err", err) + } else { + corpus = &Corpus{Context: contextJSON} + slog.Info("hotwords enabled", "count", len(cfg.Hotwords)) + } + } // Send FullClientRequest req := &FullClientRequest{ User: UserMeta{UID: connID}, @@ -69,15 +80,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { Channel: 1, }, Request: RequestMeta{ - ModelName: "seedasr-2.0", - EnableITN: true, - EnablePUNC: true, - EnableDDC: true, - ShowUtterances: true, - ResultType: "full", - EnableNonstream: true, - EndWindowSize: 800, - BoostingTableID: cfg.BoostingTableID, + ModelName: "seedasr-2.0", + EnableITN: true, + EnablePUNC: true, + EnableDDC: true, + ShowUtterances: true, + ResultType: "full", + EnableNonstream: true, + EndWindowSize: 800, + Corpus: corpus, }, } data, err := EncodeFullClientRequest(req) diff --git a/internal/asr/hotwords.go b/internal/asr/hotwords.go new file mode 100644 index 0000000..4613e12 --- /dev/null +++ b/internal/asr/hotwords.go @@ -0,0 +1,43 @@ +package asr + +import ( + "encoding/json" + "fmt" +) + +// HotwordEntry represents a single hotword for context JSON. +type HotwordEntry struct { + Word string `json:"word"` +} + +// HotwordsContext represents the context JSON structure for hotwords. +type HotwordsContext struct { + Hotwords []HotwordEntry `json:"hotwords"` +} + +// BuildHotwordsContext converts a list of hotword strings to context JSON string. +// Returns empty string if hotwords list is empty. +func BuildHotwordsContext(hotwords []string) (string, error) { + if len(hotwords) == 0 { + return "", nil + } + + entries := make([]HotwordEntry, 0, len(hotwords)) + for _, word := range hotwords { + if word != "" { + entries = append(entries, HotwordEntry{Word: word}) + } + } + + if len(entries) == 0 { + return "", nil + } + + ctx := HotwordsContext{Hotwords: entries} + data, err := json.Marshal(ctx) + if err != nil { + return "", fmt.Errorf("marshal hotwords context: %w", err) + } + + return string(data), nil +} diff --git a/internal/asr/protocol.go b/internal/asr/protocol.go index 12a5b71..4984048 100644 --- a/internal/asr/protocol.go +++ b/internal/asr/protocol.go @@ -103,16 +103,21 @@ type AudioMeta struct { Channel int `json:"channel"` } +// Corpus holds hotwords and context configuration. +type Corpus struct { + Context string `json:"context,omitempty"` // 热词直传 JSON +} + type RequestMeta struct { - ModelName string `json:"model_name"` - EnableITN bool `json:"enable_itn"` - EnablePUNC bool `json:"enable_punc"` - EnableDDC bool `json:"enable_ddc"` - ShowUtterances bool `json:"show_utterances"` - ResultType string `json:"result_type,omitempty"` - EnableNonstream bool `json:"enable_nonstream,omitempty"` - EndWindowSize int `json:"end_window_size,omitempty"` - BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID + ModelName string `json:"model_name"` + EnableITN bool `json:"enable_itn"` + EnablePUNC bool `json:"enable_punc"` + EnableDDC bool `json:"enable_ddc"` + ShowUtterances bool `json:"show_utterances"` + ResultType string `json:"result_type,omitempty"` + EnableNonstream bool `json:"enable_nonstream,omitempty"` + EndWindowSize int `json:"end_window_size,omitempty"` + Corpus *Corpus `json:"corpus,omitempty"` // 语料/热词配置 } // EncodeFullClientRequest builds the binary message for the initial handshake. // nostream mode: header(4) + payload_size(4) + gzip(json) diff --git a/internal/config/config.go b/internal/config/config.go index 573e912..10ab115 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,10 +6,10 @@ import ( // DoubaoConfig holds 火山引擎豆包 ASR credentials. type DoubaoConfig struct { - AppID string `yaml:"app_id"` - AccessToken string `yaml:"access_token"` - ResourceID string `yaml:"resource_id"` - BoostingTableID string `yaml:"boosting_table_id"` // 热词表 ID(从控制台创建) + AppID string `yaml:"app_id"` + AccessToken string `yaml:"access_token"` + ResourceID string `yaml:"resource_id"` + Hotwords []string `yaml:"hotwords"` // 本地热词列表 } // SecurityConfig holds authentication settings. diff --git a/main.go b/main.go index b03cc37..4e6f936 100644 --- a/main.go +++ b/main.go @@ -144,10 +144,10 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) { asrCfg := asr.Config{ - AppID: cfg.Doubao.AppID, - AccessToken: cfg.Doubao.AccessToken, - ResourceID: cfg.Doubao.ResourceID, - BoostingTableID: cfg.Doubao.BoostingTableID, + AppID: cfg.Doubao.AppID, + AccessToken: cfg.Doubao.AccessToken, + ResourceID: cfg.Doubao.ResourceID, + Hotwords: cfg.Doubao.Hotwords, } return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { client, err := asr.Dial(asrCfg, resultCh)