feat: 实现本地热词管理,移除平台绑定
- 使用 corpus.context 参数直接传递热词列表(豆包文档支持)
- 移除 boosting_table_id 配置,避免绑定火山引擎控制台
- 实现 BuildHotwordsContext 函数,将本地热词转换为 JSON 格式
- 热词配置完全本地化,便于迁移到其他 ASR 平台
配置示例:
hotwords:
- 张三
- 李四
- VoicePaste
程序自动转换为豆包 API 要求的格式:
{"hotwords":[{"word":"张三"},{"word":"李四"},{"word":"VoicePaste"}]}
This commit is contained in:
@@ -6,7 +6,11 @@ doubao:
|
|||||||
app_id: "" # env: DOUBAO_APP_ID
|
app_id: "" # env: DOUBAO_APP_ID
|
||||||
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
||||||
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
||||||
boosting_table_id: "" # 可选:热词表 ID(从控制台 https://console.volcengine.com/speech/hotword 创建)
|
hotwords: # 可选:本地热词列表
|
||||||
|
# - 张三
|
||||||
|
# - 李四
|
||||||
|
# - VoicePaste
|
||||||
|
# - 人工智能
|
||||||
|
|
||||||
# 服务配置
|
# 服务配置
|
||||||
server:
|
server:
|
||||||
|
|||||||
@@ -20,10 +20,10 @@ const (
|
|||||||
|
|
||||||
// Config holds Doubao ASR connection parameters.
|
// Config holds Doubao ASR connection parameters.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AppID string
|
AppID string
|
||||||
AccessToken string
|
AccessToken string
|
||||||
ResourceID string
|
ResourceID string
|
||||||
BoostingTableID string // 热词表 ID(从控制台创建)
|
Hotwords []string // 本地热词列表
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client manages a single ASR session with Doubao.
|
// Client manages a single ASR session with Doubao.
|
||||||
@@ -58,6 +58,17 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
log: slog.With("conn_id", connID),
|
log: slog.With("conn_id", connID),
|
||||||
}
|
}
|
||||||
|
// Build corpus configuration
|
||||||
|
var corpus *Corpus
|
||||||
|
if len(cfg.Hotwords) > 0 {
|
||||||
|
contextJSON, err := BuildHotwordsContext(cfg.Hotwords)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to build hotwords context, skipping", "err", err)
|
||||||
|
} else {
|
||||||
|
corpus = &Corpus{Context: contextJSON}
|
||||||
|
slog.Info("hotwords enabled", "count", len(cfg.Hotwords))
|
||||||
|
}
|
||||||
|
}
|
||||||
// Send FullClientRequest
|
// Send FullClientRequest
|
||||||
req := &FullClientRequest{
|
req := &FullClientRequest{
|
||||||
User: UserMeta{UID: connID},
|
User: UserMeta{UID: connID},
|
||||||
@@ -69,15 +80,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
Channel: 1,
|
Channel: 1,
|
||||||
},
|
},
|
||||||
Request: RequestMeta{
|
Request: RequestMeta{
|
||||||
ModelName: "seedasr-2.0",
|
ModelName: "seedasr-2.0",
|
||||||
EnableITN: true,
|
EnableITN: true,
|
||||||
EnablePUNC: true,
|
EnablePUNC: true,
|
||||||
EnableDDC: true,
|
EnableDDC: true,
|
||||||
ShowUtterances: true,
|
ShowUtterances: true,
|
||||||
ResultType: "full",
|
ResultType: "full",
|
||||||
EnableNonstream: true,
|
EnableNonstream: true,
|
||||||
EndWindowSize: 800,
|
EndWindowSize: 800,
|
||||||
BoostingTableID: cfg.BoostingTableID,
|
Corpus: corpus,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
data, err := EncodeFullClientRequest(req)
|
data, err := EncodeFullClientRequest(req)
|
||||||
|
|||||||
43
internal/asr/hotwords.go
Normal file
43
internal/asr/hotwords.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package asr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HotwordEntry represents a single hotword for context JSON.
|
||||||
|
type HotwordEntry struct {
|
||||||
|
Word string `json:"word"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HotwordsContext represents the context JSON structure for hotwords.
|
||||||
|
type HotwordsContext struct {
|
||||||
|
Hotwords []HotwordEntry `json:"hotwords"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildHotwordsContext converts a list of hotword strings to context JSON string.
|
||||||
|
// Returns empty string if hotwords list is empty.
|
||||||
|
func BuildHotwordsContext(hotwords []string) (string, error) {
|
||||||
|
if len(hotwords) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]HotwordEntry, 0, len(hotwords))
|
||||||
|
for _, word := range hotwords {
|
||||||
|
if word != "" {
|
||||||
|
entries = append(entries, HotwordEntry{Word: word})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := HotwordsContext{Hotwords: entries}
|
||||||
|
data, err := json.Marshal(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal hotwords context: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
@@ -103,16 +103,21 @@ type AudioMeta struct {
|
|||||||
Channel int `json:"channel"`
|
Channel int `json:"channel"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Corpus holds hotwords and context configuration.
|
||||||
|
type Corpus struct {
|
||||||
|
Context string `json:"context,omitempty"` // 热词直传 JSON
|
||||||
|
}
|
||||||
|
|
||||||
type RequestMeta struct {
|
type RequestMeta struct {
|
||||||
ModelName string `json:"model_name"`
|
ModelName string `json:"model_name"`
|
||||||
EnableITN bool `json:"enable_itn"`
|
EnableITN bool `json:"enable_itn"`
|
||||||
EnablePUNC bool `json:"enable_punc"`
|
EnablePUNC bool `json:"enable_punc"`
|
||||||
EnableDDC bool `json:"enable_ddc"`
|
EnableDDC bool `json:"enable_ddc"`
|
||||||
ShowUtterances bool `json:"show_utterances"`
|
ShowUtterances bool `json:"show_utterances"`
|
||||||
ResultType string `json:"result_type,omitempty"`
|
ResultType string `json:"result_type,omitempty"`
|
||||||
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
||||||
EndWindowSize int `json:"end_window_size,omitempty"`
|
EndWindowSize int `json:"end_window_size,omitempty"`
|
||||||
BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID
|
Corpus *Corpus `json:"corpus,omitempty"` // 语料/热词配置
|
||||||
}
|
}
|
||||||
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
||||||
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import (
|
|||||||
|
|
||||||
// DoubaoConfig holds 火山引擎豆包 ASR credentials.
|
// DoubaoConfig holds 火山引擎豆包 ASR credentials.
|
||||||
type DoubaoConfig struct {
|
type DoubaoConfig struct {
|
||||||
AppID string `yaml:"app_id"`
|
AppID string `yaml:"app_id"`
|
||||||
AccessToken string `yaml:"access_token"`
|
AccessToken string `yaml:"access_token"`
|
||||||
ResourceID string `yaml:"resource_id"`
|
ResourceID string `yaml:"resource_id"`
|
||||||
BoostingTableID string `yaml:"boosting_table_id"` // 热词表 ID(从控制台创建)
|
Hotwords []string `yaml:"hotwords"` // 本地热词列表
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecurityConfig holds authentication settings.
|
// SecurityConfig holds authentication settings.
|
||||||
|
|||||||
8
main.go
8
main.go
@@ -144,10 +144,10 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser
|
|||||||
|
|
||||||
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
||||||
asrCfg := asr.Config{
|
asrCfg := asr.Config{
|
||||||
AppID: cfg.Doubao.AppID,
|
AppID: cfg.Doubao.AppID,
|
||||||
AccessToken: cfg.Doubao.AccessToken,
|
AccessToken: cfg.Doubao.AccessToken,
|
||||||
ResourceID: cfg.Doubao.ResourceID,
|
ResourceID: cfg.Doubao.ResourceID,
|
||||||
BoostingTableID: cfg.Doubao.BoostingTableID,
|
Hotwords: cfg.Doubao.Hotwords,
|
||||||
}
|
}
|
||||||
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
||||||
client, err := asr.Dial(asrCfg, resultCh)
|
client, err := asr.Dial(asrCfg, resultCh)
|
||||||
|
|||||||
Reference in New Issue
Block a user