From 4120d6451e4e473422ab3cfc4abe4ad9fe6f2d10 Mon Sep 17 00:00:00 2001 From: imbytecat Date: Mon, 2 Mar 2026 01:16:34 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=E7=83=AD?= =?UTF-8?q?=E8=AF=8D=E9=85=8D=E7=BD=AE=E4=B8=BA=E8=B1=86=E5=8C=85=E6=8E=A7?= =?UTF-8?q?=E5=88=B6=E5=8F=B0=20ID?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除本地热词列表配置,改为直接使用豆包控制台的热词表 ID - 删除 internal/asr/hotwords.go(不再需要本地解析) - 简化 client.go 逻辑,直接传递 boosting_table_id - 移除 protocol.go 中的 boosting_table_name 字段 - 更新配置示例,添加控制台链接说明 使用方法: 1. 在豆包控制台创建热词表:https://console.volcengine.com/speech/hotword 2. 复制热词表 ID 到 config.yaml 的 boosting_table_id 字段 --- config.example.yaml | 7 +-- internal/asr/client.go | 41 +++++---------- internal/asr/hotwords.go | 107 -------------------------------------- internal/asr/protocol.go | 5 +- internal/config/config.go | 8 +-- main.go | 8 +-- 6 files changed, 24 insertions(+), 152 deletions(-) delete mode 100644 internal/asr/hotwords.go diff --git a/config.example.yaml b/config.example.yaml index 2a690da..d24b55a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -6,12 +6,7 @@ doubao: app_id: "" # env: DOUBAO_APP_ID access_token: "" # env: DOUBAO_ACCESS_TOKEN resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID - hotwords: # 可选:热词列表,格式 "词|权重" 或 "词"(默认权重 4) - # - 张三|8 - # - 李四|8 - # - VoicePaste|10 - # - 人工智能|6 - # - 测试 + boosting_table_id: "" # 可选:热词表 ID(从控制台 https://console.volcengine.com/speech/hotword 创建) # 服务配置 server: diff --git a/internal/asr/client.go b/internal/asr/client.go index 7f06db5..6bdad1b 100644 --- a/internal/asr/client.go +++ b/internal/asr/client.go @@ -20,10 +20,10 @@ const ( // Config holds Doubao ASR connection parameters. type Config struct { - AppID string - AccessToken string - ResourceID string - Hotwords []string // 热词列表,格式 "词|权重" 或 "词" + AppID string + AccessToken string + ResourceID string + BoostingTableID string // 热词表 ID(从控制台创建) } // Client manages a single ASR session with Doubao. @@ -58,21 +58,6 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { closeCh: make(chan struct{}), log: slog.With("conn_id", connID), } - // Parse hotwords configuration - var boostingTableName string - if len(cfg.Hotwords) > 0 { - entries, err := ParseHotwords(cfg.Hotwords) - if err != nil { - slog.Warn("invalid hotwords config, skipping", "err", err) - } else { - boostingTableName = GenerateTableName(entries) - tableContent := FormatHotwordsTable(entries) - slog.Info("hotwords enabled", - "count", len(entries), - "table_name", boostingTableName, - "content", tableContent) - } - } // Send FullClientRequest req := &FullClientRequest{ User: UserMeta{UID: connID}, @@ -84,15 +69,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { Channel: 1, }, Request: RequestMeta{ - ModelName: "seedasr-2.0", - EnableITN: true, - EnablePUNC: true, - EnableDDC: true, - ShowUtterances: true, - ResultType: "full", - EnableNonstream: true, - EndWindowSize: 800, - BoostingTableName: boostingTableName, + ModelName: "seedasr-2.0", + EnableITN: true, + EnablePUNC: true, + EnableDDC: true, + ShowUtterances: true, + ResultType: "full", + EnableNonstream: true, + EndWindowSize: 800, + BoostingTableID: cfg.BoostingTableID, }, } data, err := EncodeFullClientRequest(req) diff --git a/internal/asr/hotwords.go b/internal/asr/hotwords.go deleted file mode 100644 index 751c457..0000000 --- a/internal/asr/hotwords.go +++ /dev/null @@ -1,107 +0,0 @@ -package asr - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "strconv" - "strings" - "unicode/utf8" -) - -const ( - // DefaultHotwordWeight is the default weight when not specified. - DefaultHotwordWeight = 4 - // MaxHotwordWeight is the maximum allowed weight. - MaxHotwordWeight = 10 - // MinHotwordWeight is the minimum allowed weight. - MinHotwordWeight = 1 - // MaxHotwordLength is the maximum character count per hotword. - MaxHotwordLength = 10 -) - -// HotwordEntry represents a single hotword with its weight. -type HotwordEntry struct { - Word string - Weight int // 1-10, default 4 -} - -// ParseHotwords parses raw hotword strings from config. -// Format: "word|weight" or "word" (default weight 4). -// Returns error if any hotword is invalid. -func ParseHotwords(raw []string) ([]HotwordEntry, error) { - if len(raw) == 0 { - return nil, nil - } - - entries := make([]HotwordEntry, 0, len(raw)) - for i, line := range raw { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - parts := strings.Split(line, "|") - word := strings.TrimSpace(parts[0]) - - // Validate word length - if utf8.RuneCountInString(word) > MaxHotwordLength { - return nil, fmt.Errorf("hotword %d: exceeds %d characters: %q", i+1, MaxHotwordLength, word) - } - if word == "" { - return nil, fmt.Errorf("hotword %d: empty word", i+1) - } - - // Parse weight - weight := DefaultHotwordWeight - if len(parts) > 1 { - weightStr := strings.TrimSpace(parts[1]) - if weightStr != "" { - w, err := strconv.Atoi(weightStr) - if err != nil { - return nil, fmt.Errorf("hotword %d: invalid weight %q: %w", i+1, weightStr, err) - } - if w < MinHotwordWeight || w > MaxHotwordWeight { - return nil, fmt.Errorf("hotword %d: weight %d out of range [%d, %d]", i+1, w, MinHotwordWeight, MaxHotwordWeight) - } - weight = w - } - } - - entries = append(entries, HotwordEntry{ - Word: word, - Weight: weight, - }) - } - - return entries, nil -} - -// FormatHotwordsTable generates Doubao-compatible hotword table content. -// Format: each line "word|weight\n". -func FormatHotwordsTable(entries []HotwordEntry) string { - if len(entries) == 0 { - return "" - } - - var sb strings.Builder - for _, e := range entries { - sb.WriteString(e.Word) - sb.WriteString("|") - sb.WriteString(strconv.Itoa(e.Weight)) - sb.WriteString("\n") - } - return sb.String() -} - -// GenerateTableName generates a unique table name based on hotword content. -// Uses SHA256 hash of the formatted table content (first 16 hex chars). -func GenerateTableName(entries []HotwordEntry) string { - if len(entries) == 0 { - return "" - } - - content := FormatHotwordsTable(entries) - hash := sha256.Sum256([]byte(content)) - return "voicepaste_" + hex.EncodeToString(hash[:])[:16] -} diff --git a/internal/asr/protocol.go b/internal/asr/protocol.go index 86c2cbc..12a5b71 100644 --- a/internal/asr/protocol.go +++ b/internal/asr/protocol.go @@ -111,9 +111,8 @@ type RequestMeta struct { ShowUtterances bool `json:"show_utterances"` ResultType string `json:"result_type,omitempty"` EnableNonstream bool `json:"enable_nonstream,omitempty"` - EndWindowSize int `json:"end_window_size,omitempty"` - BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID - BoostingTableName string `json:"boosting_table_name,omitempty"` // 热词表名称 + EndWindowSize int `json:"end_window_size,omitempty"` + BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID } // EncodeFullClientRequest builds the binary message for the initial handshake. // nostream mode: header(4) + payload_size(4) + gzip(json) diff --git a/internal/config/config.go b/internal/config/config.go index 654f833..573e912 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,10 +6,10 @@ import ( // DoubaoConfig holds 火山引擎豆包 ASR credentials. type DoubaoConfig struct { - AppID string `yaml:"app_id"` - AccessToken string `yaml:"access_token"` - ResourceID string `yaml:"resource_id"` - Hotwords []string `yaml:"hotwords"` // 热词列表,格式 "词|权重" 或 "词" + AppID string `yaml:"app_id"` + AccessToken string `yaml:"access_token"` + ResourceID string `yaml:"resource_id"` + BoostingTableID string `yaml:"boosting_table_id"` // 热词表 ID(从控制台创建) } // SecurityConfig holds authentication settings. diff --git a/main.go b/main.go index 4e6f936..b03cc37 100644 --- a/main.go +++ b/main.go @@ -144,10 +144,10 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) { asrCfg := asr.Config{ - AppID: cfg.Doubao.AppID, - AccessToken: cfg.Doubao.AccessToken, - ResourceID: cfg.Doubao.ResourceID, - Hotwords: cfg.Doubao.Hotwords, + AppID: cfg.Doubao.AppID, + AccessToken: cfg.Doubao.AccessToken, + ResourceID: cfg.Doubao.ResourceID, + BoostingTableID: cfg.Doubao.BoostingTableID, } return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { client, err := asr.Dial(asrCfg, resultCh)