diff --git a/internal/config/config.go b/internal/config/config.go index 601772f..7795327 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,9 +1,12 @@ package config import ( + "fmt" "log/slog" + "path/filepath" "strings" - + "sync" + "sync/atomic" "github.com/fsnotify/fsnotify" "github.com/spf13/viper" ) @@ -47,6 +50,13 @@ func defaults() Config { } } +// global holds the current config atomically for concurrent reads. +var global atomic.Value +var watcher *fsnotify.Watcher +var watcherMu sync.Mutex +var watchStarted bool +var watchStartErr error + // Load reads config from file (or uses defaults if file doesn't exist). // Empty path defaults to "config.yaml". func Load(path string) (Config, error) { @@ -81,54 +91,176 @@ func Load(path string) (Config, error) { if err := v.Unmarshal(&cfg); err != nil { return Config{}, err } - + // Validate before storing + if err := validate(cfg); err != nil { + return Config{}, err + } + store(cfg) return cfg, nil } +// validate checks required fields. +func validate(cfg Config) error { + if cfg.Doubao.AppID == "" { + return fmt.Errorf("doubao.app_id is required") + } + if cfg.Doubao.AccessToken == "" { + return fmt.Errorf("doubao.access_token is required") + } + return nil +} // WatchAndReload starts watching config file for changes and reloads automatically. // Empty path defaults to "config.yaml". -func WatchAndReload(path string) { +// Returns a function to stop watching. Can only be called once. +func WatchAndReload(path string) func() { + watcherMu.Lock() + defer watcherMu.Unlock() + // Check if already started + if watchStarted { + if watchStartErr != nil { + return func() {} // Previous start failed, return no-op + } + // Already running, return existing stop function + return func() { + watcherMu.Lock() + defer watcherMu.Unlock() + if watcher != nil { + if err := watcher.Close(); err != nil { + slog.Warn("failed to close config watcher", "err", err) + } + watcher = nil + watchStarted = false + slog.Info("config watcher stopped") + } + } + } if path == "" { path = "config.yaml" } - + // Get absolute path for reliable matching + absPath, err := filepath.Abs(path) + if err != nil { + slog.Error("failed to resolve config path", "err", err) + watchStartErr = err + // Don't set watchStarted = true, allow retry + return func() {} + } + w, err := fsnotify.NewWatcher() + if err != nil { + slog.Error("failed to create config watcher", "err", err) + watchStartErr = err + // Don't set watchStarted = true, allow retry + return func() {} + } + watchDir := filepath.Dir(absPath) + if err := w.Add(watchDir); err != nil { + slog.Error("failed to watch config directory", "err", err, "dir", watchDir) + w.Close() + watchStartErr = err + // Don't set watchStarted = true, allow retry + return func() {} + } + // Assign to global watcher before marking as started + watcher = w + watchStarted = true + watchStartErr = nil + // Create Viper instance for reading config v := viper.New() - v.SetConfigFile(path) + v.SetConfigFile(absPath) // Use absolute path for consistency v.SetConfigType("yaml") - - // Set defaults (same as Load) + // Set defaults def := defaults() v.SetDefault("doubao.resource_id", def.Doubao.ResourceID) v.SetDefault("server.port", def.Server.Port) v.SetDefault("server.tls_auto", def.Server.TLSAuto) - v.SetEnvPrefix("voicepaste") v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() - // Initial read (ignore error if file doesn't exist) if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { slog.Warn("config watch: initial read failed", "err", err) + } else { + slog.Warn("config file not found, will watch for creation", "path", path) } } - - // Watch for changes - v.WatchConfig() - v.OnConfigChange(func(e fsnotify.Event) { - slog.Info("config file changed, reloading", "file", e.Name) - var cfg Config - if err := v.Unmarshal(&cfg); err != nil { - slog.Error("config reload failed", "err", err) - return + // Start event loop in goroutine + go func() { + for { + select { + case event, ok := <-w.Events: + if !ok { + return + } + // Normalize event path for comparison + eventPath := filepath.Clean(event.Name) + if eventPath != absPath { + continue + } + // Process Write, Create, and Rename events (common editor patterns) + if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create || event.Op&fsnotify.Rename == fsnotify.Rename { + slog.Info("config file changed, reloading", "file", absPath, "op", event.Op) + // Re-read the file + if err := v.ReadInConfig(); err != nil { + slog.Error("config reload: read failed", "err", err) + continue + } + var cfg Config + if err := v.Unmarshal(&cfg); err != nil { + slog.Error("config reload: unmarshal failed", "err", err) + continue + } + // Validate before applying + if err := validate(cfg); err != nil { + slog.Warn("config reload: validation failed, keeping old config", "err", err) + continue + } + store(cfg) + slog.Info("config reloaded and applied successfully") + } + case err, ok := <-w.Errors: + if !ok { + return + } + slog.Error("config watcher error", "err", err) + } } - slog.Info("config reloaded successfully") - }) + }() + // Return stop function + return func() { + watcherMu.Lock() + defer watcherMu.Unlock() + if watcher != nil { + if err := watcher.Close(); err != nil { + slog.Warn("failed to close config watcher", "err", err) + } + watcher = nil + watchStarted = false + slog.Info("config watcher stopped") + } + } } -// Get returns the current config snapshot. -// Note: After switching to Viper, this is deprecated. -// Use viper.Get* methods or Load() directly instead. +// Get returns the current config snapshot. Safe for concurrent use. +// Returns a deep copy to prevent external modifications. func Get() Config { + if v := global.Load(); v != nil { + cfg := v.(Config) + // Deep copy hotwords slice to prevent external modifications + if cfg.Doubao.Hotwords != nil { + cfg.Doubao.Hotwords = append([]string(nil), cfg.Doubao.Hotwords...) + } + return cfg + } return defaults() } + +// store updates the global config. +// Deep copies slices to ensure immutability. +func store(cfg Config) { + // Deep copy hotwords to prevent external modifications + if cfg.Doubao.Hotwords != nil { + cfg.Doubao.Hotwords = append([]string(nil), cfg.Doubao.Hotwords...) + } + global.Store(cfg) +} diff --git a/main.go b/main.go index caab353..8711d82 100644 --- a/main.go +++ b/main.go @@ -25,7 +25,12 @@ func main() { initLogger() slog.Info("VoicePaste", "version", version) cfg := mustLoadConfig() - go config.WatchAndReload("") + stopWatch := config.WatchAndReload("") + defer func() { + if stopWatch != nil { + stopWatch() + } + }() initClipboard() lanIPs := mustDetectLANIPs() lanIP := lanIPs[0] @@ -136,20 +141,22 @@ func createServer(cfg config.Config, lanIP string, tlsResult *vpTLS.Result) *ser tlsConfig = tlsResult.Config } srv := server.New(cfg.Security.Token, lanIP, webContent, tlsConfig) - asrFactory := buildASRFactory(cfg) + asrFactory := buildASRFactory() wsHandler := ws.NewHandler(cfg.Security.Token, paste.Paste, asrFactory) wsHandler.Register(srv.App()) return srv } -func buildASRFactory(cfg config.Config) func(chan<- ws.ServerMsg) (func([]byte), func(), error) { - asrCfg := asr.Config{ - AppID: cfg.Doubao.AppID, - AccessToken: cfg.Doubao.AccessToken, - ResourceID: cfg.Doubao.ResourceID, - Hotwords: cfg.Doubao.Hotwords, - } +func buildASRFactory() func(chan<- ws.ServerMsg) (func([]byte), func(), error) { return func(resultCh chan<- ws.ServerMsg) (func([]byte), func(), error) { + // Read latest config on each new connection + cfg := config.Get() + asrCfg := asr.Config{ + AppID: cfg.Doubao.AppID, + AccessToken: cfg.Doubao.AccessToken, + ResourceID: cfg.Doubao.ResourceID, + Hotwords: cfg.Doubao.Hotwords, + } client, err := asr.Dial(asrCfg, resultCh) if err != nil { return nil, nil, err