Files
voicepaste/main.go
imbytecat 96d685fdf2 feat: 添加豆包 ASR 热词功能支持
- 在 config.yaml 中添加 hotwords 配置项,支持本地管理热词列表
- 实现热词解析、格式化和表名生成工具(internal/asr/hotwords.go)
- 在 ASR 连接建立时自动将热词发送给豆包(boosting_table_name 参数)
- 支持热词权重配置(1-10,默认 4),格式:"词|权重" 或 "词"
- 支持配置热重载,修改热词后新连接自动生效
- 为未来动态热词功能预留扩展接口

热词格式示例:
  hotwords:
    - 张三|8
    - VoicePaste|10
    - 人工智能|6
2026-03-02 00:55:37 +08:00

187 lines
5.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
crypto_tls "crypto/tls"
"embed"
"fmt"
"io/fs"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/imbytecat/voicepaste/internal/asr"
"github.com/imbytecat/voicepaste/internal/config"
"github.com/imbytecat/voicepaste/internal/paste"
"github.com/imbytecat/voicepaste/internal/server"
vpTLS "github.com/imbytecat/voicepaste/internal/tls"
"github.com/imbytecat/voicepaste/internal/ws"
)
//go:embed all:web/dist
var webFS embed.FS
var version = "dev"
func main() {
initLogger()
slog.Info("VoicePaste", "version", version)
cfg := mustLoadConfig()
config.WatchAndReload("")
initClipboard()
lanIPs := mustDetectLANIPs()
lanIP := lanIPs[0]
tlsResult, scheme := setupTLS(cfg, lanIP)
printBanner(cfg, tlsResult, lanIPs, scheme)
srv := createServer(cfg, lanIP, tlsResult)
runWithGracefulShutdown(srv)
}
func initLogger() {
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
})))
}
func mustLoadConfig() config.Config {
cfg, err := config.Load("")
if err != nil {
slog.Error("failed to load config", "error", err)
os.Exit(1)
}
return cfg
}
func initClipboard() {
if err := paste.Init(); err != nil {
slog.Warn("clipboard init failed, paste will be unavailable", "err", err)
}
}
func mustDetectLANIPs() []string {
lanIPs, err := server.GetLANIPs()
if err != nil {
slog.Error("failed to detect LAN IP", "error", err)
os.Exit(1)
}
return lanIPs
}
func setupTLS(cfg config.Config, lanIP string) (*vpTLS.Result, string) {
if !cfg.Server.TLSAuto {
return nil, "http"
}
tlsResult, err := vpTLS.GetTLSConfig(lanIP)
if err != nil {
slog.Error("TLS setup failed", "error", err)
os.Exit(1)
}
return tlsResult, "https"
}
func printBanner(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
fmt.Println()
fmt.Println("╔══════════════════════════════════════╗")
fmt.Println("║ VoicePaste 就绪 ║")
fmt.Println("╚══════════════════════════════════════╝")
fmt.Println()
printAddresses(cfg, tlsResult, lanIPs, scheme)
printCertInfo(tlsResult, cfg.Server.TLSAuto)
printAuthInfo(cfg.Security.Token)
fmt.Println()
fmt.Println(" 在手机浏览器中打开上方地址")
fmt.Println(" 按 Ctrl+C 停止服务")
fmt.Println()
}
func printAddresses(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
token := cfg.Security.Token
if len(lanIPs) == 1 {
host := lanIP(tlsResult, lanIPs[0])
fmt.Printf(" 地址: %s\n", buildURL(scheme, host, cfg.Server.Port, token))
return
}
fmt.Println(" 地址:")
for _, ip := range lanIPs {
host := lanIP(tlsResult, ip)
fmt.Printf(" - %s\n", buildURL(scheme, host, cfg.Server.Port, token))
}
}
func lanIP(tlsResult *vpTLS.Result, ip string) string {
if tlsResult != nil {
return vpTLS.AnyIPHost(ip)
}
return ip
}
func printCertInfo(tlsResult *vpTLS.Result, tlsAuto bool) {
if tlsResult != nil {
fmt.Println(" 证书: AnyIP浏览器信任")
} else if tlsAuto {
fmt.Println(" 证书: 配置错误TLS 启用但未获取证书)")
}
}
func printAuthInfo(token string) {
if token != "" {
fmt.Println(" 认证: 已启用")
} else {
fmt.Println(" 认证: 未启用(无需 token")
}
}
func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *server.Server {
webContent, _ := fs.Sub(webFS, "web/dist")
var tlsConfig *crypto_tls.Config
if tlsResult != nil {
tlsConfig = tlsResult.Config
}
srv := server.New(cfg.Security.Token, lanIP, webContent, tlsConfig)
asrFactory := buildASRFactory(cfg)
wsHandler := ws.NewHandler(cfg.Security.Token, paste.Paste, asrFactory)
wsHandler.Register(srv.App())
return srv
}
func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
asrCfg := asr.Config{
AppID: cfg.Doubao.AppID,
AccessToken: cfg.Doubao.AccessToken,
ResourceID: cfg.Doubao.ResourceID,
Hotwords: cfg.Doubao.Hotwords,
}
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
client, err := asr.Dial(asrCfg, resultCh)
if err != nil {
return nil, nil, err
}
sendAudio := func(pcm []byte) {
if err := client.SendAudio(pcm, false); err != nil {
slog.Warn("send audio to asr", "err", err)
}
}
cleanup := func() {
client.Finish()
}
return sendAudio, cleanup, nil
}
}
func runWithGracefulShutdown(srv *server.Server) {
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down...")
srv.Shutdown()
}()
if err := srv.Start(); err != nil {
slog.Error("server error", "error", err)
os.Exit(1)
}
}
func buildURL(scheme, host string, port int, token string) string {
if token != "" {
return fmt.Sprintf("%s://%s:%d/?token=%s", scheme, host, port, token)
}
return fmt.Sprintf("%s://%s:%d/", scheme, host, port)
}