- 切换到 bigmodel_async endpoint 并启用 enable_nonstream - 第一遍流式识别提供实时文字预览 - VAD 分句后自动触发第二遍非流式识别提升准确率 - 修改文本处理逻辑从累加改为替换(适配 full 模式) - 统一配置字段命名:app_key → app_id, access_key → access_token
237 lines
6.1 KiB
Go
237 lines
6.1 KiB
Go
// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech
|
|
// recognition client using the custom binary WebSocket protocol.
|
|
package asr
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
)
|
|
|
|
// ── Protocol constants ──
|
|
|
|
const protocolVersion = 0x01
|
|
const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes
|
|
|
|
// MessageType (upper nibble of byte 1)
|
|
type MessageType uint8
|
|
|
|
const (
|
|
MsgFullClientRequest MessageType = 0x1
|
|
MsgAudioOnlyRequest MessageType = 0x2
|
|
MsgFullServerResponse MessageType = 0x9
|
|
MsgServerError MessageType = 0xF
|
|
)
|
|
|
|
// MessageFlags (lower nibble of byte 1)
|
|
type MessageFlags uint8
|
|
|
|
const (
|
|
FlagNoSeq MessageFlags = 0x0
|
|
FlagPosSeq MessageFlags = 0x1
|
|
FlagLastNoSeq MessageFlags = 0x2
|
|
FlagNegSeq MessageFlags = 0x3 // last packet with seq
|
|
)
|
|
|
|
// Serialization (upper nibble of byte 2)
|
|
const (
|
|
SerNone uint8 = 0x0
|
|
SerJSON uint8 = 0x1
|
|
)
|
|
|
|
// Compression (lower nibble of byte 2)
|
|
const (
|
|
CompNone uint8 = 0x0
|
|
CompGzip uint8 = 0x1
|
|
)
|
|
|
|
// ── Header encoding ──
|
|
|
|
func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte {
|
|
return []byte{
|
|
(protocolVersion << 4) | headerSizeUnit,
|
|
(uint8(mt) << 4) | uint8(flags),
|
|
(ser << 4) | comp,
|
|
0x00, // reserved
|
|
}
|
|
}
|
|
|
|
// ── Gzip helpers ──
|
|
|
|
func gzipCompress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w := gzip.NewWriter(&buf)
|
|
if _, err := w.Write(data); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func gzipDecompress(data []byte) ([]byte, error) {
|
|
r, err := gzip.NewReader(bytes.NewReader(data))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer r.Close()
|
|
return io.ReadAll(r)
|
|
}
|
|
|
|
// ── Request builders ──
|
|
|
|
// FullClientRequest is the JSON payload for the initial handshake.
|
|
type FullClientRequest struct {
|
|
User UserMeta `json:"user"`
|
|
Audio AudioMeta `json:"audio"`
|
|
Request RequestMeta `json:"request"`
|
|
}
|
|
|
|
type UserMeta struct {
|
|
UID string `json:"uid"`
|
|
}
|
|
|
|
type AudioMeta struct {
|
|
Format string `json:"format"`
|
|
Codec string `json:"codec"`
|
|
Rate int `json:"rate"`
|
|
Bits int `json:"bits"`
|
|
Channel int `json:"channel"`
|
|
}
|
|
|
|
type RequestMeta struct {
|
|
ModelName string `json:"model_name"`
|
|
EnableITN bool `json:"enable_itn"`
|
|
EnablePUNC bool `json:"enable_punc"`
|
|
EnableDDC bool `json:"enable_ddc"`
|
|
ShowUtterances bool `json:"show_utterances"`
|
|
ResultType string `json:"result_type,omitempty"`
|
|
EnableNonstream bool `json:"enable_nonstream,omitempty"`
|
|
EndWindowSize int `json:"end_window_size,omitempty"`
|
|
}
|
|
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
|
// nostream mode: header(4) + payload_size(4) + gzip(json)
|
|
func EncodeFullClientRequest(req *FullClientRequest) ([]byte, error) {
|
|
payloadJSON, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
compressed, err := gzipCompress(payloadJSON)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gzip compress: %w", err)
|
|
}
|
|
var buf bytes.Buffer
|
|
buf.Write(encodeHeader(MsgFullClientRequest, FlagNoSeq, SerJSON, CompGzip))
|
|
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
|
buf.Write(compressed)
|
|
return buf.Bytes(), nil
|
|
}
|
|
// EncodeAudioFrame builds a binary audio-only request.
|
|
// nostream mode: header(4) + payload_size(4) + gzip(pcm)
|
|
// last=true sets FlagLastNoSeq to signal end of stream.
|
|
func EncodeAudioFrame(pcm []byte, last bool) ([]byte, error) {
|
|
flags := FlagNoSeq
|
|
if last {
|
|
flags = FlagLastNoSeq
|
|
}
|
|
compressed, err := gzipCompress(pcm)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("gzip audio: %w", err)
|
|
}
|
|
var buf bytes.Buffer
|
|
buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip))
|
|
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
|
buf.Write(compressed)
|
|
return buf.Bytes(), nil
|
|
}
|
|
// ── Response parsing ──
|
|
// ServerResponse is the parsed response from Doubao.
|
|
type ServerResponse struct {
|
|
Code int
|
|
IsLast bool
|
|
Text string
|
|
Utterances []Utterance
|
|
ErrMsg string
|
|
}
|
|
type Utterance struct {
|
|
Text string `json:"text"`
|
|
Definite bool `json:"definite"`
|
|
}
|
|
type responsePayload struct {
|
|
Result struct {
|
|
Text string `json:"text"`
|
|
Utterances []Utterance `json:"utterances"`
|
|
} `json:"result"`
|
|
AudioInfo struct {
|
|
Duration int `json:"duration"`
|
|
} `json:"audio_info"`
|
|
}
|
|
// ParseResponse decodes a binary server response.
|
|
func ParseResponse(msg []byte) (*ServerResponse, error) {
|
|
if len(msg) < 4 {
|
|
return nil, fmt.Errorf("message too short: %d bytes", len(msg))
|
|
}
|
|
headerSize := int(msg[0]&0x0F) * 4
|
|
msgType := MessageType(msg[1] >> 4)
|
|
flags := MessageFlags(msg[1] & 0x0F)
|
|
compression := msg[2] & 0x0F
|
|
payload := msg[headerSize:]
|
|
resp := &ServerResponse{}
|
|
// Parse sequence if present
|
|
if flags&0x01 != 0 && len(payload) >= 4 {
|
|
payload = payload[4:] // skip sequence
|
|
}
|
|
if flags == FlagNegSeq || flags == FlagLastNoSeq {
|
|
resp.IsLast = true
|
|
}
|
|
// Parse event if flag bit 2 set
|
|
if flags&0x04 != 0 && len(payload) >= 4 {
|
|
payload = payload[4:] // skip event
|
|
}
|
|
switch msgType {
|
|
case MsgFullServerResponse:
|
|
if len(payload) < 4 {
|
|
return resp, nil
|
|
}
|
|
payloadSize := int(binary.BigEndian.Uint32(payload[:4]))
|
|
payload = payload[4:]
|
|
if payloadSize == 0 || len(payload) == 0 {
|
|
return resp, nil
|
|
}
|
|
if compression == CompGzip {
|
|
var err error
|
|
payload, err = gzipDecompress(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decompress response: %w", err)
|
|
}
|
|
}
|
|
var rp responsePayload
|
|
if err := json.Unmarshal(payload, &rp); err != nil {
|
|
return nil, fmt.Errorf("unmarshal response: %w", err)
|
|
}
|
|
resp.Text = rp.Result.Text
|
|
resp.Utterances = rp.Result.Utterances
|
|
case MsgServerError:
|
|
if len(payload) < 8 {
|
|
return resp, nil
|
|
}
|
|
resp.Code = int(binary.BigEndian.Uint32(payload[:4]))
|
|
payloadSize := int(binary.BigEndian.Uint32(payload[4:8]))
|
|
payload = payload[8:]
|
|
if payloadSize > 0 && len(payload) > 0 {
|
|
if compression == CompGzip {
|
|
var err error
|
|
payload, err = gzipDecompress(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decompress error: %w", err)
|
|
}
|
|
}
|
|
resp.ErrMsg = string(payload)
|
|
}
|
|
}
|
|
return resp, nil
|
|
} |