Files
voicepaste/internal/asr/client.go

179 lines
4.1 KiB
Go

package asr
import (
"fmt"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/fasthttp/websocket"
"github.com/google/uuid"
wsMsg "github.com/imbytecat/voicepaste/internal/ws"
)
const (
doubaoEndpoint = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
writeTimeout = 10 * time.Second
readTimeout = 30 * time.Second
)
// Config holds Doubao ASR connection parameters.
type Config struct {
AppKey string
AccessKey string
ResourceID string
}
// Client manages a single ASR session with Doubao.
type Client struct {
cfg Config
conn *websocket.Conn
seq atomic.Int32
mu sync.Mutex
closed bool
closeCh chan struct{}
log *slog.Logger
}
// Dial connects to Doubao ASR and sends the initial FullClientRequest.
// resultCh receives partial/final results. Caller must call Close() when done.
func Dial(cfg Config, resultCh chan<- wsMsg.ServerMsg) (*Client, error) {
connID := uuid.New().String()
headers := http.Header{
"X-Api-App-Key": {cfg.AppKey},
"X-Api-Access-Key": {cfg.AccessKey},
"X-Api-Resource-Id": {cfg.ResourceID},
"X-Api-Connect-Id": {connID},
}
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(doubaoEndpoint, headers)
if err != nil {
return nil, fmt.Errorf("dial doubao: %w", err)
}
c := &Client{
cfg: cfg,
conn: conn,
closeCh: make(chan struct{}),
log: slog.With("conn_id", connID),
}
// Send FullClientRequest
req := &FullClientRequest{
User: UserMeta{UID: connID},
Audio: AudioMeta{
Format: "pcm",
Codec: "raw",
Rate: 16000,
Bits: 16,
Channel: 1,
},
Request: RequestMeta{
ModelName: "seedasr-2.0",
EnableITN: true,
EnablePUNC: true,
EnableDDC: true,
ShowUtterances: true,
},
}
c.seq.Store(1)
data, err := EncodeFullClientRequest(req, c.seq.Load())
if err != nil {
conn.Close()
return nil, fmt.Errorf("encode full request: %w", err)
}
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
conn.Close()
return nil, fmt.Errorf("send full request: %w", err)
}
// Start read loop
go c.readLoop(resultCh)
return c, nil
}
// SendAudio sends a PCM audio frame to Doubao.
func (c *Client) SendAudio(pcm []byte, last bool) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("client closed")
}
seq := c.seq.Add(1)
data, err := EncodeAudioFrame(seq, pcm, last)
if err != nil {
return fmt.Errorf("encode audio: %w", err)
}
_ = c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
return c.conn.WriteMessage(websocket.BinaryMessage, data)
}
// readLoop reads server responses and forwards them to resultCh.
func (c *Client) readLoop(resultCh chan<- wsMsg.ServerMsg) {
defer func() {
c.conn.Close()
c.mu.Lock()
c.closed = true
c.mu.Unlock()
close(c.closeCh)
}()
for {
_ = c.conn.SetReadDeadline(time.Now().Add(readTimeout))
_, data, err := c.conn.ReadMessage()
if err != nil {
c.log.Debug("asr read done", "err", err)
return
}
resp, err := ParseResponse(data)
if err != nil {
c.log.Warn("parse asr response", "err", err)
continue
}
if resp.Code != 0 {
c.log.Error("asr error", "code", resp.Code, "msg", resp.ErrMsg)
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgError, Message: resp.ErrMsg}
return
}
// Determine if this is a final result by checking utterances
isFinal := false
text := resp.Text
for _, u := range resp.Utterances {
if u.Definite {
isFinal = true
text = u.Text
break
}
}
if isFinal {
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgFinal, Text: text}
} else if text != "" {
resultCh <- wsMsg.ServerMsg{Type: wsMsg.MsgPartial, Text: text}
}
if resp.IsLast {
return
}
}
}
// Finish sends the last audio frame and waits for ASR to return final results.
func (c *Client) Finish() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.mu.Unlock()
_ = c.SendAudio(nil, true)
<-c.closeCh
}
// Close forcefully shuts down the ASR connection.
func (c *Client) Close() {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.conn.Close()
c.mu.Unlock()
<-c.closeCh
}