diff --git a/config.example.yaml b/config.example.yaml index 334a1fa..2a690da 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -6,6 +6,12 @@ doubao: app_id: "" # env: DOUBAO_APP_ID access_token: "" # env: DOUBAO_ACCESS_TOKEN resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID + hotwords: # 可选:热词列表,格式 "词|权重" 或 "词"(默认权重 4) + # - 张三|8 + # - 李四|8 + # - VoicePaste|10 + # - 人工智能|6 + # - 测试 # 服务配置 server: diff --git a/internal/asr/client.go b/internal/asr/client.go index fa2b8b6..7f06db5 100644 --- a/internal/asr/client.go +++ b/internal/asr/client.go @@ -22,7 +22,8 @@ const ( type Config struct { AppID string AccessToken string - ResourceID string + ResourceID string + Hotwords []string // 热词列表,格式 "词|权重" 或 "词" } // Client manages a single ASR session with Doubao. @@ -57,6 +58,21 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { closeCh: make(chan struct{}), log: slog.With("conn_id", connID), } + // Parse hotwords configuration + var boostingTableName string + if len(cfg.Hotwords) > 0 { + entries, err := ParseHotwords(cfg.Hotwords) + if err != nil { + slog.Warn("invalid hotwords config, skipping", "err", err) + } else { + boostingTableName = GenerateTableName(entries) + tableContent := FormatHotwordsTable(entries) + slog.Info("hotwords enabled", + "count", len(entries), + "table_name", boostingTableName, + "content", tableContent) + } + } // Send FullClientRequest req := &FullClientRequest{ User: UserMeta{UID: connID}, @@ -68,14 +84,15 @@ func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { Channel: 1, }, Request: RequestMeta{ - ModelName: "seedasr-2.0", - EnableITN: true, - EnablePUNC: true, - EnableDDC: true, - ShowUtterances: true, - ResultType: "full", - EnableNonstream: true, - EndWindowSize: 800, + ModelName: "seedasr-2.0", + EnableITN: true, + EnablePUNC: true, + EnableDDC: true, + ShowUtterances: true, + ResultType: "full", + EnableNonstream: true, + EndWindowSize: 800, + BoostingTableName: boostingTableName, }, } data, err := EncodeFullClientRequest(req) diff --git a/internal/asr/hotwords.go b/internal/asr/hotwords.go new file mode 100644 index 0000000..751c457 --- /dev/null +++ b/internal/asr/hotwords.go @@ -0,0 +1,107 @@ +package asr + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strconv" + "strings" + "unicode/utf8" +) + +const ( + // DefaultHotwordWeight is the default weight when not specified. + DefaultHotwordWeight = 4 + // MaxHotwordWeight is the maximum allowed weight. + MaxHotwordWeight = 10 + // MinHotwordWeight is the minimum allowed weight. + MinHotwordWeight = 1 + // MaxHotwordLength is the maximum character count per hotword. + MaxHotwordLength = 10 +) + +// HotwordEntry represents a single hotword with its weight. +type HotwordEntry struct { + Word string + Weight int // 1-10, default 4 +} + +// ParseHotwords parses raw hotword strings from config. +// Format: "word|weight" or "word" (default weight 4). +// Returns error if any hotword is invalid. +func ParseHotwords(raw []string) ([]HotwordEntry, error) { + if len(raw) == 0 { + return nil, nil + } + + entries := make([]HotwordEntry, 0, len(raw)) + for i, line := range raw { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + parts := strings.Split(line, "|") + word := strings.TrimSpace(parts[0]) + + // Validate word length + if utf8.RuneCountInString(word) > MaxHotwordLength { + return nil, fmt.Errorf("hotword %d: exceeds %d characters: %q", i+1, MaxHotwordLength, word) + } + if word == "" { + return nil, fmt.Errorf("hotword %d: empty word", i+1) + } + + // Parse weight + weight := DefaultHotwordWeight + if len(parts) > 1 { + weightStr := strings.TrimSpace(parts[1]) + if weightStr != "" { + w, err := strconv.Atoi(weightStr) + if err != nil { + return nil, fmt.Errorf("hotword %d: invalid weight %q: %w", i+1, weightStr, err) + } + if w < MinHotwordWeight || w > MaxHotwordWeight { + return nil, fmt.Errorf("hotword %d: weight %d out of range [%d, %d]", i+1, w, MinHotwordWeight, MaxHotwordWeight) + } + weight = w + } + } + + entries = append(entries, HotwordEntry{ + Word: word, + Weight: weight, + }) + } + + return entries, nil +} + +// FormatHotwordsTable generates Doubao-compatible hotword table content. +// Format: each line "word|weight\n". +func FormatHotwordsTable(entries []HotwordEntry) string { + if len(entries) == 0 { + return "" + } + + var sb strings.Builder + for _, e := range entries { + sb.WriteString(e.Word) + sb.WriteString("|") + sb.WriteString(strconv.Itoa(e.Weight)) + sb.WriteString("\n") + } + return sb.String() +} + +// GenerateTableName generates a unique table name based on hotword content. +// Uses SHA256 hash of the formatted table content (first 16 hex chars). +func GenerateTableName(entries []HotwordEntry) string { + if len(entries) == 0 { + return "" + } + + content := FormatHotwordsTable(entries) + hash := sha256.Sum256([]byte(content)) + return "voicepaste_" + hex.EncodeToString(hash[:])[:16] +} diff --git a/internal/asr/protocol.go b/internal/asr/protocol.go index fad30e8..86c2cbc 100644 --- a/internal/asr/protocol.go +++ b/internal/asr/protocol.go @@ -111,7 +111,9 @@ type RequestMeta struct { ShowUtterances bool `json:"show_utterances"` ResultType string `json:"result_type,omitempty"` EnableNonstream bool `json:"enable_nonstream,omitempty"` - EndWindowSize int `json:"end_window_size,omitempty"` + EndWindowSize int `json:"end_window_size,omitempty"` + BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID + BoostingTableName string `json:"boosting_table_name,omitempty"` // 热词表名称 } // EncodeFullClientRequest builds the binary message for the initial handshake. // nostream mode: header(4) + payload_size(4) + gzip(json) diff --git a/internal/config/config.go b/internal/config/config.go index 25fd2dd..654f833 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,7 +8,8 @@ import ( type DoubaoConfig struct { AppID string `yaml:"app_id"` AccessToken string `yaml:"access_token"` - ResourceID string `yaml:"resource_id"` + ResourceID string `yaml:"resource_id"` + Hotwords []string `yaml:"hotwords"` // 热词列表,格式 "词|权重" 或 "词" } // SecurityConfig holds authentication settings. diff --git a/internal/ws/handler.go b/internal/ws/handler.go index c1bd4d1..c1fd683 100644 --- a/internal/ws/handler.go +++ b/internal/ws/handler.go @@ -133,6 +133,11 @@ func (h *Handler) handleStart(s *session) { if s.active { return } + // Future extension: support runtime dynamic hotwords (Phase 2) + // if len(msg.Hotwords) > 0 { + // // Priority: runtime hotwords > config.yaml hotwords + // // Need to modify asrFactory signature to pass msg.Hotwords + // } s.previewMu.Lock() s.previewText = "" s.previewMu.Unlock() diff --git a/internal/ws/protocol.go b/internal/ws/protocol.go index 45826c9..dbbef09 100644 --- a/internal/ws/protocol.go +++ b/internal/ws/protocol.go @@ -17,6 +17,8 @@ const ( type ClientMsg struct { Type MsgType `json:"type"` Text string `json:"text,omitempty"` // Only for "paste" + // Future extension: dynamic hotwords (Phase 2) + // Hotwords []string `json:"hotwords,omitempty"` } // ── Server → Client messages ── diff --git a/main.go b/main.go index 8581781..4e6f936 100644 --- a/main.go +++ b/main.go @@ -147,6 +147,7 @@ func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), AppID: cfg.Doubao.AppID, AccessToken: cfg.Doubao.AccessToken, ResourceID: cfg.Doubao.ResourceID, + Hotwords: cfg.Doubao.Hotwords, } return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { client, err := asr.Dial(asrCfg, resultCh)