From 48c8444b3f0a895605bc42561ed41b55e3c2068f Mon Sep 17 00:00:00 2001 From: imbytecat Date: Mon, 2 Mar 2026 04:36:22 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E7=BB=93=E6=9E=84=EF=BC=8C=E8=A7=A3=E8=80=A6=E7=83=AD?= =?UTF-8?q?=E8=AF=8D=E3=80=81=E7=BB=9F=E4=B8=80=E8=AE=A4=E8=AF=81=E3=80=81?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=20TLS=20=E5=BC=80=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 ASRConfig,热词从 doubao 提升为 provider 无关配置 - 移除 SecurityConfig,token 移入 ServerConfig - 移除 tls_auto 配置项,TLS 始终启用(getUserMedia 要求 HTTPS) - validate() 改为基于 provider 白名单验证,增加 resource_id 校验 - 简化 main.go:移除 scheme 变量和 HTTP 降级分支 - 更新 config.example.yaml 为新结构并修正环境变量前缀 --- config.example.yaml | 26 +++++++-------- internal/config/config.go | 69 +++++++++++++++++++++++---------------- main.go | 56 +++++++++++++++---------------- 3 files changed, 81 insertions(+), 70 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index e8005f1..c13977b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,22 +1,22 @@ # VoicePaste config -# Environment variables override these values (prefix: none, direct mapping) +# Environment variables override these values (prefix: VOICEPASTE_) -# 火山引擎豆包 ASR 配置 -doubao: - app_id: "" # env: DOUBAO_APP_ID - access_token: "" # env: DOUBAO_ACCESS_TOKEN - resource_id: "volc.seedasr.sauc.duration" # env: DOUBAO_RESOURCE_ID - hotwords: # 可选:本地热词列表 +# ASR 通用配置 +asr: + provider: doubao # env: VOICEPASTE_ASR_PROVIDER — ASR 引擎(目前支持: doubao) + hotwords: # 可选:热词列表,提升特定词汇识别准确率 # - 张三 # - 李四 # - VoicePaste # - 人工智能 +# 火山引擎豆包 ASR 凭证 +doubao: + app_id: "" # env: VOICEPASTE_DOUBAO_APP_ID + access_token: "" # env: VOICEPASTE_DOUBAO_ACCESS_TOKEN + resource_id: "volc.seedasr.sauc.duration" # env: VOICEPASTE_DOUBAO_RESOURCE_ID + # 服务配置 server: - port: 8443 # env: PORT - tls_auto: true # env: TLS_AUTO — 自动 TLS (AnyIP + 自签名降级) - -# 安全配置 -security: - token: "" # 留空则不需要认证;填写后访问需携带 token 参数 + port: 8443 # env: VOICEPASTE_SERVER_PORT + token: "" # env: VOICEPASTE_SERVER_TOKEN — 留空则不需要认证;填写后访问需携带 token 参数 diff --git a/internal/config/config.go b/internal/config/config.go index 7795327..47b260f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,45 +7,48 @@ import ( "strings" "sync" "sync/atomic" + "github.com/fsnotify/fsnotify" "github.com/spf13/viper" ) -// DoubaoConfig holds 火山引擎豆包 ASR credentials. -type DoubaoConfig struct { - AppID string `mapstructure:"app_id"` - AccessToken string `mapstructure:"access_token"` - ResourceID string `mapstructure:"resource_id"` - Hotwords []string `mapstructure:"hotwords"` // 本地热词列表 +// ASRConfig holds ASR settings independent of any specific provider. +type ASRConfig struct { + Provider string `mapstructure:"provider"` + Hotwords []string `mapstructure:"hotwords"` // 通用热词列表 } -// SecurityConfig holds authentication settings. -type SecurityConfig struct { - Token string `mapstructure:"token"` +// DoubaoConfig holds 火山引擎豆包 ASR credentials. +type DoubaoConfig struct { + AppID string `mapstructure:"app_id"` + AccessToken string `mapstructure:"access_token"` + ResourceID string `mapstructure:"resource_id"` } // ServerConfig holds server settings. type ServerConfig struct { - Port int `mapstructure:"port"` - TLSAuto bool `mapstructure:"tls_auto"` + Port int `mapstructure:"port"` + Token string `mapstructure:"token"` } // Config is the top-level configuration. type Config struct { - Doubao DoubaoConfig `mapstructure:"doubao"` - Server ServerConfig `mapstructure:"server"` - Security SecurityConfig `mapstructure:"security"` + ASR ASRConfig `mapstructure:"asr"` + Doubao DoubaoConfig `mapstructure:"doubao"` + Server ServerConfig `mapstructure:"server"` } // defaults returns a Config with default values. func defaults() Config { return Config{ + ASR: ASRConfig{ + Provider: "doubao", + }, Doubao: DoubaoConfig{ ResourceID: "volc.seedasr.sauc.duration", }, Server: ServerConfig{ - Port: 8443, - TLSAuto: true, + Port: 8443, }, } } @@ -70,9 +73,9 @@ func Load(path string) (Config, error) { // Set defaults def := defaults() + v.SetDefault("asr.provider", def.ASR.Provider) v.SetDefault("doubao.resource_id", def.Doubao.ResourceID) v.SetDefault("server.port", def.Server.Port) - v.SetDefault("server.tls_auto", def.Server.TLSAuto) // Allow env var overrides (e.g., VOICEPASTE_DOUBAO_APP_ID) v.SetEnvPrefix("voicepaste") @@ -98,13 +101,23 @@ func Load(path string) (Config, error) { store(cfg) return cfg, nil } -// validate checks required fields. + +// validate checks required fields based on the configured ASR provider. func validate(cfg Config) error { - if cfg.Doubao.AppID == "" { - return fmt.Errorf("doubao.app_id is required") - } - if cfg.Doubao.AccessToken == "" { - return fmt.Errorf("doubao.access_token is required") + provider := strings.TrimSpace(strings.ToLower(cfg.ASR.Provider)) + switch provider { + case "doubao": + if cfg.Doubao.AppID == "" { + return fmt.Errorf("doubao.app_id is required when asr.provider is \"doubao\"") + } + if cfg.Doubao.AccessToken == "" { + return fmt.Errorf("doubao.access_token is required when asr.provider is \"doubao\"") + } + if cfg.Doubao.ResourceID == "" { + return fmt.Errorf("doubao.resource_id is required when asr.provider is \"doubao\"") + } + default: + return fmt.Errorf("unsupported asr.provider: %q (supported: doubao)", cfg.ASR.Provider) } return nil } @@ -170,9 +183,9 @@ func WatchAndReload(path string) func() { v.SetConfigType("yaml") // Set defaults def := defaults() + v.SetDefault("asr.provider", def.ASR.Provider) v.SetDefault("doubao.resource_id", def.Doubao.ResourceID) v.SetDefault("server.port", def.Server.Port) - v.SetDefault("server.tls_auto", def.Server.TLSAuto) v.SetEnvPrefix("voicepaste") v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() @@ -247,8 +260,8 @@ func Get() Config { if v := global.Load(); v != nil { cfg := v.(Config) // Deep copy hotwords slice to prevent external modifications - if cfg.Doubao.Hotwords != nil { - cfg.Doubao.Hotwords = append([]string(nil), cfg.Doubao.Hotwords...) + if cfg.ASR.Hotwords != nil { + cfg.ASR.Hotwords = append([]string(nil), cfg.ASR.Hotwords...) } return cfg } @@ -259,8 +272,8 @@ func Get() Config { // Deep copies slices to ensure immutability. func store(cfg Config) { // Deep copy hotwords to prevent external modifications - if cfg.Doubao.Hotwords != nil { - cfg.Doubao.Hotwords = append([]string(nil), cfg.Doubao.Hotwords...) + if cfg.ASR.Hotwords != nil { + cfg.ASR.Hotwords = append([]string(nil), cfg.ASR.Hotwords...) } global.Store(cfg) } diff --git a/main.go b/main.go index 8711d82..59f86e9 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "github.com/imbytecat/voicepaste/internal/asr" "github.com/imbytecat/voicepaste/internal/config" "github.com/imbytecat/voicepaste/internal/paste" @@ -34,8 +35,8 @@ func main() { initClipboard() lanIPs := mustDetectLANIPs() lanIP := lanIPs[0] - tlsResult, scheme := setupTLS(cfg, lanIP) - printBanner(cfg, tlsResult, lanIPs, scheme) + tlsResult := mustSetupTLS(lanIP) + printBanner(cfg, tlsResult, lanIPs) srv := createServer(cfg, lanIP, tlsResult) runWithGracefulShutdown(srv) } @@ -70,59 +71,56 @@ func mustDetectLANIPs() []string { return lanIPs } -func setupTLS(cfg config.Config, lanIP string) (*vpTLS.Result, string) { - if !cfg.Server.TLSAuto { - return nil, "http" - } +func mustSetupTLS(lanIP string) *vpTLS.Result { tlsResult, err := vpTLS.GetTLSConfig(lanIP) if err != nil { slog.Error("TLS setup failed", "error", err) os.Exit(1) } - return tlsResult, "https" + return tlsResult } -func printBanner(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) { +func printBanner(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string) { fmt.Println() fmt.Println("╔══════════════════════════════════════╗") fmt.Println("║ VoicePaste 就绪 ║") fmt.Println("╚══════════════════════════════════════╝") fmt.Println() - printAddresses(cfg, tlsResult, lanIPs, scheme) - printCertInfo(tlsResult, cfg.Server.TLSAuto) - printAuthInfo(cfg.Security.Token) + printAddresses(cfg, tlsResult, lanIPs) + printCertInfo(tlsResult) + printAuthInfo(cfg.Server.Token) fmt.Println() fmt.Println(" 在手机浏览器中打开上方地址") fmt.Println(" 按 Ctrl+C 停止服务") fmt.Println() } -func printAddresses(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string, scheme string) { - token := cfg.Security.Token +func printAddresses(cfg config.Config, tlsResult *vpTLS.Result, lanIPs []string) { + token := cfg.Server.Token if len(lanIPs) == 1 { - host := lanIP(tlsResult, lanIPs[0]) - fmt.Printf(" 地址: %s\n", buildURL(scheme, host, cfg.Server.Port, token)) + host := lanIPHost(tlsResult, lanIPs[0]) + fmt.Printf(" 地址: %s\n", buildURL(host, cfg.Server.Port, token)) return } fmt.Println(" 地址:") for _, ip := range lanIPs { - host := lanIP(tlsResult, ip) - fmt.Printf(" - %s\n", buildURL(scheme, host, cfg.Server.Port, token)) + host := lanIPHost(tlsResult, ip) + fmt.Printf(" - %s\n", buildURL(host, cfg.Server.Port, token)) } } -func lanIP(tlsResult *vpTLS.Result, ip string) string { +func lanIPHost(tlsResult *vpTLS.Result, ip string) string { if tlsResult != nil { return vpTLS.AnyIPHost(ip) } return ip } -func printCertInfo(tlsResult *vpTLS.Result, tlsAuto bool) { +func printCertInfo(tlsResult *vpTLS.Result) { if tlsResult != nil { fmt.Println(" 证书: AnyIP(浏览器信任)") - } else if tlsAuto { - fmt.Println(" 证书: 配置错误(TLS 启用但未获取证书)") + } else { + fmt.Println(" 证书: 获取失败") } } @@ -140,22 +138,21 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser if tlsResult != nil { tlsConfig = tlsResult.Config } - srv := server.New(cfg.Security.Token, lanIP, webContent, tlsConfig) + srv := server.New(cfg.Server.Token, lanIP, webContent, tlsConfig) asrFactory := buildASRFactory() - wsHandler := ws.NewHandler(cfg.Security.Token, paste.Paste, asrFactory) + wsHandler := ws.NewHandler(cfg.Server.Token, paste.Paste, asrFactory) wsHandler.Register(srv.App()) return srv } func buildASRFactory() func(chan<- ws.ServerMsg) (func([]byte), func(), error) { return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { - // Read latest config on each new connection cfg := config.Get() asrCfg := asr.Config{ AppID: cfg.Doubao.AppID, AccessToken: cfg.Doubao.AccessToken, ResourceID: cfg.Doubao.ResourceID, - Hotwords: cfg.Doubao.Hotwords, + Hotwords: cfg.ASR.Hotwords, } client, err := asr.Dial(asrCfg, resultCh) if err != nil { @@ -186,9 +183,10 @@ func runWithGracefulShutdown(srv *server.Server) { os.Exit(1) } } -func buildURL(scheme, host string, port int, token string) string { + +func buildURL(host string, port int, token string) string { if token != "" { - return fmt.Sprintf("%s://%s:%d/?token=%s", scheme, host, port, token) + return fmt.Sprintf("https://%s:%d/?token=%s", host, port, token) } - return fmt.Sprintf("%s://%s:%d/", scheme, host, port) -} \ No newline at end of file + return fmt.Sprintf("https://%s:%d/", host, port) +}