Files
voicepaste/main.go
imbytecat 0720505ef6 refactor: 使用 Viper 重构配置管理并实现生产级热重载
- 引入 Viper 库替代手动 YAML 解析
- 实现基于 fsnotify 的配置文件热重载
- 使用 atomic.Value 保证并发安全的配置读写
- 添加配置验证(必填字段检查)
- 深拷贝 Hotwords 切片防止数据竞争
- 使用绝对路径匹配提升跨平台可靠性
- 支持启动失败后重试(不锁死状态)
- 提供 stop 函数正确清理 watcher 资源
- 通过 Oracle 多轮审计确认生产就绪
2026-03-02 02:57:47 +08:00

194 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
crypto_tls "crypto/tls"
"embed"
"fmt"
"io/fs"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/imbytecat/voicepaste/internal/asr"
"github.com/imbytecat/voicepaste/internal/config"
"github.com/imbytecat/voicepaste/internal/paste"
"github.com/imbytecat/voicepaste/internal/server"
vpTLS "github.com/imbytecat/voicepaste/internal/tls"
"github.com/imbytecat/voicepaste/internal/ws"
)
//go:embed all:web/dist
var webFS embed.FS
var version = "dev"
func main() {
initLogger()
slog.Info("VoicePaste", "version", version)
cfg := mustLoadConfig()
stopWatch := config.WatchAndReload("")
defer func() {
if stopWatch != nil {
stopWatch()
}
}()
initClipboard()
lanIPs := mustDetectLANIPs()
lanIP := lanIPs[0]
tlsResult, scheme := setupTLS(cfg, lanIP)
printBanner(cfg, tlsResult, lanIPs, scheme)
srv := createServer(cfg, lanIP, tlsResult)
runWithGracefulShutdown(srv)
}
func initLogger() {
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
})))
}
func mustLoadConfig() config.Config {
cfg, err := config.Load("")
if err != nil {
slog.Error("failed to load config", "error", err)
os.Exit(1)
}
return cfg
}
func initClipboard() {
if err := paste.Init(); err != nil {
slog.Warn("clipboard init failed, paste will be unavailable", "err", err)
}
}
func mustDetectLANIPs() []string {
lanIPs, err := server.GetLANIPs()
if err != nil {
slog.Error("failed to detect LAN IP", "error", err)
os.Exit(1)
}
return lanIPs
}
func setupTLS(cfg config.Config, lanIP string) (*vpTLS.Result, string) {
if !cfg.Server.TLSAuto {
return nil, "http"
}
tlsResult, err := vpTLS.GetTLSConfig(lanIP)
if err != nil {
slog.Error("TLS setup failed", "error", err)
os.Exit(1)
}
return tlsResult, "https"
}
func printBanner(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
fmt.Println()
fmt.Println("╔══════════════════════════════════════╗")
fmt.Println("║ VoicePaste 就绪 ║")
fmt.Println("╚══════════════════════════════════════╝")
fmt.Println()
printAddresses(cfg, tlsResult, lanIPs, scheme)
printCertInfo(tlsResult, cfg.Server.TLSAuto)
printAuthInfo(cfg.Security.Token)
fmt.Println()
fmt.Println(" 在手机浏览器中打开上方地址")
fmt.Println(" 按 Ctrl+C 停止服务")
fmt.Println()
}
func printAddresses(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) {
token := cfg.Security.Token
if len(lanIPs) == 1 {
host := lanIP(tlsResult, lanIPs[0])
fmt.Printf(" 地址: %s\n", buildURL(scheme, host, cfg.Server.Port, token))
return
}
fmt.Println(" 地址:")
for _, ip := range lanIPs {
host := lanIP(tlsResult, ip)
fmt.Printf(" - %s\n", buildURL(scheme, host, cfg.Server.Port, token))
}
}
func lanIP(tlsResult *vpTLS.Result, ip string) string {
if tlsResult != nil {
return vpTLS.AnyIPHost(ip)
}
return ip
}
func printCertInfo(tlsResult *vpTLS.Result, tlsAuto bool) {
if tlsResult != nil {
fmt.Println(" 证书: AnyIP浏览器信任")
} else if tlsAuto {
fmt.Println(" 证书: 配置错误TLS 启用但未获取证书)")
}
}
func printAuthInfo(token string) {
if token != "" {
fmt.Println(" 认证: 已启用")
} else {
fmt.Println(" 认证: 未启用(无需 token")
}
}
func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *server.Server {
webContent, _ := fs.Sub(webFS, "web/dist")
var tlsConfig *crypto_tls.Config
if tlsResult != nil {
tlsConfig = tlsResult.Config
}
srv := server.New(cfg.Security.Token, lanIP, webContent, tlsConfig)
asrFactory := buildASRFactory()
wsHandler := ws.NewHandler(cfg.Security.Token, paste.Paste, asrFactory)
wsHandler.Register(srv.App())
return srv
}
func buildASRFactory() func(chan<- ws.ServerMsg) (func([]byte), func(), error) {
return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) {
// Read latest config on each new connection
cfg := config.Get()
asrCfg := asr.Config{
AppID: cfg.Doubao.AppID,
AccessToken: cfg.Doubao.AccessToken,
ResourceID: cfg.Doubao.ResourceID,
Hotwords: cfg.Doubao.Hotwords,
}
client, err := asr.Dial(asrCfg, resultCh)
if err != nil {
return nil, nil, err
}
sendAudio := func(pcm []byte) {
if err := client.SendAudio(pcm, false); err != nil {
slog.Warn("send audio to asr", "err", err)
}
}
cleanup := func() {
client.Finish()
}
return sendAudio, cleanup, nil
}
}
func runWithGracefulShutdown(srv *server.Server) {
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down...")
srv.Shutdown()
}()
if err := srv.Start(); err != nil {
slog.Error("server error", "error", err)
os.Exit(1)
}
}
func buildURL(scheme, host string, port int, token string) string {
if token != "" {
return fmt.Sprintf("%s://%s:%d/?token=%s", scheme, host, port, token)
}
return fmt.Sprintf("%s://%s:%d/", scheme, host, port)
}