refactor: 简化热词配置为豆包控制台 ID
- 移除本地热词列表配置,改为直接使用豆包控制台的热词表 ID - 删除 internal/asr/hotwords.go(不再需要本地解析) - 简化 client.go 逻辑,直接传递 boosting_table_id - 移除 protocol.go 中的 boosting_table_name 字段 - 更新配置示例,添加控制台链接说明 使用方法: 1. 在豆包控制台创建热词表:https://console.volcengine.com/speech/hotword 2. 复制热词表 ID 到 config.yaml 的 boosting_table_id 字段
This commit is contained in:
@@ -6,12 +6,7 @@ doubao:
|
|||||||
app_id: "" # env: DOUBAO_APP_ID
|
app_id: "" # env: DOUBAO_APP_ID
|
||||||
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
access_token: "" # env: DOUBAO_ACCESS_TOKEN
|
||||||
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID
|
||||||
hotwords: # 可选:热词列表,格式 "词|权重" 或 "词"(默认权重 4)
|
boosting_table_id: "" # 可选:热词表 ID(从控制台 https://console.volcengine.com/speech/hotword 创建)
|
||||||
# - 张三|8
|
|
||||||
# - 李四|8
|
|
||||||
# - VoicePaste|10
|
|
||||||
# - 人工智能|6
|
|
||||||
# - 测试
|
|
||||||
|
|
||||||
# 服务配置
|
# 服务配置
|
||||||
server:
|
server:
|
||||||
|
|||||||
@@ -20,10 +20,10 @@ const (
|
|||||||
|
|
||||||
// Config holds Doubao ASR connection parameters.
|
// Config holds Doubao ASR connection parameters.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AppID string
|
AppID string
|
||||||
AccessToken string
|
AccessToken string
|
||||||
ResourceID string
|
ResourceID string
|
||||||
Hotwords []string // 热词列表,格式 "词|权重" 或 "词"
|
BoostingTableID string // 热词表 ID(从控制台创建)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client manages a single ASR session with Doubao.
|
// Client manages a single ASR session with Doubao.
|
||||||
@@ -58,21 +58,6 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
log: slog.With("conn_id", connID),
|
log: slog.With("conn_id", connID),
|
||||||
}
|
}
|
||||||
// Parse hotwords configuration
|
|
||||||
var boostingTableName string
|
|
||||||
if len(cfg.Hotwords) > 0 {
|
|
||||||
entries, err := ParseHotwords(cfg.Hotwords)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("invalid hotwords config, skipping", "err", err)
|
|
||||||
} else {
|
|
||||||
boostingTableName = GenerateTableName(entries)
|
|
||||||
tableContent := FormatHotwordsTable(entries)
|
|
||||||
slog.Info("hotwords enabled",
|
|
||||||
"count", len(entries),
|
|
||||||
"table_name", boostingTableName,
|
|
||||||
"content", tableContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Send FullClientRequest
|
// Send FullClientRequest
|
||||||
req := &FullClientRequest{
|
req := &FullClientRequest{
|
||||||
User: UserMeta{UID: connID},
|
User: UserMeta{UID: connID},
|
||||||
@@ -84,15 +69,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
|||||||
Channel: 1,
|
Channel: 1,
|
||||||
},
|
},
|
||||||
Request: RequestMeta{
|
Request: RequestMeta{
|
||||||
ModelName: "seedasr-2.0",
|
ModelName: "seedasr-2.0",
|
||||||
EnableITN: true,
|
EnableITN: true,
|
||||||
EnablePUNC: true,
|
EnablePUNC: true,
|
||||||
EnableDDC: true,
|
EnableDDC: true,
|
||||||
ShowUtterances: true,
|
ShowUtterances: true,
|
||||||
ResultType: "full",
|
ResultType: "full",
|
||||||
EnableNonstream: true,
|
EnableNonstream: true,
|
||||||
EndWindowSize: 800,
|
EndWindowSize: 800,
|
||||||
BoostingTableName: boostingTableName,
|
BoostingTableID: cfg.BoostingTableID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
data, err := EncodeFullClientRequest(req)
|
data, err := EncodeFullClientRequest(req)
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
package asr
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"unicode/utf8"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultHotwordWeight is the default weight when not specified.
|
|
||||||
DefaultHotwordWeight = 4
|
|
||||||
// MaxHotwordWeight is the maximum allowed weight.
|
|
||||||
MaxHotwordWeight = 10
|
|
||||||
// MinHotwordWeight is the minimum allowed weight.
|
|
||||||
MinHotwordWeight = 1
|
|
||||||
// MaxHotwordLength is the maximum character count per hotword.
|
|
||||||
MaxHotwordLength = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
// HotwordEntry represents a single hotword with its weight.
|
|
||||||
type HotwordEntry struct {
|
|
||||||
Word string
|
|
||||||
Weight int // 1-10, default 4
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseHotwords parses raw hotword strings from config.
|
|
||||||
// Format: "word|weight" or "word" (default weight 4).
|
|
||||||
// Returns error if any hotword is invalid.
|
|
||||||
func ParseHotwords(raw []string) ([]HotwordEntry, error) {
|
|
||||||
if len(raw) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := make([]HotwordEntry, 0, len(raw))
|
|
||||||
for i, line := range raw {
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(line, "|")
|
|
||||||
word := strings.TrimSpace(parts[0])
|
|
||||||
|
|
||||||
// Validate word length
|
|
||||||
if utf8.RuneCountInString(word) > MaxHotwordLength {
|
|
||||||
return nil, fmt.Errorf("hotword %d: exceeds %d characters: %q", i+1, MaxHotwordLength, word)
|
|
||||||
}
|
|
||||||
if word == "" {
|
|
||||||
return nil, fmt.Errorf("hotword %d: empty word", i+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse weight
|
|
||||||
weight := DefaultHotwordWeight
|
|
||||||
if len(parts) > 1 {
|
|
||||||
weightStr := strings.TrimSpace(parts[1])
|
|
||||||
if weightStr != "" {
|
|
||||||
w, err := strconv.Atoi(weightStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("hotword %d: invalid weight %q: %w", i+1, weightStr, err)
|
|
||||||
}
|
|
||||||
if w < MinHotwordWeight || w > MaxHotwordWeight {
|
|
||||||
return nil, fmt.Errorf("hotword %d: weight %d out of range [%d, %d]", i+1, w, MinHotwordWeight, MaxHotwordWeight)
|
|
||||||
}
|
|
||||||
weight = w
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
entries = append(entries, HotwordEntry{
|
|
||||||
Word: word,
|
|
||||||
Weight: weight,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return entries, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatHotwordsTable generates Doubao-compatible hotword table content.
|
|
||||||
// Format: each line "word|weight\n".
|
|
||||||
func FormatHotwordsTable(entries []HotwordEntry) string {
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, e := range entries {
|
|
||||||
sb.WriteString(e.Word)
|
|
||||||
sb.WriteString("|")
|
|
||||||
sb.WriteString(strconv.Itoa(e.Weight))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateTableName generates a unique table name based on hotword content.
|
|
||||||
// Uses SHA256 hash of the formatted table content (first 16 hex chars).
|
|
||||||
func GenerateTableName(entries []HotwordEntry) string {
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
content := FormatHotwordsTable(entries)
|
|
||||||
hash := sha256.Sum256([]byte(content))
|
|
||||||
return "voicepaste_" + hex.EncodeToString(hash[:])[:16]
|
|
||||||
}
|
|
||||||
@@ -111,9 +111,8 @@ type RequestMeta struct {
|
|||||||
ShowUtterances bool `json:"show_utterances"`
|
ShowUtterances bool `json:"show_utterances"`
|
||||||
ResultType string `json:"result_type,omitempty"`
|
ResultType string `json:"result_type,omitempty"`
|
||||||
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
||||||
EndWindowSize int `json:"end_window_size,omitempty"`
|
EndWindowSize int `json:"end_window_size,omitempty"`
|
||||||
BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID
|
BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID
|
||||||
BoostingTableName string `json:"boosting_table_name,omitempty"` // 热词表名称
|
|
||||||
}
|
}
|
||||||
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
||||||
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import (
|
|||||||
|
|
||||||
// DoubaoConfig holds 火山引擎豆包 ASR credentials.
|
// DoubaoConfig holds 火山引擎豆包 ASR credentials.
|
||||||
type DoubaoConfig struct {
|
type DoubaoConfig struct {
|
||||||
AppID string `yaml:"app_id"`
|
AppID string `yaml:"app_id"`
|
||||||
AccessToken string `yaml:"access_token"`
|
AccessToken string `yaml:"access_token"`
|
||||||
ResourceID string `yaml:"resource_id"`
|
ResourceID string `yaml:"resource_id"`
|
||||||
Hotwords []string `yaml:"hotwords"` // 热词列表,格式 "词|权重" 或 "词"
|
BoostingTableID string `yaml:"boosting_table_id"` // 热词表 ID(从控制台创建)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecurityConfig holds authentication settings.
|
// SecurityConfig holds authentication settings.
|
||||||
|
|||||||
8
main.go
8
main.go
@@ -144,10 +144,10 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser
|
|||||||
|
|
||||||
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
||||||
asrCfg := asr.Config{
|
asrCfg := asr.Config{
|
||||||
AppID: cfg.Doubao.AppID,
|
AppID: cfg.Doubao.AppID,
|
||||||
AccessToken: cfg.Doubao.AccessToken,
|
AccessToken: cfg.Doubao.AccessToken,
|
||||||
ResourceID: cfg.Doubao.ResourceID,
|
ResourceID: cfg.Doubao.ResourceID,
|
||||||
Hotwords: cfg.Doubao.Hotwords,
|
BoostingTableID: cfg.Doubao.BoostingTableID,
|
||||||
}
|
}
|
||||||
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
|
||||||
client, err := asr.Dial(asrCfg, resultCh)
|
client, err := asr.Dial(asrCfg, resultCh)
|
||||||
|
|||||||
Reference in New Issue
Block a user