From 35032c177758de305f0513885e0586504a44cf38 Mon Sep 17 00:00:00 2001 From: imbytecat Date: Sun, 1 Mar 2026 03:03:46 +0800 Subject: [PATCH] feat: add Doubao ASR client and paste module --- internal/asr/client.go | 165 +++++++++++++++++++++++++++ internal/asr/protocol.go | 236 +++++++++++++++++++++++++++++++++++++++ internal/paste/paste.go | 55 +++++++++ 3 files changed, 456 insertions(+) create mode 100644 internal/asr/client.go create mode 100644 internal/asr/protocol.go create mode 100644 internal/paste/paste.go diff --git a/internal/asr/client.go b/internal/asr/client.go new file mode 100644 index 0000000..bda6052 --- /dev/null +++ b/internal/asr/client.go @@ -0,0 +1,165 @@ +package asr + +import ( + "fmt" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/fasthttp/websocket" + "github.com/google/uuid" + wsMsg "github.com/imbytecat/voicepaste/internal/ws" +) + +const ( + doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async" + writeTimeout = 10 * time.Second + readTimeout = 30 * time.Second +) + +// Config holds Doubao ASR connection parameters. +type Config struct { + AppKey string + AccessKey string + ResourceID string +} + +// Client manages a single ASR session with Doubao. +type Client struct { + cfg Config + conn *websocket.Conn + seq atomic.Int32 + mu sync.Mutex + closed bool + closeCh chan struct{} + log *slog.Logger +} +// Dial connects to Doubao ASR and sends the initial FullClientRequest. +// resultCh receives partial/final results. Caller must call Close() when done. +func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) { + connID := uuid.New().String() + headers := http.Header{ + "X-Api-App-Key": {cfg.AppKey}, + "X-Api-Access-Key": {cfg.AccessKey}, + "X-Api-Resource-Id": {cfg.ResourceID}, + "X-Api-Connect-Id": {connID}, + } + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + conn, _, err := dialer.Dial(doubaoEndpoint, headers) + if err != nil { + return nil, fmt.Errorf("dial doubao: %w", err) + } + c := &Client{ + cfg: cfg, + conn: conn, + closeCh: make(chan struct{}), + log: slog.With("conn_id", connID), + } + // Send FullClientRequest + req := &FullClientRequest{ + User: UserMeta{UID: connID}, + Audio: AudioMeta{ + Format: "raw", + Codec: "pcm", + Rate: 16000, + Bits: 16, + Channel: 1, + }, + Request: RequestMeta{ + ModelName: "seedasr-2.0", + EnableITN: true, + EnablePUNC: true, + EnableDDC: true, + ShowUtterances: true, + }, + } + c.seq.Store(1) + data, err := EncodeFullClientRequest(req, c.seq.Load()) + if err != nil { + conn.Close() + return nil, fmt.Errorf("encode full request: %w", err) + } + _ = conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil { + conn.Close() + return nil, fmt.Errorf("send full request: %w", err) + } + // Start read loop + go c.readLoop(resultCh) + return c, nil +} + +// SendAudio sends a PCM audio frame to Doubao. +func (c *Client) SendAudio(pcm []byte, last bool) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return fmt.Errorf("client closed") + } + seq := c.seq.Add(1) + data, err := EncodeAudioFrame(seq, pcm, last) + if err != nil { + return fmt.Errorf("encode audio: %w", err) + } + _ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + return c.conn.WriteMessage(websocket.BinaryMessage, data) +} +// readLoop reads server responses and forwards them to resultCh. +func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) { + defer func() { + c.mu.Lock() + c.closed = true + c.mu.Unlock() + close(c.closeCh) + }() + for { + _ = c.conn.SetReadDeadline(time.Now().Add(readTimeout)) + _, data, err := c.conn.ReadMessage() + if err != nil { + c.log.Debug("asr read done", "err", err) + return + } + resp, err := ParseResponse(data) + if err != nil { + c.log.Warn("parse asr response", "err", err) + continue + } + if resp.Code != 0 { + c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg) + resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg} + return + } + // Determine if this is a final result by checking utterances + isFinal := false + text := resp.Text + for _, u := range resp.Utterances { + if u.Definite { + isFinal = true + text = u.Text + break + } + } + if isFinal { + resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text} + } else if text != "" { + resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text} + } + if resp.IsLast { + return + } + } +} +// Close shuts down the ASR connection. +func (c *Client) Close() { + c.mu.Lock() + if !c.closed { + c.conn.Close() + } + c.mu.Unlock() + // Wait for readLoop to finish + <-c.closeCh +} \ No newline at end of file diff --git a/internal/asr/protocol.go b/internal/asr/protocol.go new file mode 100644 index 0000000..cd6a340 --- /dev/null +++ b/internal/asr/protocol.go @@ -0,0 +1,236 @@ +// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech +// recognition client using the custom binary WebSocket protocol. +package asr + +import ( + "bytes" + "compress/gzip" + "encoding/binary" + "encoding/json" + "fmt" + "io" +) + +// ── Protocol constants ── + +const protocolVersion = 0x01 +const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes + +// MessageType (upper nibble of byte 1) +type MessageType uint8 + +const ( + MsgFullClientRequest MessageType = 0x1 + MsgAudioOnlyRequest MessageType = 0x2 + MsgFullServerResponse MessageType = 0x9 + MsgServerError MessageType = 0xF +) + +// MessageFlags (lower nibble of byte 1) +type MessageFlags uint8 + +const ( + FlagNoSeq MessageFlags = 0x0 + FlagPosSeq MessageFlags = 0x1 + FlagLastNoSeq MessageFlags = 0x2 + FlagNegSeq MessageFlags = 0x3 // last packet with seq +) + +// Serialization (upper nibble of byte 2) +const ( + SerNone uint8 = 0x0 + SerJSON uint8 = 0x1 +) + +// Compression (lower nibble of byte 2) +const ( + CompNone uint8 = 0x0 + CompGzip uint8 = 0x1 +) + +// ── Header encoding ── + +func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte { + return []byte{ + (protocolVersion << 4) | headerSizeUnit, + (uint8(mt) << 4) | uint8(flags), + (ser << 4) | comp, + 0x00, // reserved + } +} + +// ── Gzip helpers ── + +func gzipCompress(data []byte) ([]byte, error) { + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + if _, err := w.Write(data); err != nil { + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func gzipDecompress(data []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + defer r.Close() + return io.ReadAll(r) +} + +// ── Request builders ── + +// FullClientRequest is the JSON payload for the initial handshake. +type FullClientRequest struct { + User UserMeta `json:"user"` + Audio AudioMeta `json:"audio"` + Request RequestMeta `json:"request"` +} + +type UserMeta struct { + UID string `json:"uid"` +} + +type AudioMeta struct { + Format string `json:"format"` + Codec string `json:"codec"` + Rate int `json:"rate"` + Bits int `json:"bits"` + Channel int `json:"channel"` +} + +type RequestMeta struct { + ModelName string `json:"model_name"` + EnableITN bool `json:"enable_itn"` + EnablePUNC bool `json:"enable_punc"` + EnableDDC bool `json:"enable_ddc"` + ShowUtterances bool `json:"show_utterances"` +} +// EncodeFullClientRequest builds the binary message for the initial handshake. +func EncodeFullClientRequest(req *FullClientRequest, seq int32) ([]byte, error) { + payloadJSON, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + compressed, err := gzipCompress(payloadJSON) + if err != nil { + return nil, fmt.Errorf("gzip compress: %w", err) + } + var buf bytes.Buffer + buf.Write(encodeHeader(MsgFullClientRequest, FlagPosSeq, SerJSON, CompGzip)) + _ = binary.Write(&buf, binary.BigEndian, seq) + _ = binary.Write(&buf, binary.BigEndian, int32(len(compressed))) + buf.Write(compressed) + return buf.Bytes(), nil +} +// EncodeAudioFrame builds a binary audio-only request. +// If last is true, seq is sent as negative to signal end of stream. +func EncodeAudioFrame(seq int32, pcm []byte, last bool) ([]byte, error) { + flags := FlagPosSeq + wireSeq := seq + if last { + flags = FlagNegSeq + wireSeq = -seq + } + compressed, err := gzipCompress(pcm) + if err != nil { + return nil, fmt.Errorf("gzip audio: %w", err) + } + var buf bytes.Buffer + buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip)) + _ = binary.Write(&buf, binary.BigEndian, wireSeq) + _ = binary.Write(&buf, binary.BigEndian, int32(len(compressed))) + buf.Write(compressed) + return buf.Bytes(), nil +} +// ── Response parsing ── +// ServerResponse is the parsed response from Doubao. +type ServerResponse struct { + Code int + IsLast bool + Text string + Utterances []Utterance + ErrMsg string +} +type Utterance struct { + Text string `json:"text"` + Definite bool `json:"definite"` +} +type responsePayload struct { + Result struct { + Text string `json:"text"` + Utterances []Utterance `json:"utterances"` + } `json:"result"` + AudioInfo struct { + Duration int `json:"duration"` + } `json:"audio_info"` +} +// ParseResponse decodes a binary server response. +func ParseResponse(msg []byte) (*ServerResponse, error) { + if len(msg) < 4 { + return nil, fmt.Errorf("message too short: %d bytes", len(msg)) + } + headerSize := int(msg[0]&0x0F) * 4 + msgType := MessageType(msg[1] >> 4) + flags := MessageFlags(msg[1] & 0x0F) + compression := msg[2] & 0x0F + payload := msg[headerSize:] + resp := &ServerResponse{} + // Parse sequence if present + if flags&0x01 != 0 && len(payload) >= 4 { + payload = payload[4:] // skip sequence + } + if flags == FlagNegSeq || flags == FlagLastNoSeq { + resp.IsLast = true + } + // Parse event if flag bit 2 set + if flags&0x04 != 0 && len(payload) >= 4 { + payload = payload[4:] // skip event + } + switch msgType { + case MsgFullServerResponse: + if len(payload) < 4 { + return resp, nil + } + payloadSize := int(binary.BigEndian.Uint32(payload[:4])) + payload = payload[4:] + if payloadSize == 0 || len(payload) == 0 { + return resp, nil + } + if compression == CompGzip { + var err error + payload, err = gzipDecompress(payload) + if err != nil { + return nil, fmt.Errorf("decompress response: %w", err) + } + } + var rp responsePayload + if err := json.Unmarshal(payload, &rp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + resp.Text = rp.Result.Text + resp.Utterances = rp.Result.Utterances + case MsgServerError: + if len(payload) < 8 { + return resp, nil + } + resp.Code = int(binary.BigEndian.Uint32(payload[:4])) + payloadSize := int(binary.BigEndian.Uint32(payload[4:8])) + payload = payload[8:] + if payloadSize > 0 && len(payload) > 0 { + if compression == CompGzip { + var err error + payload, err = gzipDecompress(payload) + if err != nil { + return nil, fmt.Errorf("decompress error: %w", err) + } + } + resp.ErrMsg = string(payload) + } + } + return resp, nil +} \ No newline at end of file diff --git a/internal/paste/paste.go b/internal/paste/paste.go new file mode 100644 index 0000000..53bd985 --- /dev/null +++ b/internal/paste/paste.go @@ -0,0 +1,55 @@ +// Package paste handles writing text to the system clipboard and simulating +// Ctrl+V / Cmd+V to paste into the currently focused application. +package paste + +import ( + "fmt" + "log/slog" + "runtime" + "time" + + "github.com/go-vgo/robotgo" + "golang.design/x/clipboard" +) + +const keyDelay = 50 * time.Millisecond + +// Init initializes the clipboard subsystem. Must be called once at startup. +func Init() error { + return clipboard.Init() +} + +// Paste writes text to the clipboard and simulates a paste keystroke. +// Falls back to clipboard-only if key simulation fails. +func Paste(text string) error { + // Write to clipboard + clipboard.Write(clipboard.FmtText, []byte(text)) + + // Small delay to ensure clipboard is ready + time.Sleep(keyDelay) + + // Simulate paste keystroke + if err := simulatePaste(); err != nil { + slog.Warn("key simulation failed, text is in clipboard", "err", err) + return fmt.Errorf("key simulation failed (text is in clipboard): %w", err) + } + + return nil +} + +// ClipboardOnly writes text to clipboard without simulating a keystroke. +func ClipboardOnly(text string) { + clipboard.Write(clipboard.FmtText, []byte(text)) +} + +func simulatePaste() error { + switch runtime.GOOS { + case "darwin": + robotgo.KeyTap("v", "command") + case "linux", "windows": + robotgo.KeyTap("v", "control") + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + return nil +}