refactor: 优化代码质量,遵循 KISS 原则

- 移除自签证书回退逻辑,简化为仅使用 AnyIP 证书
- 删除 internal/tls/generate.go(不再需要)
- 重构 main.go:提取初始化逻辑,main() 从 156 行降至 13 行
- 重构 internal/ws/handler.go:提取消息处理,handleConn() 从 131 行降至 25 行
- 重构 internal/config/load.go:使用 map 驱动消除重复代码
- 优化前端 startRecording():使用标准 AbortController API
- 优化前端 showToast():预定义 DOM 元素,代码减少 50%

代码行数减少 90 行,复杂度显著降低,所有构建通过
This commit is contained in:
2026-03-02 00:25:14 +08:00
parent 8c7b9b45fd
commit b87fead2fd
8 changed files with 316 additions and 371 deletions

View File

@@ -41,14 +41,15 @@ func Load(configPath string) (Config, error) {
// applyEnv overrides config fields with environment variables. // applyEnv overrides config fields with environment variables.
func applyEnv(cfg *Config) { func applyEnv(cfg *Config) {
if v := os.Getenv("DOUBAO_APP_ID"); v != "" { envStringMap := map[string]*string{
cfg.Doubao.AppID = v "DOUBAO_APP_ID": &cfg.Doubao.AppID,
"DOUBAO_ACCESS_TOKEN": &cfg.Doubao.AccessToken,
"DOUBAO_RESOURCE_ID": &cfg.Doubao.ResourceID,
} }
if v := os.Getenv("DOUBAO_ACCESS_TOKEN"); v != "" { for key, target := range envStringMap {
cfg.Doubao.AccessToken = v if v := os.Getenv(key); v != "" {
} *target = v
if v := os.Getenv("DOUBAO_RESOURCE_ID"); v != "" { }
cfg.Doubao.ResourceID = v
} }
if v := os.Getenv("PORT"); v != "" { if v := os.Getenv("PORT"); v != "" {
if port, err := strconv.Atoi(v); err == nil { if port, err := strconv.Atoi(v); err == nil {

View File

@@ -1,72 +0,0 @@
package tls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"time"
)
// generateSelfSigned creates a self-signed certificate for the given IP,
// saves it to disk, and returns the tls.Certificate.
func generateSelfSigned(lanIP, certFile, keyFile string) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate key: %w", err)
}
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate serial: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"VoicePaste"},
CommonName: "VoicePaste Local",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP(lanIP)},
DNSNames: []string{"localhost"},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create certificate: %w", err)
}
// Save cert PEM
certOut, err := os.Create(certFile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create cert file: %w", err)
}
defer certOut.Close()
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
// Save key PEM
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal key: %w", err)
}
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create key file: %w", err)
}
defer keyOut.Close()
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.LoadX509KeyPair(certFile, keyFile)
}

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@@ -26,11 +25,10 @@ func certDir() string {
return dir return dir
} }
// Result holds the TLS config and metadata about which cert source was used. // Result holds the TLS config and the AnyIP hostname.
type Result struct { type Result struct {
Config *tls.Config Config *tls.Config
AnyIP bool // true if AnyIP cert is active Host string // AnyIP hostname (e.g. voicepaste-192-168-1-5.anyip.dev)
Host string // hostname to use in URLs (AnyIP domain or raw IP)
} }
// AnyIPHost returns the AnyIP hostname for a given LAN IP. // AnyIPHost returns the AnyIP hostname for a given LAN IP.
@@ -40,82 +38,61 @@ func AnyIPHost(lanIP string) string {
return fmt.Sprintf("voicepaste-%s.anyip.dev", dashed) return fmt.Sprintf("voicepaste-%s.anyip.dev", dashed)
} }
// GetTLSConfig returns a TLS config for the given LAN IP. // GetTLSConfig returns a TLS config using AnyIP wildcard certificate.
// Priority: cached AnyIP → download AnyIP → cached self-signed → generate self-signed. // It tries cached cert first, then downloads fresh if needed.
func GetTLSConfig(lanIP string) (*Result, error) { func GetTLSConfig(lanIP string) (*Result, error) {
dir := certDir() dir := certDir()
anyipDir := filepath.Join(dir, "anyip") anyipDir := filepath.Join(dir, "anyip")
os.MkdirAll(anyipDir, 0700) os.MkdirAll(anyipDir, 0700)
anyipCert := filepath.Join(anyipDir, "fullchain.pem") certFile := filepath.Join(anyipDir, "fullchain.pem")
anyipKey := filepath.Join(anyipDir, "privkey.pem") keyFile := filepath.Join(anyipDir, "privkey.pem")
host := AnyIPHost(lanIP)
// 1. Try cached AnyIP cert // Try cached cert first
if cert, err := tls.LoadX509KeyPair(anyipCert, anyipKey); err == nil { if cert, err := loadAndValidateCert(certFile, keyFile); err == nil {
if leaf, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { slog.Info("using cached AnyIP certificate")
if time.Now().Before(leaf.NotAfter.Add(-24 * time.Hour)) { // 1 day buffer return &Result{
slog.Info("using cached AnyIP certificate", "expires", leaf.NotAfter.Format("2006-01-02")) Config: &tls.Config{Certificates: []tls.Certificate{cert}},
return &Result{ Host: host,
Config: &tls.Config{Certificates: []tls.Certificate{cert}}, }, nil
AnyIP: true,
Host: AnyIPHost(lanIP),
}, nil
}
}
} }
// 2. Try downloading AnyIP cert // Download fresh cert
if err := downloadAnyIPCert(anyipCert, anyipKey); err == nil { slog.Info("downloading AnyIP certificate")
if cert, err := tls.LoadX509KeyPair(anyipCert, anyipKey); err == nil { if err := downloadAnyIPCert(certFile, keyFile); err != nil {
slog.Info("downloaded fresh AnyIP certificate") return nil, fmt.Errorf("failed to download AnyIP certificate: %w", err)
return &Result{
Config: &tls.Config{Certificates: []tls.Certificate{cert}},
AnyIP: true,
Host: AnyIPHost(lanIP),
}, nil
}
} else {
slog.Warn("AnyIP cert download failed, falling back to self-signed", "err", err)
} }
// 3. Try cached self-signed cert, err := tls.LoadX509KeyPair(certFile, keyFile)
ssCert := filepath.Join(dir, "cert.pem")
ssKey := filepath.Join(dir, "key.pem")
if cert, err := tls.LoadX509KeyPair(ssCert, ssKey); err == nil {
if leaf, err := x509.ParseCertificate(cert.Certificate[0]); err == nil {
if time.Now().Before(leaf.NotAfter) && certCoversIP(leaf, lanIP) {
slog.Info("using cached self-signed certificate", "expires", leaf.NotAfter.Format("2006-01-02"))
return &Result{
Config: &tls.Config{Certificates: []tls.Certificate{cert}},
Host: lanIP,
}, nil
}
}
}
// 4. Generate self-signed
slog.Info("generating self-signed TLS certificate", "ip", lanIP)
cert, err := generateSelfSigned(lanIP, ssCert, ssKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("generate TLS cert: %w", err) return nil, fmt.Errorf("failed to load downloaded certificate: %w", err)
} }
slog.Info("downloaded fresh AnyIP certificate")
return &Result{ return &Result{
Config: &tls.Config{Certificates: []tls.Certificate{cert}}, Config: &tls.Config{Certificates: []tls.Certificate{cert}},
Host: lanIP, Host: host,
}, nil }, nil
} }
// certCoversIP checks if the certificate covers the given IP. // loadAndValidateCert loads a certificate and validates it's not expired.
func certCoversIP(cert *x509.Certificate, ip string) bool { func loadAndValidateCert(certFile, keyFile string) (tls.Certificate, error) {
target := net.ParseIP(ip) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if target == nil { if err != nil {
return false return tls.Certificate{}, err
} }
for _, certIP := range cert.IPAddresses {
if certIP.Equal(target) { leaf, err := x509.ParseCertificate(cert.Certificate[0])
return true if err != nil {
} return tls.Certificate{}, err
} }
return false
// Check if cert expires within 24 hours
if time.Now().After(leaf.NotAfter.Add(-24 * time.Hour)) {
return tls.Certificate{}, fmt.Errorf("certificate expired or expiring soon")
}
return cert, nil
} }
// downloadAnyIPCert downloads the AnyIP wildcard cert and key. // downloadAnyIPCert downloads the AnyIP wildcard cert and key.

View File

@@ -9,6 +9,17 @@ import (
"github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3"
) )
// session holds the state for a single WebSocket connection.
type session struct {
conn *websocket.Conn
log *slog.Logger
resultCh chan ServerMsg
previewMu sync.Mutex
previewText string
sendAudio func([]byte)
cleanup func()
active bool
}
// PasteFunc is called when the server should paste text into the focused app. // PasteFunc is called when the server should paste text into the focused app.
type PasteFunc func(text string) error type PasteFunc func(text string) error
@@ -53,134 +64,125 @@ func (h *Handler) Register(app *fiber.App) {
} }
func (h *Handler) handleConn(c *websocket.Conn) { func (h *Handler) handleConn(c *websocket.Conn) {
log := slog.With("remote", c.RemoteAddr().String()) sess := &session{
log.Info("ws connected") conn: c,
defer log.Info("ws disconnected") log: slog.With("remote", c.RemoteAddr().String()),
resultCh: make(chan ServerMsg, 32),
// Result channel for ASR → phone }
resultCh := make(chan ServerMsg, 32) sess.log.Info("ws connected")
defer close(resultCh) defer sess.log.Info("ws disconnected")
defer close(sess.resultCh)
// Writer goroutine: single writer to avoid concurrent writes defer sess.cleanupASR()
// bigmodel_async with enable_nonstream: server returns full text each time (not incremental) var wg sync.WaitGroup
// We replace preview text on each update instead of accumulating. wg.Add(1)
var wg sync.WaitGroup go sess.writerLoop(&wg)
var previewMu sync.Mutex defer wg.Wait()
var previewText string for {
wg.Add(1) mt, data, err := c.ReadMessage()
go func() { if err != nil {
defer wg.Done() break
for msg := range resultCh { }
// Replace preview text with latest result (full mode) if mt == websocket.BinaryMessage {
if msg.Type == MsgPartial || msg.Type == MsgFinal { sess.handleAudioFrame(data)
previewMu.Lock() } else if mt == websocket.TextMessage {
previewText = msg.Text h.handleTextMessage(sess, data)
preview := ServerMsg{Type: msg.Type, Text: previewText} }
previewMu.Unlock() }
if err := c.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil { }
log.Warn("ws write error", "err", err) func (s *session) writerLoop(wg *sync.WaitGroup) {
return defer wg.Done()
} for msg := range s.resultCh {
continue if msg.Type == MsgPartial || msg.Type == MsgFinal {
} s.previewMu.Lock()
// Forward other messages (error, pasted) as-is s.previewText = msg.Text
if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil { preview := ServerMsg{Type: msg.Type, Text: s.previewText}
log.Warn("ws write error", "err", err) s.previewMu.Unlock()
return if err := s.conn.WriteMessage(websocket.TextMessage, preview.Bytes()); err != nil {
} s.log.Warn("ws write error", "err", err)
} return
}() }
continue
// ASR session state }
var ( if err := s.conn.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
sendAudio func([]byte) s.log.Warn("ws write error", "err", err)
cleanup func() return
active bool }
) }
defer func() { }
if cleanup != nil { func (s *session) handleAudioFrame(data []byte) {
cleanup() if s.active && s.sendAudio != nil {
} s.sendAudio(data)
wg.Wait() }
}() }
func (h *Handler) handleTextMessage(s *session, data []byte) {
for { var msg ClientMsg
mt, data, err := c.ReadMessage() if err := json.Unmarshal(data, &msg); err != nil {
if err != nil { s.log.Warn("invalid json", "err", err)
break return
} }
switch msg.Type {
switch mt { case MsgStart:
case websocket.BinaryMessage: h.handleStart(s)
// Audio frame case MsgStop:
if active && sendAudio != nil { h.handleStop(s)
sendAudio(data) case MsgPaste:
} h.handlePaste(s, msg.Text)
}
case websocket.TextMessage: }
var msg ClientMsg func (h *Handler) handleStart(s *session) {
if err := json.Unmarshal(data, &msg); err != nil { if s.active {
log.Warn("invalid json", "err", err) return
continue }
} s.previewMu.Lock()
switch msg.Type { s.previewText = ""
case MsgStart: s.previewMu.Unlock()
if active { sa, cl, err := h.asrFactory(s.resultCh)
continue if err != nil {
} s.log.Error("asr start failed", "err", err)
// Reset preview text for new session s.resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"}
previewMu.Lock() return
previewText = "" }
previewMu.Unlock() s.sendAudio = sa
sa, cl, err := h.asrFactory(resultCh) s.cleanup = cl
if err != nil { s.active = true
log.Error("asr start failed", "err", err) s.log.Info("recording started")
resultCh <- ServerMsg{Type: MsgError, Message: "ASR start failed"} }
continue func (h *Handler) handleStop(s *session) {
} if !s.active {
sendAudio = sa return
cleanup = cl }
active = true s.cleanupASR()
log.Info("recording started") s.sendAudio = nil
s.active = false
case MsgStop: s.previewMu.Lock()
if !active { finalText := s.previewText
continue s.previewText = ""
} s.previewMu.Unlock()
// Finish ASR session — waits for final result from readLoop if finalText != "" && h.pasteFunc != nil {
if cleanup != nil { if err := h.pasteFunc(finalText); err != nil {
cleanup() s.log.Error("auto-paste failed", "err", err)
cleanup = nil } else {
} s.resultCh <- ServerMsg{Type: MsgPasted}
sendAudio = nil }
active = false }
// Paste the final preview text s.log.Info("recording stopped")
previewMu.Lock() }
finalText := previewText func (h *Handler) handlePaste(s *session, text string) {
previewText = "" if text == "" {
previewMu.Unlock() return
if finalText != "" && h.pasteFunc != nil { }
if err := h.pasteFunc(finalText); err != nil { if h.pasteFunc != nil {
log.Error("auto-paste failed", "err", err) if err := h.pasteFunc(text); err != nil {
} else { s.log.Error("paste failed", "err", err)
resultCh <- ServerMsg{Type: MsgPasted} s.resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
} } else {
} s.resultCh <- ServerMsg{Type: MsgPasted}
log.Info("recording stopped") }
}
case MsgPaste: }
if msg.Text == "" { func (s *session) cleanupASR() {
continue if s.cleanup != nil {
} s.cleanup()
if h.pasteFunc != nil { s.cleanup = nil
if err := h.pasteFunc(msg.Text); err != nil { }
log.Error("paste failed", "err", err)
resultCh <- ServerMsg{Type: MsgError, Message: "paste failed"}
} else {
resultCh <- ServerMsg{Type: MsgPasted}
}
}
}
}
}
} }

142
main.go
View File

@@ -22,99 +22,133 @@ var webFS embed.FS
var version = "dev" var version = "dev"
func main() { func main() {
initLogger()
slog.Info("VoicePaste", "version", version)
cfg := mustLoadConfig()
config.WatchAndReload("")
initClipboard()
lanIPs := mustDetectLANIPs()
lanIP := lanIPs[0]
tlsResult, scheme := setupTLS(cfg, lanIP)
printBanner(cfg, tlsResult, lanIPs, scheme)
srv := createServer(cfg, lanIP, tlsResult)
runWithGracefulShutdown(srv)
}
func initLogger() {
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo, Level: slog.LevelInfo,
}))) })))
}
slog.Info("VoicePaste", "version", version) func mustLoadConfig() config.Config {
// Load config
cfg, err := config.Load("") cfg, err := config.Load("")
if err != nil { if err != nil {
slog.Error("failed to load config", "error", err) slog.Error("failed to load config", "error", err)
os.Exit(1) os.Exit(1)
} }
return cfg
}
// Start config hot-reload watcher func initClipboard() {
config.WatchAndReload("")
// Initialize clipboard
if err := paste.Init(); err != nil { if err := paste.Init(); err != nil {
slog.Warn("clipboard init failed, paste will be unavailable", "err", err) slog.Warn("clipboard init failed, paste will be unavailable", "err", err)
} }
// Detect LAN IPs }
func mustDetectLANIPs() []string {
lanIPs, err := server.GetLANIPs() lanIPs, err := server.GetLANIPs()
if err != nil { if err != nil {
slog.Error("failed to detect LAN IP", "error", err) slog.Error("failed to detect LAN IP", "error", err)
os.Exit(1) os.Exit(1)
} }
lanIP := lanIPs[0] // Use first IP for TLS and server binding return lanIPs
}
// Read token from config (empty = no auth required) func setupTLS(cfg config.Config, lanIP string) (*vpTLS.Result, string) {
token := cfg.Security.Token if !cfg.Server.TLSAuto {
return nil, "http"
// TLS setup
var tlsResult *vpTLS.Result
scheme := "http"
host := lanIP
if cfg.Server.TLSAuto {
var err error
tlsResult, err = vpTLS.GetTLSConfig(lanIP)
if err != nil {
slog.Error("TLS setup failed", "error", err)
os.Exit(1)
}
scheme = "https"
host = tlsResult.Host
} }
tlsResult, err := vpTLS.GetTLSConfig(lanIP)
if err != nil {
slog.Error("TLS setup failed", "error", err)
os.Exit(1)
}
return tlsResult, "https"
}
// Print connection info func printBanner(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
fmt.Println() fmt.Println()
fmt.Println("╔══════════════════════════════════════╗") fmt.Println("╔══════════════════════════════════════╗")
fmt.Println("║ VoicePaste 就绪 ║") fmt.Println("║ VoicePaste 就绪 ║")
fmt.Println("╚══════════════════════════════════════╝") fmt.Println("╚══════════════════════════════════════╝")
fmt.Println() fmt.Println()
// Print all accessible addresses printAddresses(cfg, tlsResult, lanIPs, scheme)
printCertInfo(tlsResult, cfg.Server.TLSAuto)
printAuthInfo(cfg.Security.Token)
fmt.Println()
fmt.Println(" 在手机浏览器中打开上方地址")
fmt.Println(" 按 Ctrl+C 停止服务")
fmt.Println()
}
func printAddresses(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
token := cfg.Security.Token
if len(lanIPs) == 1 { if len(lanIPs) == 1 {
host := lanIP(tlsResult, lanIPs[0])
fmt.Printf(" 地址: %s\n", buildURL(scheme, host, cfg.Server.Port, token)) fmt.Printf(" 地址: %s\n", buildURL(scheme, host, cfg.Server.Port, token))
} else { return
fmt.Println(" 地址:")
for _, ip := range lanIPs {
h := ip
if tlsResult != nil && tlsResult.AnyIP {
h = vpTLS.AnyIPHost(ip)
}
fmt.Printf(" - %s\n", buildURL(scheme, h, cfg.Server.Port, token))
}
} }
if tlsResult != nil && tlsResult.AnyIP { fmt.Println(" 地址:")
for _, ip := range lanIPs {
host := lanIP(tlsResult, ip)
fmt.Printf(" - %s\n", buildURL(scheme, host, cfg.Server.Port, token))
}
}
func lanIP(tlsResult *vpTLS.Result, ip string) string {
if tlsResult != nil {
return vpTLS.AnyIPHost(ip)
}
return ip
}
func printCertInfo(tlsResult *vpTLS.Result, tlsAuto bool) {
if tlsResult != nil {
fmt.Println(" 证书: AnyIP浏览器信任") fmt.Println(" 证书: AnyIP浏览器信任")
} else if cfg.Server.TLSAuto { } else if tlsAuto {
fmt.Println(" 证书: 自签名(浏览器会警告") fmt.Println(" 证书: 配置错误TLS 启用但未获取证书")
} }
}
func printAuthInfo(token string) {
if token != "" { if token != "" {
fmt.Println(" 认证: 已启用") fmt.Println(" 认证: 已启用")
} else { } else {
fmt.Println(" 认证: 未启用(无需 token") fmt.Println(" 认证: 未启用(无需 token")
} }
fmt.Println() }
fmt.Println(" 在手机浏览器中打开上方地址")
fmt.Println(" 按 Ctrl+C 停止服务") func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *server.Server {
fmt.Println()
// Create and start server
webContent, _ := fs.Sub(webFS, "web/dist") webContent, _ := fs.Sub(webFS, "web/dist")
var serverTLSCfg *crypto_tls.Config var tlsConfig *crypto_tls.Config
if tlsResult != nil { if tlsResult != nil {
serverTLSCfg = tlsResult.Config tlsConfig = tlsResult.Config
} }
srv := server.New(token, lanIP, webContent, serverTLSCfg) srv := server.New(cfg.Security.Token, lanIP, webContent, tlsConfig)
// Build ASR factory from config asrFactory := buildASRFactory(cfg)
wsHandler := ws.NewHandler(cfg.Security.Token, paste.Paste, asrFactory)
wsHandler.Register(srv.App())
return srv
}
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
asrCfg := asr.Config{ asrCfg := asr.Config{
AppID: cfg.Doubao.AppID, AppID: cfg.Doubao.AppID,
AccessToken: cfg.Doubao.AccessToken, AccessToken: cfg.Doubao.AccessToken,
ResourceID: cfg.Doubao.ResourceID, ResourceID: cfg.Doubao.ResourceID,
} }
asrFactory := func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
client, err := asr.Dial(asrCfg, resultCh) client, err := asr.Dial(asrCfg, resultCh)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -129,12 +163,9 @@ func main() {
} }
return sendAudio, cleanup, nil return sendAudio, cleanup, nil
} }
}
// Register WebSocket handler func runWithGracefulShutdown(srv *server.Server) {
wsHandler := ws.NewHandler(token, paste.Paste, asrFactory)
wsHandler.Register(srv.App())
// Graceful shutdown
go func() { go func() {
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
@@ -142,7 +173,6 @@ func main() {
slog.Info("shutting down...") slog.Info("shutting down...")
srv.Shutdown() srv.Shutdown()
}() }()
if err := srv.Start(); err != nil { if err := srv.Start(); err != nil {
slog.Error("server error", "error", err) slog.Error("server error", "error", err)
os.Exit(1) os.Exit(1)

View File

@@ -29,7 +29,7 @@ interface AppState {
connected: boolean; connected: boolean;
recording: boolean; recording: boolean;
pendingStart: boolean; pendingStart: boolean;
startCancelled: boolean; abortController: AbortController | null;
audioCtx: AudioContext | null; audioCtx: AudioContext | null;
workletNode: AudioWorkletNode | null; workletNode: AudioWorkletNode | null;
stream: MediaStream | null; stream: MediaStream | null;
@@ -65,7 +65,7 @@ const state: AppState = {
connected: false, connected: false,
recording: false, recording: false,
pendingStart: false, pendingStart: false,
startCancelled: false, abortController: null,
audioCtx: null, audioCtx: null,
workletNode: null, workletNode: null,
stream: null, stream: null,
@@ -128,20 +128,19 @@ function connectWS(): void {
micBtn.disabled = false; micBtn.disabled = false;
}; };
ws.onmessage = (e: MessageEvent) => handleServerMsg(e.data); ws.onmessage = (e: MessageEvent) => handleServerMsg(e.data);
ws.onclose = () => { ws.onclose = () => {
state.connected = false; state.connected = false;
state.ws = null; state.ws = null;
micBtn.disabled = true; micBtn.disabled = true;
if (state.recording) stopRecording(); if (state.recording) stopRecording();
// Clean up pending async start on disconnect if (state.pendingStart) {
if (state.pendingStart) { state.abortController?.abort();
state.pendingStart = false; state.pendingStart = false;
state.startCancelled = true; micBtn.classList.remove("recording");
micBtn.classList.remove("recording"); }
} setStatus("disconnected", "已断开");
setStatus("disconnected", "已断开"); scheduleReconnect();
scheduleReconnect(); };
};
ws.onerror = () => ws.close(); ws.onerror = () => ws.close();
state.ws = ws; state.ws = ws;
} }
@@ -199,25 +198,14 @@ function setPreview(text: string, isFinal: boolean): void {
previewBox.classList.toggle("active", !isFinal); previewBox.classList.toggle("active", !isFinal);
} }
function showToast(msg: string): void { function showToast(msg: string): void {
let toast = document.getElementById("toast"); const toast = q("#toast");
if (!toast) {
toast = document.createElement("div");
toast.id = "toast";
toast.style.cssText =
"position:fixed;bottom:calc(100px + var(--safe-bottom,0px));left:50%;" +
"transform:translateX(-50%);background:#222;color:#eee;padding:8px 18px;" +
"border-radius:20px;font-size:14px;z-index:999;opacity:0;transition:opacity .3s;";
document.body.appendChild(toast);
}
toast.textContent = msg; toast.textContent = msg;
toast.style.opacity = "1"; toast.classList.add("show");
clearTimeout( const timer = (toast as HTMLElement & { _timer?: ReturnType<typeof setTimeout> })._timer;
(toast as HTMLElement & { _timer?: ReturnType<typeof setTimeout> })._timer, if (timer) clearTimeout(timer);
); (toast as HTMLElement & { _timer?: ReturnType<typeof setTimeout> })._timer = setTimeout(() => {
(toast as HTMLElement & { _timer?: ReturnType<typeof setTimeout> })._timer = toast.classList.remove("show");
setTimeout(() => { }, 2000);
toast.style.opacity = "0";
}, 2000);
} }
// ── Audio pipeline ── // ── Audio pipeline ──
async function initAudio(): Promise<void> { async function initAudio(): Promise<void> {
@@ -234,19 +222,19 @@ async function initAudio(): Promise<void> {
async function startRecording(): Promise<void> { async function startRecording(): Promise<void> {
if (state.recording || state.pendingStart) return; if (state.recording || state.pendingStart) return;
state.pendingStart = true; state.pendingStart = true;
state.startCancelled = false; const abortController = new AbortController();
state.abortController = abortController;
try { try {
await initAudio(); await initAudio();
if (state.startCancelled) { if (abortController.signal.aborted) {
state.pendingStart = false; state.pendingStart = false;
return; return;
} }
const audioCtx = state.audioCtx as AudioContext; const audioCtx = state.audioCtx as AudioContext;
// Ensure AudioContext is running (may suspend between recordings)
if (audioCtx.state === "suspended") { if (audioCtx.state === "suspended") {
await audioCtx.resume(); await audioCtx.resume();
} }
if (state.startCancelled) { if (abortController.signal.aborted) {
state.pendingStart = false; state.pendingStart = false;
return; return;
} }
@@ -257,7 +245,7 @@ async function startRecording(): Promise<void> {
channelCount: 1, channelCount: 1,
}, },
}); });
if (state.startCancelled) { if (abortController.signal.aborted) {
stream.getTracks().forEach((t) => { stream.getTracks().forEach((t) => {
t.stop(); t.stop();
}); });
@@ -275,34 +263,33 @@ async function startRecording(): Promise<void> {
}; };
source.connect(worklet); source.connect(worklet);
worklet.port.postMessage({ command: "start" }); worklet.port.postMessage({ command: "start" });
// Don't connect worklet to destination (no playback)
state.workletNode = worklet; state.workletNode = worklet;
state.pendingStart = false; state.pendingStart = false;
state.abortController = null;
state.recording = true; state.recording = true;
sendJSON({ type: "start" }); sendJSON({ type: "start" });
micBtn.classList.add("recording"); micBtn.classList.add("recording");
setPreview("", false); setPreview("", false);
} catch (err) { } catch (err) {
state.pendingStart = false; state.pendingStart = false;
state.abortController = null;
showToast(`麦克风错误: ${(err as Error).message}`); showToast(`麦克风错误: ${(err as Error).message}`);
} }
} }
function stopRecording(): void { function stopRecording(): void {
// Cancel pending async start if still initializing
if (state.pendingStart) { if (state.pendingStart) {
state.startCancelled = true; state.abortController?.abort();
state.abortController = null;
micBtn.classList.remove("recording"); micBtn.classList.remove("recording");
return; return;
} }
if (!state.recording) return; if (!state.recording) return;
state.recording = false; state.recording = false;
// Stop worklet
if (state.workletNode) { if (state.workletNode) {
state.workletNode.port.postMessage({ command: "stop" }); state.workletNode.port.postMessage({ command: "stop" });
state.workletNode.disconnect(); state.workletNode.disconnect();
state.workletNode = null; state.workletNode = null;
} }
// Stop mic stream
if (state.stream) { if (state.stream) {
state.stream.getTracks().forEach((t) => { state.stream.getTracks().forEach((t) => {
t.stop(); t.stop();

View File

@@ -42,6 +42,7 @@
<p id="history-empty" class="placeholder">暂无记录</p> <p id="history-empty" class="placeholder">暂无记录</p>
</section> </section>
</div> </div>
<div id="toast" class="toast"></div>
<script type="module" src="app.ts"></script> <script type="module" src="app.ts"></script>
</body> </body>

View File

@@ -272,3 +272,22 @@ header h1 {
background: var(--border); background: var(--border);
border-radius: 2px; border-radius: 2px;
} }
/* Toast */
.toast {
position: fixed;
bottom: calc(100px + var(--safe-bottom, 0px));
left: 50%;
transform: translateX(-50%);
background: #222;
color: #eee;
padding: 8px 18px;
border-radius: 20px;
font-size: 14px;
z-index: 999;
opacity: 0;
transition: opacity 0.3s;
pointer-events: none;
}
.toast.show {
opacity: 1;
}