Files
voicepaste/internal/asr/protocol.go
imbytecat 96d685fdf2 feat: 添加豆包 ASR 热词功能支持
- 在 config.yaml 中添加 hotwords 配置项,支持本地管理热词列表
- 实现热词解析、格式化和表名生成工具(internal/asr/hotwords.go)
- 在 ASR 连接建立时自动将热词发送给豆包(boosting_table_name 参数)
- 支持热词权重配置(1-10,默认 4),格式:"词|权重" 或 "词"
- 支持配置热重载,修改热词后新连接自动生效
- 为未来动态热词功能预留扩展接口

热词格式示例:
  hotwords:
    - 张三|8
    - VoicePaste|10
    - 人工智能|6
2026-03-02 00:55:37 +08:00

239 lines
6.3 KiB
Go

// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech
// recognition client using the custom binary WebSocket protocol.
package asr
import (
"bytes"
"compress/gzip"
"encoding/binary"
"encoding/json"
"fmt"
"io"
)
// ── Protocol constants ──
const protocolVersion = 0x01
const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes
// MessageType (upper nibble of byte 1)
type MessageType uint8
const (
MsgFullClientRequest MessageType = 0x1
MsgAudioOnlyRequest MessageType = 0x2
MsgFullServerResponse MessageType = 0x9
MsgServerError MessageType = 0xF
)
// MessageFlags (lower nibble of byte 1)
type MessageFlags uint8
const (
FlagNoSeq MessageFlags = 0x0
FlagPosSeq MessageFlags = 0x1
FlagLastNoSeq MessageFlags = 0x2
FlagNegSeq MessageFlags = 0x3 // last packet with seq
)
// Serialization (upper nibble of byte 2)
const (
SerNone uint8 = 0x0
SerJSON uint8 = 0x1
)
// Compression (lower nibble of byte 2)
const (
CompNone uint8 = 0x0
CompGzip uint8 = 0x1
)
// ── Header encoding ──
func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte {
return []byte{
(protocolVersion << 4) | headerSizeUnit,
(uint8(mt) << 4) | uint8(flags),
(ser << 4) | comp,
0x00, // reserved
}
}
// ── Gzip helpers ──
func gzipCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w := gzip.NewWriter(&buf)
if _, err := w.Write(data); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func gzipDecompress(data []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
defer r.Close()
return io.ReadAll(r)
}
// ── Request builders ──
// FullClientRequest is the JSON payload for the initial handshake.
type FullClientRequest struct {
User UserMeta `json:"user"`
Audio AudioMeta `json:"audio"`
Request RequestMeta `json:"request"`
}
type UserMeta struct {
UID string `json:"uid"`
}
type AudioMeta struct {
Format string `json:"format"`
Codec string `json:"codec"`
Rate int `json:"rate"`
Bits int `json:"bits"`
Channel int `json:"channel"`
}
type RequestMeta struct {
ModelName string `json:"model_name"`
EnableITN bool `json:"enable_itn"`
EnablePUNC bool `json:"enable_punc"`
EnableDDC bool `json:"enable_ddc"`
ShowUtterances bool `json:"show_utterances"`
ResultType string `json:"result_type,omitempty"`
EnableNonstream bool `json:"enable_nonstream,omitempty"`
EndWindowSize int `json:"end_window_size,omitempty"`
BoostingTableID string `json:"boosting_table_id,omitempty"` // 热词表 ID
BoostingTableName string `json:"boosting_table_name,omitempty"` // 热词表名称
}
// EncodeFullClientRequest builds the binary message for the initial handshake.
// nostream mode: header(4) + payload_size(4) + gzip(json)
func EncodeFullClientRequest(req *FullClientRequest) ([]byte, error) {
payloadJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
compressed, err := gzipCompress(payloadJSON)
if err != nil {
return nil, fmt.Errorf("gzip compress: %w", err)
}
var buf bytes.Buffer
buf.Write(encodeHeader(MsgFullClientRequest, FlagNoSeq, SerJSON, CompGzip))
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
buf.Write(compressed)
return buf.Bytes(), nil
}
// EncodeAudioFrame builds a binary audio-only request.
// nostream mode: header(4) + payload_size(4) + gzip(pcm)
// last=true sets FlagLastNoSeq to signal end of stream.
func EncodeAudioFrame(pcm []byte, last bool) ([]byte, error) {
flags := FlagNoSeq
if last {
flags = FlagLastNoSeq
}
compressed, err := gzipCompress(pcm)
if err != nil {
return nil, fmt.Errorf("gzip audio: %w", err)
}
var buf bytes.Buffer
buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip))
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
buf.Write(compressed)
return buf.Bytes(), nil
}
// ── Response parsing ──
// ServerResponse is the parsed response from Doubao.
type ServerResponse struct {
Code int
IsLast bool
Text string
Utterances []Utterance
ErrMsg string
}
type Utterance struct {
Text string `json:"text"`
Definite bool `json:"definite"`
}
type responsePayload struct {
Result struct {
Text string `json:"text"`
Utterances []Utterance `json:"utterances"`
} `json:"result"`
AudioInfo struct {
Duration int `json:"duration"`
} `json:"audio_info"`
}
// ParseResponse decodes a binary server response.
func ParseResponse(msg []byte) (*ServerResponse, error) {
if len(msg) < 4 {
return nil, fmt.Errorf("message too short: %d bytes", len(msg))
}
headerSize := int(msg[0]&0x0F) * 4
msgType := MessageType(msg[1] >> 4)
flags := MessageFlags(msg[1] & 0x0F)
compression := msg[2] & 0x0F
payload := msg[headerSize:]
resp := &ServerResponse{}
// Parse sequence if present
if flags&0x01 != 0 && len(payload) >= 4 {
payload = payload[4:] // skip sequence
}
if flags == FlagNegSeq || flags == FlagLastNoSeq {
resp.IsLast = true
}
// Parse event if flag bit 2 set
if flags&0x04 != 0 && len(payload) >= 4 {
payload = payload[4:] // skip event
}
switch msgType {
case MsgFullServerResponse:
if len(payload) < 4 {
return resp, nil
}
payloadSize := int(binary.BigEndian.Uint32(payload[:4]))
payload = payload[4:]
if payloadSize == 0 || len(payload) == 0 {
return resp, nil
}
if compression == CompGzip {
var err error
payload, err = gzipDecompress(payload)
if err != nil {
return nil, fmt.Errorf("decompress response: %w", err)
}
}
var rp responsePayload
if err := json.Unmarshal(payload, &rp); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
resp.Text = rp.Result.Text
resp.Utterances = rp.Result.Utterances
case MsgServerError:
if len(payload) < 8 {
return resp, nil
}
resp.Code = int(binary.BigEndian.Uint32(payload[:4]))
payloadSize := int(binary.BigEndian.Uint32(payload[4:8]))
payload = payload[8:]
if payloadSize > 0 && len(payload) > 0 {
if compression == CompGzip {
var err error
payload, err = gzipDecompress(payload)
if err != nil {
return nil, fmt.Errorf("decompress error: %w", err)
}
}
resp.ErrMsg = string(payload)
}
}
return resp, nil
}