feat: 添加豆包 ASR 热词功能支持
- 在 config.yaml 中添加 hotwords 配置项,支持本地管理热词列表
- 实现热词解析、格式化和表名生成工具(internal/asr/hotwords.go)
- 在 ASR 连接建立时自动将热词发送给豆包(boosting_table_name 参数)
- 支持热词权重配置(1-10,默认 4),格式:"词|权重" 或 "词"
- 支持配置热重载,修改热词后新连接自动生效
- 为未来动态热词功能预留扩展接口
热词格式示例:
hotwords:
- 张三|8
- VoicePaste|10
- 人工智能|6
This commit is contained in:
@@ -6,6 +6,12 @@ doubao:
|
|||||||
app_id: "" # env: DOUBAO_APP_ID
|
app_id: "" # env: DOUBAO_APP_ID
|
||||||
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
||||||
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
||||||
|
hotwords: # 可选:热词列表,格式 "词|权重" 或 "词"(默认权重 4)
|
||||||
|
# - 张三|8
|
||||||
|
# - 李四|8
|
||||||
|
# - VoicePaste|10
|
||||||
|
# - 人工智能|6
|
||||||
|
# - 测试
|
||||||
|
|
||||||
# 服务配置
|
# 服务配置
|
||||||
server:
|
server:
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ const (
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
AppID string
|
AppID string
|
||||||
AccessToken string
|
AccessToken string
|
||||||
ResourceID string
|
ResourceID string
|
||||||
|
Hotwords []string // 热词列表,格式 "词|权重" 或 "词"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client manages a single ASR session with Doubao.
|
// Client manages a single ASR session with Doubao.
|
||||||
@@ -57,6 +58,21 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
log: slog.With("conn_id", connID),
|
log: slog.With("conn_id", connID),
|
||||||
}
|
}
|
||||||
|
// Parse hotwords configuration
|
||||||
|
var boostingTableName string
|
||||||
|
if len(cfg.Hotwords) > 0 {
|
||||||
|
entries, err := ParseHotwords(cfg.Hotwords)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("invalid hotwords config, skipping", "err", err)
|
||||||
|
} else {
|
||||||
|
boostingTableName = GenerateTableName(entries)
|
||||||
|
tableContent := FormatHotwordsTable(entries)
|
||||||
|
slog.Info("hotwords enabled",
|
||||||
|
"count", len(entries),
|
||||||
|
"table_name", boostingTableName,
|
||||||
|
"content", tableContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
// Send FullClientRequest
|
// Send FullClientRequest
|
||||||
req := &FullClientRequest{
|
req := &FullClientRequest{
|
||||||
User: UserMeta{UID: connID},
|
User: UserMeta{UID: connID},
|
||||||
@@ -68,14 +84,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
Channel: 1,
|
Channel: 1,
|
||||||
},
|
},
|
||||||
Request: RequestMeta{
|
Request: RequestMeta{
|
||||||
ModelName: "seedasr-2.0",
|
ModelName: "seedasr-2.0",
|
||||||
EnableITN: true,
|
EnableITN: true,
|
||||||
EnablePUNC: true,
|
EnablePUNC: true,
|
||||||
EnableDDC: true,
|
EnableDDC: true,
|
||||||
ShowUtterances: true,
|
ShowUtterances: true,
|
||||||
ResultType: "full",
|
ResultType: "full",
|
||||||
EnableNonstream: true,
|
EnableNonstream: true,
|
||||||
EndWindowSize: 800,
|
EndWindowSize: 800,
|
||||||
|
BoostingTableName: boostingTableName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
data, err := EncodeFullClientRequest(req)
|
data, err := EncodeFullClientRequest(req)
|
||||||
|
|||||||
107
internal/asr/hotwords.go
Normal file
107
internal/asr/hotwords.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package asr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultHotwordWeight is the default weight when not specified.
|
||||||
|
DefaultHotwordWeight = 4
|
||||||
|
// MaxHotwordWeight is the maximum allowed weight.
|
||||||
|
MaxHotwordWeight = 10
|
||||||
|
// MinHotwordWeight is the minimum allowed weight.
|
||||||
|
MinHotwordWeight = 1
|
||||||
|
// MaxHotwordLength is the maximum character count per hotword.
|
||||||
|
MaxHotwordLength = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// HotwordEntry represents a single hotword with its weight.
|
||||||
|
type HotwordEntry struct {
|
||||||
|
Word string
|
||||||
|
Weight int // 1-10, default 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHotwords parses raw hotword strings from config.
|
||||||
|
// Format: "word|weight" or "word" (default weight 4).
|
||||||
|
// Returns error if any hotword is invalid.
|
||||||
|
func ParseHotwords(raw []string) ([]HotwordEntry, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]HotwordEntry, 0, len(raw))
|
||||||
|
for i, line := range raw {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(line, "|")
|
||||||
|
word := strings.TrimSpace(parts[0])
|
||||||
|
|
||||||
|
// Validate word length
|
||||||
|
if utf8.RuneCountInString(word) > MaxHotwordLength {
|
||||||
|
return nil, fmt.Errorf("hotword %d: exceeds %d characters: %q", i+1, MaxHotwordLength, word)
|
||||||
|
}
|
||||||
|
if word == "" {
|
||||||
|
return nil, fmt.Errorf("hotword %d: empty word", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse weight
|
||||||
|
weight := DefaultHotwordWeight
|
||||||
|
if len(parts) > 1 {
|
||||||
|
weightStr := strings.TrimSpace(parts[1])
|
||||||
|
if weightStr != "" {
|
||||||
|
w, err := strconv.Atoi(weightStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hotword %d: invalid weight %q: %w", i+1, weightStr, err)
|
||||||
|
}
|
||||||
|
if w < MinHotwordWeight || w > MaxHotwordWeight {
|
||||||
|
return nil, fmt.Errorf("hotword %d: weight %d out of range [%d, %d]", i+1, w, MinHotwordWeight, MaxHotwordWeight)
|
||||||
|
}
|
||||||
|
weight = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = append(entries, HotwordEntry{
|
||||||
|
Word: word,
|
||||||
|
Weight: weight,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatHotwordsTable generates Doubao-compatible hotword table content.
|
||||||
|
// Format: each line "word|weight\n".
|
||||||
|
func FormatHotwordsTable(entries []HotwordEntry) string {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, e := range entries {
|
||||||
|
sb.WriteString(e.Word)
|
||||||
|
sb.WriteString("|")
|
||||||
|
sb.WriteString(strconv.Itoa(e.Weight))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateTableName generates a unique table name based on hotword content.
|
||||||
|
// Uses SHA256 hash of the formatted table content (first 16 hex chars).
|
||||||
|
func GenerateTableName(entries []HotwordEntry) string {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
content := FormatHotwordsTable(entries)
|
||||||
|
hash := sha256.Sum256([]byte(content))
|
||||||
|
return "voicepaste_" + hex.EncodeToString(hash[:])[:16]
|
||||||
|
}
|
||||||
@@ -111,7 +111,9 @@ type RequestMeta struct {
|
|||||||
ShowUtterances bool `json:"show_utterances"`
|
ShowUtterances bool `json:"show_utterances"`
|
||||||
ResultType string `json:"result_type,omitempty"`
|
ResultType string `json:"result_type,omitempty"`
|
||||||
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
||||||
EndWindowSize int `json:"end_window_size,omitempty"`
|
EndWindowSize int `json:"end_window_size,omitempty"`
|
||||||
|
BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID
|
||||||
|
BoostingTableName string `json:"boosting_table_name,omitempty"` // 热词表名称
|
||||||
}
|
}
|
||||||
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
||||||
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import (
|
|||||||
type DoubaoConfig struct {
|
type DoubaoConfig struct {
|
||||||
AppID string `yaml:"app_id"`
|
AppID string `yaml:"app_id"`
|
||||||
AccessToken string `yaml:"access_token"`
|
AccessToken string `yaml:"access_token"`
|
||||||
ResourceID string `yaml:"resource_id"`
|
ResourceID string `yaml:"resource_id"`
|
||||||
|
Hotwords []string `yaml:"hotwords"` // 热词列表,格式 "词|权重" 或 "词"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecurityConfig holds authentication settings.
|
// SecurityConfig holds authentication settings.
|
||||||
|
|||||||
@@ -133,6 +133,11 @@ func (h *Handler) handleStart(s *session) {
|
|||||||
if s.active {
|
if s.active {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Future extension: support runtime dynamic hotwords (Phase 2)
|
||||||
|
// if len(msg.Hotwords) > 0 {
|
||||||
|
// // Priority: runtime hotwords > config.yaml hotwords
|
||||||
|
// // Need to modify asrFactory signature to pass msg.Hotwords
|
||||||
|
// }
|
||||||
s.previewMu.Lock()
|
s.previewMu.Lock()
|
||||||
s.previewText = ""
|
s.previewText = ""
|
||||||
s.previewMu.Unlock()
|
s.previewMu.Unlock()
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ const (
|
|||||||
type ClientMsg struct {
|
type ClientMsg struct {
|
||||||
Type MsgType `json:"type"`
|
Type MsgType `json:"type"`
|
||||||
Text string `json:"text,omitempty"` // Only for "paste"
|
Text string `json:"text,omitempty"` // Only for "paste"
|
||||||
|
// Future extension: dynamic hotwords (Phase 2)
|
||||||
|
// Hotwords []string `json:"hotwords,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Server → Client messages ──
|
// ── Server → Client messages ──
|
||||||
|
|||||||
1
main.go
1
main.go
@@ -147,6 +147,7 @@ func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte),
|
|||||||
AppID: cfg.Doubao.AppID,
|
AppID: cfg.Doubao.AppID,
|
||||||
AccessToken: cfg.Doubao.AccessToken,
|
AccessToken: cfg.Doubao.AccessToken,
|
||||||
ResourceID: cfg.Doubao.ResourceID,
|
ResourceID: cfg.Doubao.ResourceID,
|
||||||
|
Hotwords: cfg.Doubao.Hotwords,
|
||||||
}
|
}
|
||||||
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
||||||
client, err := asr.Dial(asrCfg, resultCh)
|
client, err := asr.Dial(asrCfg, resultCh)
|
||||||
|
|||||||
Reference in New Issue
Block a user