feat: add Doubao ASR client and paste module
This commit is contained in:
236
internal/asr/protocol.go
Normal file
236
internal/asr/protocol.go
Normal file
@@ -0,0 +1,236 @@
|
||||
// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech
|
||||
// recognition client using the custom binary WebSocket protocol.
|
||||
package asr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ── Protocol constants ──
|
||||
|
||||
const protocolVersion = 0x01
|
||||
const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes
|
||||
|
||||
// MessageType (upper nibble of byte 1)
|
||||
type MessageType uint8
|
||||
|
||||
const (
|
||||
MsgFullClientRequest MessageType = 0x1
|
||||
MsgAudioOnlyRequest MessageType = 0x2
|
||||
MsgFullServerResponse MessageType = 0x9
|
||||
MsgServerError MessageType = 0xF
|
||||
)
|
||||
|
||||
// MessageFlags (lower nibble of byte 1)
|
||||
type MessageFlags uint8
|
||||
|
||||
const (
|
||||
FlagNoSeq MessageFlags = 0x0
|
||||
FlagPosSeq MessageFlags = 0x1
|
||||
FlagLastNoSeq MessageFlags = 0x2
|
||||
FlagNegSeq MessageFlags = 0x3 // last packet with seq
|
||||
)
|
||||
|
||||
// Serialization (upper nibble of byte 2)
|
||||
const (
|
||||
SerNone uint8 = 0x0
|
||||
SerJSON uint8 = 0x1
|
||||
)
|
||||
|
||||
// Compression (lower nibble of byte 2)
|
||||
const (
|
||||
CompNone uint8 = 0x0
|
||||
CompGzip uint8 = 0x1
|
||||
)
|
||||
|
||||
// ── Header encoding ──
|
||||
|
||||
func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte {
|
||||
return []byte{
|
||||
(protocolVersion << 4) | headerSizeUnit,
|
||||
(uint8(mt) << 4) | uint8(flags),
|
||||
(ser << 4) | comp,
|
||||
0x00, // reserved
|
||||
}
|
||||
}
|
||||
|
||||
// ── Gzip helpers ──
|
||||
|
||||
func gzipCompress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
w := gzip.NewWriter(&buf)
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func gzipDecompress(data []byte) ([]byte, error) {
|
||||
r, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
return io.ReadAll(r)
|
||||
}
|
||||
|
||||
// ── Request builders ──
|
||||
|
||||
// FullClientRequest is the JSON payload for the initial handshake.
|
||||
type FullClientRequest struct {
|
||||
User UserMeta `json:"user"`
|
||||
Audio AudioMeta `json:"audio"`
|
||||
Request RequestMeta `json:"request"`
|
||||
}
|
||||
|
||||
type UserMeta struct {
|
||||
UID string `json:"uid"`
|
||||
}
|
||||
|
||||
type AudioMeta struct {
|
||||
Format string `json:"format"`
|
||||
Codec string `json:"codec"`
|
||||
Rate int `json:"rate"`
|
||||
Bits int `json:"bits"`
|
||||
Channel int `json:"channel"`
|
||||
}
|
||||
|
||||
type RequestMeta struct {
|
||||
ModelName string `json:"model_name"`
|
||||
EnableITN bool `json:"enable_itn"`
|
||||
EnablePUNC bool `json:"enable_punc"`
|
||||
EnableDDC bool `json:"enable_ddc"`
|
||||
ShowUtterances bool `json:"show_utterances"`
|
||||
}
|
||||
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
||||
func EncodeFullClientRequest(req *FullClientRequest, seq int32) ([]byte, error) {
|
||||
payloadJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
compressed, err := gzipCompress(payloadJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip compress: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.Write(encodeHeader(MsgFullClientRequest, FlagPosSeq, SerJSON, CompGzip))
|
||||
_ = binary.Write(&buf, binary.BigEndian, seq)
|
||||
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
||||
buf.Write(compressed)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
// EncodeAudioFrame builds a binary audio-only request.
|
||||
// If last is true, seq is sent as negative to signal end of stream.
|
||||
func EncodeAudioFrame(seq int32, pcm []byte, last bool) ([]byte, error) {
|
||||
flags := FlagPosSeq
|
||||
wireSeq := seq
|
||||
if last {
|
||||
flags = FlagNegSeq
|
||||
wireSeq = -seq
|
||||
}
|
||||
compressed, err := gzipCompress(pcm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip audio: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip))
|
||||
_ = binary.Write(&buf, binary.BigEndian, wireSeq)
|
||||
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
||||
buf.Write(compressed)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
// ── Response parsing ──
|
||||
// ServerResponse is the parsed response from Doubao.
|
||||
type ServerResponse struct {
|
||||
Code int
|
||||
IsLast bool
|
||||
Text string
|
||||
Utterances []Utterance
|
||||
ErrMsg string
|
||||
}
|
||||
type Utterance struct {
|
||||
Text string `json:"text"`
|
||||
Definite bool `json:"definite"`
|
||||
}
|
||||
type responsePayload struct {
|
||||
Result struct {
|
||||
Text string `json:"text"`
|
||||
Utterances []Utterance `json:"utterances"`
|
||||
} `json:"result"`
|
||||
AudioInfo struct {
|
||||
Duration int `json:"duration"`
|
||||
} `json:"audio_info"`
|
||||
}
|
||||
// ParseResponse decodes a binary server response.
|
||||
func ParseResponse(msg []byte) (*ServerResponse, error) {
|
||||
if len(msg) < 4 {
|
||||
return nil, fmt.Errorf("message too short: %d bytes", len(msg))
|
||||
}
|
||||
headerSize := int(msg[0]&0x0F) * 4
|
||||
msgType := MessageType(msg[1] >> 4)
|
||||
flags := MessageFlags(msg[1] & 0x0F)
|
||||
compression := msg[2] & 0x0F
|
||||
payload := msg[headerSize:]
|
||||
resp := &ServerResponse{}
|
||||
// Parse sequence if present
|
||||
if flags&0x01 != 0 && len(payload) >= 4 {
|
||||
payload = payload[4:] // skip sequence
|
||||
}
|
||||
if flags == FlagNegSeq || flags == FlagLastNoSeq {
|
||||
resp.IsLast = true
|
||||
}
|
||||
// Parse event if flag bit 2 set
|
||||
if flags&0x04 != 0 && len(payload) >= 4 {
|
||||
payload = payload[4:] // skip event
|
||||
}
|
||||
switch msgType {
|
||||
case MsgFullServerResponse:
|
||||
if len(payload) < 4 {
|
||||
return resp, nil
|
||||
}
|
||||
payloadSize := int(binary.BigEndian.Uint32(payload[:4]))
|
||||
payload = payload[4:]
|
||||
if payloadSize == 0 || len(payload) == 0 {
|
||||
return resp, nil
|
||||
}
|
||||
if compression == CompGzip {
|
||||
var err error
|
||||
payload, err = gzipDecompress(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress response: %w", err)
|
||||
}
|
||||
}
|
||||
var rp responsePayload
|
||||
if err := json.Unmarshal(payload, &rp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
resp.Text = rp.Result.Text
|
||||
resp.Utterances = rp.Result.Utterances
|
||||
case MsgServerError:
|
||||
if len(payload) < 8 {
|
||||
return resp, nil
|
||||
}
|
||||
resp.Code = int(binary.BigEndian.Uint32(payload[:4]))
|
||||
payloadSize := int(binary.BigEndian.Uint32(payload[4:8]))
|
||||
payload = payload[8:]
|
||||
if payloadSize > 0 && len(payload) > 0 {
|
||||
if compression == CompGzip {
|
||||
var err error
|
||||
payload, err = gzipDecompress(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress error: %w", err)
|
||||
}
|
||||
}
|
||||
resp.ErrMsg = string(payload)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
Reference in New Issue
Block a user