feat: add Doubao ASR client and paste module
This commit is contained in:
165
internal/asr/client.go
Normal file
165
internal/asr/client.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package asr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/websocket"
|
||||
"github.com/google/uuid"
|
||||
wsMsg "github.com/imbytecat/voicepaste/internal/ws"
|
||||
)
|
||||
|
||||
const (
|
||||
doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
|
||||
writeTimeout = 10 * time.Second
|
||||
readTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Config holds Doubao ASR connection parameters.
|
||||
type Config struct {
|
||||
AppKey string
|
||||
AccessKey string
|
||||
ResourceID string
|
||||
}
|
||||
|
||||
// Client manages a single ASR session with Doubao.
|
||||
type Client struct {
|
||||
cfg Config
|
||||
conn *websocket.Conn
|
||||
seq atomic.Int32
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
log *slog.Logger
|
||||
}
|
||||
// Dial connects to Doubao ASR and sends the initial FullClientRequest.
|
||||
// resultCh receives partial/final results. Caller must call Close() when done.
|
||||
func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
|
||||
connID := uuid.New().String()
|
||||
headers := http.Header{
|
||||
"X-Api-App-Key": {cfg.AppKey},
|
||||
"X-Api-Access-Key": {cfg.AccessKey},
|
||||
"X-Api-Resource-Id": {cfg.ResourceID},
|
||||
"X-Api-Connect-Id": {connID},
|
||||
}
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
conn, _, err := dialer.Dial(doubaoEndpoint, headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial doubao: %w", err)
|
||||
}
|
||||
c := &Client{
|
||||
cfg: cfg,
|
||||
conn: conn,
|
||||
closeCh: make(chan struct{}),
|
||||
log: slog.With("conn_id", connID),
|
||||
}
|
||||
// Send FullClientRequest
|
||||
req := &FullClientRequest{
|
||||
User: UserMeta{UID: connID},
|
||||
Audio: AudioMeta{
|
||||
Format: "raw",
|
||||
Codec: "pcm",
|
||||
Rate: 16000,
|
||||
Bits: 16,
|
||||
Channel: 1,
|
||||
},
|
||||
Request: RequestMeta{
|
||||
ModelName: "seedasr-2.0",
|
||||
EnableITN: true,
|
||||
EnablePUNC: true,
|
||||
EnableDDC: true,
|
||||
ShowUtterances: true,
|
||||
},
|
||||
}
|
||||
c.seq.Store(1)
|
||||
data, err := EncodeFullClientRequest(req, c.seq.Load())
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("encode full request: %w", err)
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("send full request: %w", err)
|
||||
}
|
||||
// Start read loop
|
||||
go c.readLoop(resultCh)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// SendAudio sends a PCM audio frame to Doubao.
|
||||
func (c *Client) SendAudio(pcm []byte, last bool) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return fmt.Errorf("client closed")
|
||||
}
|
||||
seq := c.seq.Add(1)
|
||||
data, err := EncodeAudioFrame(seq, pcm, last)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode audio: %w", err)
|
||||
}
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
return c.conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
// readLoop reads server responses and forwards them to resultCh.
|
||||
func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) {
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
c.closed = true
|
||||
c.mu.Unlock()
|
||||
close(c.closeCh)
|
||||
}()
|
||||
for {
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
_, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
c.log.Debug("asr read done", "err", err)
|
||||
return
|
||||
}
|
||||
resp, err := ParseResponse(data)
|
||||
if err != nil {
|
||||
c.log.Warn("parse asr response", "err", err)
|
||||
continue
|
||||
}
|
||||
if resp.Code != 0 {
|
||||
c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg)
|
||||
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg}
|
||||
return
|
||||
}
|
||||
// Determine if this is a final result by checking utterances
|
||||
isFinal := false
|
||||
text := resp.Text
|
||||
for _, u := range resp.Utterances {
|
||||
if u.Definite {
|
||||
isFinal = true
|
||||
text = u.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFinal {
|
||||
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text}
|
||||
} else if text != "" {
|
||||
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text}
|
||||
}
|
||||
if resp.IsLast {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Close shuts down the ASR connection.
|
||||
func (c *Client) Close() {
|
||||
c.mu.Lock()
|
||||
if !c.closed {
|
||||
c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
// Wait for readLoop to finish
|
||||
<-c.closeCh
|
||||
}
|
||||
236
internal/asr/protocol.go
Normal file
236
internal/asr/protocol.go
Normal file
@@ -0,0 +1,236 @@
|
||||
// Package asr implements the Doubao (豆包) Seed-ASR-2.0 streaming speech
|
||||
// recognition client using the custom binary WebSocket protocol.
|
||||
package asr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ── Protocol constants ──
|
||||
|
||||
const protocolVersion = 0x01
|
||||
const headerSizeUnit = 0x01 // header = 1 * 4 = 4 bytes
|
||||
|
||||
// MessageType (upper nibble of byte 1)
|
||||
type MessageType uint8
|
||||
|
||||
const (
|
||||
MsgFullClientRequest MessageType = 0x1
|
||||
MsgAudioOnlyRequest MessageType = 0x2
|
||||
MsgFullServerResponse MessageType = 0x9
|
||||
MsgServerError MessageType = 0xF
|
||||
)
|
||||
|
||||
// MessageFlags (lower nibble of byte 1)
|
||||
type MessageFlags uint8
|
||||
|
||||
const (
|
||||
FlagNoSeq MessageFlags = 0x0
|
||||
FlagPosSeq MessageFlags = 0x1
|
||||
FlagLastNoSeq MessageFlags = 0x2
|
||||
FlagNegSeq MessageFlags = 0x3 // last packet with seq
|
||||
)
|
||||
|
||||
// Serialization (upper nibble of byte 2)
|
||||
const (
|
||||
SerNone uint8 = 0x0
|
||||
SerJSON uint8 = 0x1
|
||||
)
|
||||
|
||||
// Compression (lower nibble of byte 2)
|
||||
const (
|
||||
CompNone uint8 = 0x0
|
||||
CompGzip uint8 = 0x1
|
||||
)
|
||||
|
||||
// ── Header encoding ──
|
||||
|
||||
func encodeHeader(mt MessageType, flags MessageFlags, ser, comp uint8) []byte {
|
||||
return []byte{
|
||||
(protocolVersion << 4) | headerSizeUnit,
|
||||
(uint8(mt) << 4) | uint8(flags),
|
||||
(ser << 4) | comp,
|
||||
0x00, // reserved
|
||||
}
|
||||
}
|
||||
|
||||
// ── Gzip helpers ──
|
||||
|
||||
func gzipCompress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
w := gzip.NewWriter(&buf)
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func gzipDecompress(data []byte) ([]byte, error) {
|
||||
r, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
return io.ReadAll(r)
|
||||
}
|
||||
|
||||
// ── Request builders ──
|
||||
|
||||
// FullClientRequest is the JSON payload for the initial handshake.
|
||||
type FullClientRequest struct {
|
||||
User UserMeta `json:"user"`
|
||||
Audio AudioMeta `json:"audio"`
|
||||
Request RequestMeta `json:"request"`
|
||||
}
|
||||
|
||||
type UserMeta struct {
|
||||
UID string `json:"uid"`
|
||||
}
|
||||
|
||||
type AudioMeta struct {
|
||||
Format string `json:"format"`
|
||||
Codec string `json:"codec"`
|
||||
Rate int `json:"rate"`
|
||||
Bits int `json:"bits"`
|
||||
Channel int `json:"channel"`
|
||||
}
|
||||
|
||||
type RequestMeta struct {
|
||||
ModelName string `json:"model_name"`
|
||||
EnableITN bool `json:"enable_itn"`
|
||||
EnablePUNC bool `json:"enable_punc"`
|
||||
EnableDDC bool `json:"enable_ddc"`
|
||||
ShowUtterances bool `json:"show_utterances"`
|
||||
}
|
||||
// EncodeFullClientRequest builds the binary message for the initial handshake.
|
||||
func EncodeFullClientRequest(req *FullClientRequest, seq int32) ([]byte, error) {
|
||||
payloadJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
compressed, err := gzipCompress(payloadJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip compress: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.Write(encodeHeader(MsgFullClientRequest, FlagPosSeq, SerJSON, CompGzip))
|
||||
_ = binary.Write(&buf, binary.BigEndian, seq)
|
||||
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
||||
buf.Write(compressed)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
// EncodeAudioFrame builds a binary audio-only request.
|
||||
// If last is true, seq is sent as negative to signal end of stream.
|
||||
func EncodeAudioFrame(seq int32, pcm []byte, last bool) ([]byte, error) {
|
||||
flags := FlagPosSeq
|
||||
wireSeq := seq
|
||||
if last {
|
||||
flags = FlagNegSeq
|
||||
wireSeq = -seq
|
||||
}
|
||||
compressed, err := gzipCompress(pcm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip audio: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.Write(encodeHeader(MsgAudioOnlyRequest, flags, SerNone, CompGzip))
|
||||
_ = binary.Write(&buf, binary.BigEndian, wireSeq)
|
||||
_ = binary.Write(&buf, binary.BigEndian, int32(len(compressed)))
|
||||
buf.Write(compressed)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
// ── Response parsing ──
|
||||
// ServerResponse is the parsed response from Doubao.
|
||||
type ServerResponse struct {
|
||||
Code int
|
||||
IsLast bool
|
||||
Text string
|
||||
Utterances []Utterance
|
||||
ErrMsg string
|
||||
}
|
||||
type Utterance struct {
|
||||
Text string `json:"text"`
|
||||
Definite bool `json:"definite"`
|
||||
}
|
||||
type responsePayload struct {
|
||||
Result struct {
|
||||
Text string `json:"text"`
|
||||
Utterances []Utterance `json:"utterances"`
|
||||
} `json:"result"`
|
||||
AudioInfo struct {
|
||||
Duration int `json:"duration"`
|
||||
} `json:"audio_info"`
|
||||
}
|
||||
// ParseResponse decodes a binary server response.
|
||||
func ParseResponse(msg []byte) (*ServerResponse, error) {
|
||||
if len(msg) < 4 {
|
||||
return nil, fmt.Errorf("message too short: %d bytes", len(msg))
|
||||
}
|
||||
headerSize := int(msg[0]&0x0F) * 4
|
||||
msgType := MessageType(msg[1] >> 4)
|
||||
flags := MessageFlags(msg[1] & 0x0F)
|
||||
compression := msg[2] & 0x0F
|
||||
payload := msg[headerSize:]
|
||||
resp := &ServerResponse{}
|
||||
// Parse sequence if present
|
||||
if flags&0x01 != 0 && len(payload) >= 4 {
|
||||
payload = payload[4:] // skip sequence
|
||||
}
|
||||
if flags == FlagNegSeq || flags == FlagLastNoSeq {
|
||||
resp.IsLast = true
|
||||
}
|
||||
// Parse event if flag bit 2 set
|
||||
if flags&0x04 != 0 && len(payload) >= 4 {
|
||||
payload = payload[4:] // skip event
|
||||
}
|
||||
switch msgType {
|
||||
case MsgFullServerResponse:
|
||||
if len(payload) < 4 {
|
||||
return resp, nil
|
||||
}
|
||||
payloadSize := int(binary.BigEndian.Uint32(payload[:4]))
|
||||
payload = payload[4:]
|
||||
if payloadSize == 0 || len(payload) == 0 {
|
||||
return resp, nil
|
||||
}
|
||||
if compression == CompGzip {
|
||||
var err error
|
||||
payload, err = gzipDecompress(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress response: %w", err)
|
||||
}
|
||||
}
|
||||
var rp responsePayload
|
||||
if err := json.Unmarshal(payload, &rp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
resp.Text = rp.Result.Text
|
||||
resp.Utterances = rp.Result.Utterances
|
||||
case MsgServerError:
|
||||
if len(payload) < 8 {
|
||||
return resp, nil
|
||||
}
|
||||
resp.Code = int(binary.BigEndian.Uint32(payload[:4]))
|
||||
payloadSize := int(binary.BigEndian.Uint32(payload[4:8]))
|
||||
payload = payload[8:]
|
||||
if payloadSize > 0 && len(payload) > 0 {
|
||||
if compression == CompGzip {
|
||||
var err error
|
||||
payload, err = gzipDecompress(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompress error: %w", err)
|
||||
}
|
||||
}
|
||||
resp.ErrMsg = string(payload)
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
55
internal/paste/paste.go
Normal file
55
internal/paste/paste.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Package paste handles writing text to the system clipboard and simulating
|
||||
// Ctrl+V / Cmd+V to paste into the currently focused application.
|
||||
package paste
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/go-vgo/robotgo"
|
||||
"golang.design/x/clipboard"
|
||||
)
|
||||
|
||||
const keyDelay = 50 * time.Millisecond
|
||||
|
||||
// Init initializes the clipboard subsystem. Must be called once at startup.
|
||||
func Init() error {
|
||||
return clipboard.Init()
|
||||
}
|
||||
|
||||
// Paste writes text to the clipboard and simulates a paste keystroke.
|
||||
// Falls back to clipboard-only if key simulation fails.
|
||||
func Paste(text string) error {
|
||||
// Write to clipboard
|
||||
clipboard.Write(clipboard.FmtText, []byte(text))
|
||||
|
||||
// Small delay to ensure clipboard is ready
|
||||
time.Sleep(keyDelay)
|
||||
|
||||
// Simulate paste keystroke
|
||||
if err := simulatePaste(); err != nil {
|
||||
slog.Warn("key simulation failed, text is in clipboard", "err", err)
|
||||
return fmt.Errorf("key simulation failed (text is in clipboard): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClipboardOnly writes text to clipboard without simulating a keystroke.
|
||||
func ClipboardOnly(text string) {
|
||||
clipboard.Write(clipboard.FmtText, []byte(text))
|
||||
}
|
||||
|
||||
func simulatePaste() error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
robotgo.KeyTap("v", "command")
|
||||
case "linux", "windows":
|
||||
robotgo.KeyTap("v", "control")
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user