feat: add Doubao ASR client and paste module

This commit is contained in:
2026-03-01 03:03:46 +08:00
parent 39e56e5acc
commit 35032c1777
3 changed files with 456 additions and 0 deletions

165
internal/asr/client.go Normal file
View File

@@ -0,0 +1,165 @@
package asr
import (
"fmt"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/fasthttp/websocket"
"github.com/google/uuid"
wsMsg "github.com/imbytecat/voicepaste/internal/ws"
)
const (
doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
writeTimeout = 10 * time.Second
readTimeout = 30 * time.Second
)
// Config holds Doubao ASR connection parameters.
type Config struct {
AppKey string
AccessKey string
ResourceID string
}
// Client manages a single ASR session with Doubao.
type Client struct {
cfg Config
conn *websocket.Conn
seq atomic.Int32
mu sync.Mutex
closed bool
closeCh chan struct{}
log *slog.Logger
}
// Dial connects to Doubao ASR and sends the initial FullClientRequest.
// resultCh receives partial/final results. Caller must call Close() when done.
func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
connID := uuid.New().String()
headers := http.Header{
"X-Api-App-Key": {cfg.AppKey},
"X-Api-Access-Key": {cfg.AccessKey},
"X-Api-Resource-Id": {cfg.ResourceID},
"X-Api-Connect-Id": {connID},
}
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(doubaoEndpoint, headers)
if err != nil {
return nil, fmt.Errorf("dial doubao: %w", err)
}
c := &Client{
cfg: cfg,
conn: conn,
closeCh: make(chan struct{}),
log: slog.With("conn_id", connID),
}
// Send FullClientRequest
req := &FullClientRequest{
User: UserMeta{UID: connID},
Audio: AudioMeta{
Format: "raw",
Codec: "pcm",
Rate: 16000,
Bits: 16,
Channel: 1,
},
Request: RequestMeta{
ModelName: "seedasr-2.0",
EnableITN: true,
EnablePUNC: true,
EnableDDC: true,
ShowUtterances: true,
},
}
c.seq.Store(1)
data, err := EncodeFullClientRequest(req, c.seq.Load())
if err != nil {
conn.Close()
return nil, fmt.Errorf("encode full request: %w", err)
}
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
conn.Close()
return nil, fmt.Errorf("send full request: %w", err)
}
// Start read loop
go c.readLoop(resultCh)
return c, nil
}
// SendAudio sends a PCM audio frame to Doubao.
func (c *Client) SendAudio(pcm []byte, last bool) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("client closed")
}
seq := c.seq.Add(1)
data, err := EncodeAudioFrame(seq, pcm, last)
if err != nil {
return fmt.Errorf("encode audio: %w", err)
}
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return c.conn.WriteMessage(websocket.BinaryMessage, data)
}
// readLoop reads server responses and forwards them to resultCh.
func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) {
defer func() {
c.mu.Lock()
c.closed = true
c.mu.Unlock()
close(c.closeCh)
}()
for {
_ = c.conn.SetReadDeadline(time.Now().Add(readTimeout))
_, data, err := c.conn.ReadMessage()
if err != nil {
c.log.Debug("asr read done", "err", err)
return
}
resp, err := ParseResponse(data)
if err != nil {
c.log.Warn("parse asr response", "err", err)
continue
}
if resp.Code != 0 {
c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg)
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg}
return
}
// Determine if this is a final result by checking utterances
isFinal := false
text := resp.Text
for _, u := range resp.Utterances {
if u.Definite {
isFinal = true
text = u.Text
break
}
}
if isFinal {
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text}
} else if text != "" {
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text}
}
if resp.IsLast {
return
}
}
}
// Close shuts down the ASR connection.
func (c *Client) Close() {
c.mu.Lock()
if !c.closed {
c.conn.Close()
}
c.mu.Unlock()
// Wait for readLoop to finish
<-c.closeCh
}

236
internal/asr/protocol.go Normal file
View File

@@ -0,0 +1,236 @@
// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech
// recognition client using the custom binary WebSocket protocol.
package asr
import (
"bytes"
"compress/gzip"
"encoding/binary"
"encoding/json"
"fmt"
"io"
)
// ── Protocol constants ──
const protocolVersion = 0x01
const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes
// MessageType (upper nibble of byte 1)
type MessageType uint8
const (
MsgFullClientRequest MessageType = 0x1
MsgAudioOnlyRequest MessageType = 0x2
MsgFullServerResponse MessageType = 0x9
MsgServerError MessageType = 0xF
)
// MessageFlags (lower nibble of byte 1)
type MessageFlags uint8
const (
FlagNoSeq MessageFlags = 0x0
FlagPosSeq MessageFlags = 0x1
FlagLastNoSeq MessageFlags = 0x2
FlagNegSeq MessageFlags = 0x3 // last packet with seq
)
// Serialization (upper nibble of byte 2)
const (
SerNone uint8 = 0x0
SerJSON uint8 = 0x1
)
// Compression (lower nibble of byte 2)
const (
CompNone uint8 = 0x0
CompGzip uint8 = 0x1
)
// ── Header encoding ──
func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte {
return []byte{
(protocolVersion << 4) | headerSizeUnit,
(uint8(mt) << 4) | uint8(flags),
(ser << 4) | comp,
0x00, // reserved
}
}
// ── Gzip helpers ──
func gzipCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w := gzip.NewWriter(&buf)
if _, err := w.Write(data); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func gzipDecompress(data []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
defer r.Close()
return io.ReadAll(r)
}
// ── Request builders ──
// FullClientRequest is the JSON payload for the initial handshake.
type FullClientRequest struct {
User UserMeta `json:"user"`
Audio AudioMeta `json:"audio"`
Request RequestMeta `json:"request"`
}
type UserMeta struct {
UID string `json:"uid"`
}
type AudioMeta struct {
Format string `json:"format"`
Codec string `json:"codec"`
Rate int `json:"rate"`
Bits int `json:"bits"`
Channel int `json:"channel"`
}
type RequestMeta struct {
ModelName string `json:"model_name"`
EnableITN bool `json:"enable_itn"`
EnablePUNC bool `json:"enable_punc"`
EnableDDC bool `json:"enable_ddc"`
ShowUtterances bool `json:"show_utterances"`
}
// EncodeFullClientRequest builds the binary message for the initial handshake.
func EncodeFullClientRequest(req *FullClientRequest, seq int32) ([]byte, error) {
payloadJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
compressed, err := gzipCompress(payloadJSON)
if err != nil {
return nil, fmt.Errorf("gzip compress: %w", err)
}
var buf bytes.Buffer
buf.Write(encodeHeader(MsgFullClientRequest, FlagPosSeq, SerJSON, CompGzip))
_ = binary.Write(&buf, binary.BigEndian, seq)
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
buf.Write(compressed)
return buf.Bytes(), nil
}
// EncodeAudioFrame builds a binary audio-only request.
// If last is true, seq is sent as negative to signal end of stream.
func EncodeAudioFrame(seq int32, pcm []byte, last bool) ([]byte, error) {
flags := FlagPosSeq
wireSeq := seq
if last {
flags = FlagNegSeq
wireSeq = -seq
}
compressed, err := gzipCompress(pcm)
if err != nil {
return nil, fmt.Errorf("gzip audio: %w", err)
}
var buf bytes.Buffer
buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip))
_ = binary.Write(&buf, binary.BigEndian, wireSeq)
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
buf.Write(compressed)
return buf.Bytes(), nil
}
// ── Response parsing ──
// ServerResponse is the parsed response from Doubao.
type ServerResponse struct {
Code int
IsLast bool
Text string
Utterances []Utterance
ErrMsg string
}
type Utterance struct {
Text string `json:"text"`
Definite bool `json:"definite"`
}
type responsePayload struct {
Result struct {
Text string `json:"text"`
Utterances []Utterance `json:"utterances"`
} `json:"result"`
AudioInfo struct {
Duration int `json:"duration"`
} `json:"audio_info"`
}
// ParseResponse decodes a binary server response.
func ParseResponse(msg []byte) (*ServerResponse, error) {
if len(msg) < 4 {
return nil, fmt.Errorf("message too short: %d bytes", len(msg))
}
headerSize := int(msg[0]&0x0F) * 4
msgType := MessageType(msg[1] >> 4)
flags := MessageFlags(msg[1] & 0x0F)
compression := msg[2] & 0x0F
payload := msg[headerSize:]
resp := &ServerResponse{}
// Parse sequence if present
if flags&0x01 != 0 && len(payload) >= 4 {
payload = payload[4:] // skip sequence
}
if flags == FlagNegSeq || flags == FlagLastNoSeq {
resp.IsLast = true
}
// Parse event if flag bit 2 set
if flags&0x04 != 0 && len(payload) >= 4 {
payload = payload[4:] // skip event
}
switch msgType {
case MsgFullServerResponse:
if len(payload) < 4 {
return resp, nil
}
payloadSize := int(binary.BigEndian.Uint32(payload[:4]))
payload = payload[4:]
if payloadSize == 0 || len(payload) == 0 {
return resp, nil
}
if compression == CompGzip {
var err error
payload, err = gzipDecompress(payload)
if err != nil {
return nil, fmt.Errorf("decompress response: %w", err)
}
}
var rp responsePayload
if err := json.Unmarshal(payload, &rp); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
resp.Text = rp.Result.Text
resp.Utterances = rp.Result.Utterances
case MsgServerError:
if len(payload) < 8 {
return resp, nil
}
resp.Code = int(binary.BigEndian.Uint32(payload[:4]))
payloadSize := int(binary.BigEndian.Uint32(payload[4:8]))
payload = payload[8:]
if payloadSize > 0 && len(payload) > 0 {
if compression == CompGzip {
var err error
payload, err = gzipDecompress(payload)
if err != nil {
return nil, fmt.Errorf("decompress error: %w", err)
}
}
resp.ErrMsg = string(payload)
}
}
return resp, nil
}