From 16f6698a4ee733212c78bacfcb4e1088e9a51136 Mon Sep 17 00:00:00 2001 From: John McBride Date: Wed, 11 Feb 2026 22:45:28 -0500 Subject: [PATCH] feat: Utilize spf13/viper for configs Signed-off-by: John McBride --- cmd/tapes/chat/chat.go | 35 ++++--- cmd/tapes/checkout/checkout.go | 26 +++-- cmd/tapes/search/search.go | 31 +++--- cmd/tapes/serve/api/api.go | 38 ++++--- cmd/tapes/serve/proxy/proxy.go | 109 +++++++++++--------- cmd/tapes/serve/serve.go | 144 ++++++++++++++------------ go.mod | 13 ++- go.sum | 27 ++++- pkg/config/config.go | 119 ++++++++++------------ pkg/config/config_test.go | 181 ++++++++++++++++++++++++++++++++- pkg/config/flags.go | 118 +++++++++++++++++++++ pkg/config/types.go | 132 +++++++----------------- pkg/config/viper.go | 89 ++++++++++++++++ 13 files changed, 724 insertions(+), 338 deletions(-) create mode 100644 pkg/config/flags.go create mode 100644 pkg/config/viper.go diff --git a/cmd/tapes/chat/chat.go b/cmd/tapes/chat/chat.go index 208d86b..83bfdde 100644 --- a/cmd/tapes/chat/chat.go +++ b/cmd/tapes/chat/chat.go @@ -24,6 +24,8 @@ import ( ) type chatCommander struct { + flags config.FlagSet + proxyTarget string apiTarget string model string @@ -59,6 +61,11 @@ type ollamaStreamChunk struct { Done bool `json:"done"` } +var chatFlags = config.FlagSet{ + config.FlagAPITarget: {Name: "api-target", Shorthand: "a", ViperKey: "client.api_target", Description: "Tapes API server URL"}, + config.FlagProxyTarget: {Name: "proxy-target", Shorthand: "p", ViperKey: "client.proxy_target", Description: "Tapes proxy URL"}, +} + const chatLongDesc string = `Experimental: Start an interactive chat session through the tapes proxy. The chat command sends messages to an LLM through the configured tapes proxy, @@ -79,7 +86,9 @@ Examples: const chatShortDesc string = "Experimental: Interactive LLM chat through the tapes proxy" func NewChatCmd() *cobra.Command { - cmder := &chatCommander{} + cmder := &chatCommander{ + flags: chatFlags, + } cmd := &cobra.Command{ Use: "chat", @@ -87,23 +96,18 @@ func NewChatCmd() *cobra.Command { Long: chatLongDesc, PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } - - if !cmd.Flags().Changed("api-target") { - cmder.apiTarget = cfg.Client.APITarget - } + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagAPITarget, + config.FlagProxyTarget, + }) - if !cmd.Flags().Changed("proxy-target") { - cmder.proxyTarget = cfg.Client.ProxyTarget - } + cmder.apiTarget = v.GetString("client.api_target") + cmder.proxyTarget = v.GetString("client.proxy_target") return nil }, RunE: func(cmd *cobra.Command, _ []string) error { @@ -117,9 +121,8 @@ func NewChatCmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() - cmd.Flags().StringVarP(&cmder.apiTarget, "api-target", "a", defaults.Client.APITarget, "Tapes API server URL") - cmd.Flags().StringVarP(&cmder.proxyTarget, "proxy-target", "p", defaults.Client.ProxyTarget, "Tapes proxy URL") + config.AddStringFlag(cmd, cmder.flags, config.FlagAPITarget, &cmder.apiTarget) + config.AddStringFlag(cmd, cmder.flags, config.FlagProxyTarget, &cmder.proxyTarget) cmd.Flags().StringVarP(&cmder.model, "model", "m", "gemma3:latest", "Model name (e.g., gemma3:1b, ministral-3:latest)") return cmd diff --git a/cmd/tapes/checkout/checkout.go b/cmd/tapes/checkout/checkout.go index 02a3760..0647b98 100644 --- a/cmd/tapes/checkout/checkout.go +++ b/cmd/tapes/checkout/checkout.go @@ -22,6 +22,8 @@ import ( ) type checkoutCommander struct { + flags config.FlagSet + hash string apiTarget string debug bool @@ -48,6 +50,10 @@ type historyMessage struct { Usage *llm.Usage `json:"usage,omitempty"` } +var checkoutFlags = config.FlagSet{ + config.FlagAPITarget: {Name: "api-target", Shorthand: "a", ViperKey: "client.api_target", Description: "Tapes API server URL"}, +} + const checkoutLongDesc string = `Experimental: Checkout a point in the conversation for replay. Fetches the conversation history up to the given hash from the API server @@ -63,7 +69,9 @@ Examples: const checkoutShortDesc string = "Checkout a conversation point" func NewCheckoutCmd() *cobra.Command { - cmder := &checkoutCommander{} + cmder := &checkoutCommander{ + flags: checkoutFlags, + } cmd := &cobra.Command{ Use: "checkout [hash]", @@ -72,19 +80,16 @@ func NewCheckoutCmd() *cobra.Command { Args: cobra.MaximumNArgs(1), PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagAPITarget, + }) - if !cmd.Flags().Changed("api-target") { - cmder.apiTarget = cfg.Client.APITarget - } + cmder.apiTarget = v.GetString("client.api_target") return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -102,8 +107,7 @@ func NewCheckoutCmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() - cmd.Flags().StringVarP(&cmder.apiTarget, "api-target", "a", defaults.Client.APITarget, "Tapes API server URL") + config.AddStringFlag(cmd, cmder.flags, config.FlagAPITarget, &cmder.apiTarget) return cmd } diff --git a/cmd/tapes/search/search.go b/cmd/tapes/search/search.go index e53a00b..ae6c18e 100644 --- a/cmd/tapes/search/search.go +++ b/cmd/tapes/search/search.go @@ -20,15 +20,20 @@ import ( ) type searchCommander struct { - query string - topK int + flags config.FlagSet + query string + topK int apiTarget string + debug bool - debug bool logger *zap.Logger } +var searchFlags = config.FlagSet{ + config.FlagAPITarget: {Name: "api-target", Shorthand: "a", ViperKey: "client.api_target", Description: "Tapes API server URL"}, +} + const searchLongDesc string = `Search session data via the Tapes API. Search over stored sessions, returning the most relevant sessions based on the @@ -46,7 +51,9 @@ Example: const searchShortDesc string = "Search session data" func NewSearchCmd() *cobra.Command { - cmder := &searchCommander{} + cmder := &searchCommander{ + flags: searchFlags, + } cmd := &cobra.Command{ Use: "search ", @@ -55,19 +62,16 @@ func NewSearchCmd() *cobra.Command { Args: cobra.ExactArgs(1), PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagAPITarget, + }) - if !cmd.Flags().Changed("api-target") { - cmder.apiTarget = cfg.Client.APITarget - } + cmder.apiTarget = v.GetString("client.api_target") return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -83,9 +87,8 @@ func NewSearchCmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() cmd.Flags().IntVarP(&cmder.topK, "top", "k", 5, "Number of results to return") - cmd.Flags().StringVar(&cmder.apiTarget, "api-target", defaults.Client.APITarget, "Tapes API server URL") + config.AddStringFlag(cmd, cmder.flags, config.FlagAPITarget, &cmder.apiTarget) return cmd } diff --git a/cmd/tapes/serve/api/api.go b/cmd/tapes/serve/api/api.go index 6e29960..8d2c7bb 100644 --- a/cmd/tapes/serve/api/api.go +++ b/cmd/tapes/serve/api/api.go @@ -18,10 +18,19 @@ import ( ) type apiCommander struct { + flags config.FlagSet + listen string debug bool sqlitePath string - logger *zap.Logger + + logger *zap.Logger +} + +// apiFlags defines the flags for the standalone API subcommand. +var apiFlags = config.FlagSet{ + config.FlagAPIListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "api.listen", Description: "Address for API server to listen on"}, + config.FlagSQLite: {Name: "sqlite", Shorthand: "s", ViperKey: "storage.sqlite_path", Description: "Path to SQLite database"}, } const apiLongDesc string = `Run the Tapes API server for inspecting, managing, and query agent sessions.` @@ -29,7 +38,9 @@ const apiLongDesc string = `Run the Tapes API server for inspecting, managing, a const apiShortDesc string = "Run the Tapes API server" func NewAPICmd() *cobra.Command { - cmder := &apiCommander{} + cmder := &apiCommander{ + flags: apiFlags, + } cmd := &cobra.Command{ Use: "api", @@ -37,22 +48,18 @@ func NewAPICmd() *cobra.Command { Long: apiLongDesc, PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagAPIListenStandalone, + config.FlagSQLite, + }) - if !cmd.Flags().Changed("listen") { - cmder.listen = cfg.API.Listen - } - if !cmd.Flags().Changed("sqlite") { - cmder.sqlitePath = cfg.Storage.SQLitePath - } + cmder.listen = v.GetString("api.listen") + cmder.sqlitePath = v.GetString("storage.sqlite_path") return nil }, RunE: func(cmd *cobra.Command, _ []string) error { @@ -66,9 +73,8 @@ func NewAPICmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() - cmd.Flags().StringVarP(&cmder.listen, "listen", "l", defaults.API.Listen, "Address for API server to listen on") - cmd.Flags().StringVarP(&cmder.sqlitePath, "sqlite", "s", "", "Path to SQLite database (default: in-memory)") + config.AddStringFlag(cmd, cmder.flags, config.FlagAPIListenStandalone, &cmder.listen) + config.AddStringFlag(cmd, cmder.flags, config.FlagSQLite, &cmder.sqlitePath) return cmd } diff --git a/cmd/tapes/serve/proxy/proxy.go b/cmd/tapes/serve/proxy/proxy.go index b8d7784..d88cdb3 100644 --- a/cmd/tapes/serve/proxy/proxy.go +++ b/cmd/tapes/serve/proxy/proxy.go @@ -19,6 +19,8 @@ import ( ) type proxyCommander struct { + flags config.FlagSet + listen string upstream string providerType string @@ -28,13 +30,30 @@ type proxyCommander struct { vectorStoreProvider string vectorStoreTarget string - embeddingProvider string - embeddingTarget string - embeddingModel string + embeddingProvider string + embeddingTarget string + embeddingModel string + embeddingDimensions uint logger *zap.Logger } +// proxyFlags defines the flags for the standalone proxy subcommand. +// Uses FlagProxyListenStandalone (--listen/-l) instead of the parent's +// --proxy-listen/-p, and omits --api-listen since this is proxy-only. +var proxyFlags = config.FlagSet{ + config.FlagProxyListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "proxy.listen", Description: "Address for proxy to listen on"}, + config.FlagUpstream: {Name: "upstream", Shorthand: "u", ViperKey: "proxy.upstream", Description: "Upstream LLM provider URL"}, + config.FlagProvider: {Name: "provider", ViperKey: "proxy.provider", Description: "LLM provider type (anthropic, openai, ollama)"}, + config.FlagSQLite: {Name: "sqlite", Shorthand: "s", ViperKey: "storage.sqlite_path", Description: "Path to SQLite database"}, + config.FlagVectorStoreProv: {Name: "vector-store-provider", ViperKey: "vector_store.provider", Description: "Vector store provider type (e.g., chroma, sqlite)"}, + config.FlagVectorStoreTgt: {Name: "vector-store-target", ViperKey: "vector_store.target", Description: "Vector store target: filepath for sqlite or URL for remote service"}, + config.FlagEmbeddingProv: {Name: "embedding-provider", ViperKey: "embedding.provider", Description: "Embedding provider type (e.g., ollama)"}, + config.FlagEmbeddingTgt: {Name: "embedding-target", ViperKey: "embedding.target", Description: "Embedding provider URL"}, + config.FlagEmbeddingModel: {Name: "embedding-model", ViperKey: "embedding.model", Description: "Embedding model name (e.g., nomic-embed-text)"}, + config.FlagEmbeddingDims: {Name: "embedding-dimensions", ViperKey: "embedding.dimensions", Description: "Embedding dimensionality"}, +} + const proxyLongDesc string = `Run the proxy server. The proxy intercepts all requests and transparently forwards them to the @@ -48,7 +67,9 @@ agentic functionality.` const proxyShortDesc string = "Run the Tapes proxy server" func NewProxyCmd() *cobra.Command { - cmder := &proxyCommander{} + cmder := &proxyCommander{ + flags: proxyFlags, + } cmd := &cobra.Command{ Use: "proxy", @@ -56,43 +77,35 @@ func NewProxyCmd() *cobra.Command { Long: proxyLongDesc, PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagProxyListenStandalone, + config.FlagUpstream, + config.FlagProvider, + config.FlagSQLite, + config.FlagVectorStoreProv, + config.FlagVectorStoreTgt, + config.FlagEmbeddingProv, + config.FlagEmbeddingTgt, + config.FlagEmbeddingModel, + config.FlagEmbeddingDims, + }) + + cmder.listen = v.GetString("proxy.listen") + cmder.upstream = v.GetString("proxy.upstream") + cmder.providerType = v.GetString("proxy.provider") + cmder.sqlitePath = v.GetString("storage.sqlite_path") + cmder.vectorStoreProvider = v.GetString("vector_store.provider") + cmder.vectorStoreTarget = v.GetString("vector_store.target") + cmder.embeddingProvider = v.GetString("embedding.provider") + cmder.embeddingTarget = v.GetString("embedding.target") + cmder.embeddingModel = v.GetString("embedding.model") + cmder.embeddingDimensions = v.GetUint("embedding.dimensions") - if !cmd.Flags().Changed("listen") { - cmder.listen = cfg.Proxy.Listen - } - if !cmd.Flags().Changed("upstream") { - cmder.upstream = cfg.Proxy.Upstream - } - if !cmd.Flags().Changed("provider") { - cmder.providerType = cfg.Proxy.Provider - } - if !cmd.Flags().Changed("sqlite") { - cmder.sqlitePath = cfg.Storage.SQLitePath - } - if !cmd.Flags().Changed("vector-store-provider") { - cmder.vectorStoreProvider = cfg.VectorStore.Provider - } - if !cmd.Flags().Changed("vector-store-target") { - cmder.vectorStoreTarget = cfg.VectorStore.Target - } - if !cmd.Flags().Changed("embedding-provider") { - cmder.embeddingProvider = cfg.Embedding.Provider - } - if !cmd.Flags().Changed("embedding-target") { - cmder.embeddingTarget = cfg.Embedding.Target - } - if !cmd.Flags().Changed("embedding-model") { - cmder.embeddingModel = cfg.Embedding.Model - } return nil }, RunE: func(cmd *cobra.Command, _ []string) error { @@ -106,16 +119,16 @@ func NewProxyCmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() - cmd.Flags().StringVarP(&cmder.listen, "listen", "l", defaults.Proxy.Listen, "Address for proxy to listen on") - cmd.Flags().StringVarP(&cmder.upstream, "upstream", "u", defaults.Proxy.Upstream, "Upstream LLM provider URL") - cmd.Flags().StringVarP(&cmder.providerType, "provider", "p", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama)") - cmd.Flags().StringVarP(&cmder.sqlitePath, "sqlite", "s", "", "Path to SQLite database (default: in-memory)") - cmd.Flags().StringVar(&cmder.vectorStoreProvider, "vector-store-provider", defaults.VectorStore.Provider, "Vector store provider type (e.g., chroma, sqlite)") - cmd.Flags().StringVar(&cmder.vectorStoreTarget, "vector-store-target", defaults.VectorStore.Target, "Vector store URL (e.g., http://localhost:8000)") - cmd.Flags().StringVar(&cmder.embeddingProvider, "embedding-provider", defaults.Embedding.Provider, "Embedding provider type (e.g., ollama)") - cmd.Flags().StringVar(&cmder.embeddingTarget, "embedding-target", defaults.Embedding.Target, "Embedding provider URL") - cmd.Flags().StringVar(&cmder.embeddingModel, "embedding-model", defaults.Embedding.Model, "Embedding model name (e.g., nomic-embed-text)") + config.AddStringFlag(cmd, cmder.flags, config.FlagProxyListenStandalone, &cmder.listen) + config.AddStringFlag(cmd, cmder.flags, config.FlagUpstream, &cmder.upstream) + config.AddStringFlag(cmd, cmder.flags, config.FlagProvider, &cmder.providerType) + config.AddStringFlag(cmd, cmder.flags, config.FlagSQLite, &cmder.sqlitePath) + config.AddStringFlag(cmd, cmder.flags, config.FlagVectorStoreProv, &cmder.vectorStoreProvider) + config.AddStringFlag(cmd, cmder.flags, config.FlagVectorStoreTgt, &cmder.vectorStoreTarget) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingProv, &cmder.embeddingProvider) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingTgt, &cmder.embeddingTarget) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingModel, &cmder.embeddingModel) + config.AddUintFlag(cmd, cmder.flags, config.FlagEmbeddingDims, &cmder.embeddingDimensions) return cmd } @@ -151,9 +164,7 @@ func (c *proxyCommander) run() error { ProviderType: c.vectorStoreProvider, Target: c.vectorStoreTarget, Logger: c.logger, - - // TODO - need to make this actually configurable - Dimensions: 1024, + Dimensions: c.embeddingDimensions, }) if err != nil { return fmt.Errorf("creating vector driver: %w", err) diff --git a/cmd/tapes/serve/serve.go b/cmd/tapes/serve/serve.go index 1b76ee8..d7adc26 100644 --- a/cmd/tapes/serve/serve.go +++ b/cmd/tapes/serve/serve.go @@ -28,6 +28,8 @@ import ( ) type ServeCommander struct { + flags config.FlagSet + proxyListen string apiListen string upstream string @@ -47,6 +49,25 @@ type ServeCommander struct { logger *zap.Logger } +// ServeFlags is the shared FlagSet for all serve-family commands. +// The parent "tapes serve" and the "tapes serve proxy" / "tapes serve api" +// subcommands each pick the subset they need from this set. +var ServeFlags = config.FlagSet{ + config.FlagProxyListen: {Name: "proxy-listen", Shorthand: "p", ViperKey: "proxy.listen", Description: "Address for proxy to listen on"}, + config.FlagAPIListen: {Name: "api-listen", Shorthand: "a", ViperKey: "api.listen", Description: "Address for API server to listen on"}, + config.FlagUpstream: {Name: "upstream", Shorthand: "u", ViperKey: "proxy.upstream", Description: "Upstream LLM provider URL"}, + config.FlagProvider: {Name: "provider", ViperKey: "proxy.provider", Description: "LLM provider type (anthropic, openai, ollama)"}, + config.FlagSQLite: {Name: "sqlite", Shorthand: "s", ViperKey: "storage.sqlite_path", Description: "Path to SQLite database"}, + config.FlagVectorStoreProv: {Name: "vector-store-provider", ViperKey: "vector_store.provider", Description: "Vector store provider type (e.g., chroma, sqlite)"}, + config.FlagVectorStoreTgt: {Name: "vector-store-target", ViperKey: "vector_store.target", Description: "Vector store target: filepath for sqlite or URL for remote service"}, + config.FlagEmbeddingProv: {Name: "embedding-provider", ViperKey: "embedding.provider", Description: "Embedding provider type (e.g., ollama)"}, + config.FlagEmbeddingTgt: {Name: "embedding-target", ViperKey: "embedding.target", Description: "Embedding provider URL"}, + config.FlagEmbeddingModel: {Name: "embedding-model", ViperKey: "embedding.model", Description: "Embedding model name (e.g., nomic-embed-text)"}, + config.FlagEmbeddingDims: {Name: "embedding-dimensions", ViperKey: "embedding.dimensions", Description: "Embedding dimensionality"}, + config.FlagProxyListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "proxy.listen", Description: "Address for proxy to listen on"}, + config.FlagAPIListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "api.listen", Description: "Address for API server to listen on"}, +} + const serveLongDesc string = `Run Tapes services. Use subcommands to run individual services or all services together: @@ -60,7 +81,9 @@ agentic functionality.` const serveShortDesc string = "Run Tapes services" func NewServeCmd() *cobra.Command { - cmder := &ServeCommander{} + cmder := &ServeCommander{ + flags: ServeFlags, + } cmd := &cobra.Command{ Use: "serve", @@ -68,65 +91,61 @@ func NewServeCmd() *cobra.Command { Long: serveLongDesc, PreRunE: func(cmd *cobra.Command, _ []string) error { configDir, _ := cmd.Flags().GetString("config-dir") - cfger, err := config.NewConfiger(configDir) + v, err := config.InitViper(configDir) if err != nil { return fmt.Errorf("loading config: %w", err) } - cfg, err := cfger.LoadConfig() - if err != nil { - return fmt.Errorf("loading config: %w", err) - } - - // Resolve default sqlite path from dotdir target. - dotdirManager := dotdir.NewManager() - defaultTargetDir, err := dotdirManager.Target(configDir) - if err != nil { - return fmt.Errorf("resolving target dir: %w", err) + config.BindRegisteredFlags(v, cmd, cmder.flags, []string{ + config.FlagProxyListen, + config.FlagAPIListen, + config.FlagUpstream, + config.FlagProvider, + config.FlagSQLite, + config.FlagVectorStoreProv, + config.FlagVectorStoreTgt, + config.FlagEmbeddingProv, + config.FlagEmbeddingTgt, + config.FlagEmbeddingModel, + config.FlagEmbeddingDims, + }) + + // Resolve default sqlite path from dotdir target when not set + // via flag, env, or config file. + if v.GetString("storage.sqlite_path") == "" { + dotdirManager := dotdir.NewManager() + defaultTargetDir, err := dotdirManager.Target(configDir) + if err != nil { + return fmt.Errorf("resolving target dir: %w", err) + } + if defaultTargetDir != "" { + v.Set("storage.sqlite_path", filepath.Join(defaultTargetDir, "tapes.sqlite")) + } } - defaultTargetSqliteFile := filepath.Join(defaultTargetDir, "tapes.sqlite") - if !cmd.Flags().Changed("proxy-listen") { - cmder.proxyListen = cfg.Proxy.Listen - } - if !cmd.Flags().Changed("api-listen") { - cmder.apiListen = cfg.API.Listen - } - if !cmd.Flags().Changed("upstream") { - cmder.upstream = cfg.Proxy.Upstream - } - if !cmd.Flags().Changed("provider") { - cmder.providerType = cfg.Proxy.Provider - } - if !cmd.Flags().Changed("sqlite") { - if cfg.Storage.SQLitePath != "" { - cmder.sqlitePath = cfg.Storage.SQLitePath - } else { - cmder.sqlitePath = defaultTargetSqliteFile + // Same fallback for vector store target. + if v.GetString("vector_store.target") == "" { + dotdirManager := dotdir.NewManager() + defaultTargetDir, err := dotdirManager.Target(configDir) + if err != nil { + return fmt.Errorf("resolving target dir: %w", err) } - } - if !cmd.Flags().Changed("vector-store-provider") { - cmder.vectorStoreProvider = cfg.VectorStore.Provider - } - if !cmd.Flags().Changed("vector-store-target") { - if cfg.VectorStore.Target != "" { - cmder.vectorStoreTarget = cfg.VectorStore.Target - } else { - cmder.vectorStoreTarget = defaultTargetSqliteFile + if defaultTargetDir != "" { + v.Set("vector_store.target", filepath.Join(defaultTargetDir, "tapes.sqlite")) } } - if !cmd.Flags().Changed("embedding-provider") { - cmder.embeddingProvider = cfg.Embedding.Provider - } - if !cmd.Flags().Changed("embedding-target") { - cmder.embeddingTarget = cfg.Embedding.Target - } - if !cmd.Flags().Changed("embedding-model") { - cmder.embeddingModel = cfg.Embedding.Model - } - if !cmd.Flags().Changed("embedding-dimensions") { - cmder.embeddingDimensions = cfg.Embedding.Dimensions - } + + cmder.proxyListen = v.GetString("proxy.listen") + cmder.apiListen = v.GetString("api.listen") + cmder.upstream = v.GetString("proxy.upstream") + cmder.providerType = v.GetString("proxy.provider") + cmder.sqlitePath = v.GetString("storage.sqlite_path") + cmder.vectorStoreProvider = v.GetString("vector_store.provider") + cmder.vectorStoreTarget = v.GetString("vector_store.target") + cmder.embeddingProvider = v.GetString("embedding.provider") + cmder.embeddingTarget = v.GetString("embedding.target") + cmder.embeddingModel = v.GetString("embedding.model") + cmder.embeddingDimensions = v.GetUint("embedding.dimensions") return nil }, RunE: func(cmd *cobra.Command, _ []string) error { @@ -139,18 +158,17 @@ func NewServeCmd() *cobra.Command { }, } - defaults := config.NewDefaultConfig() - cmd.Flags().StringVarP(&cmder.proxyListen, "proxy-listen", "p", defaults.Proxy.Listen, "Address for proxy to listen on") - cmd.Flags().StringVarP(&cmder.apiListen, "api-listen", "a", defaults.API.Listen, "Address for API server to listen on") - cmd.Flags().StringVarP(&cmder.upstream, "upstream", "u", defaults.Proxy.Upstream, "Upstream LLM provider URL") - cmd.Flags().StringVar(&cmder.providerType, "provider", defaults.Proxy.Provider, "LLM provider type (anthropic, openai, ollama)") - cmd.Flags().StringVarP(&cmder.sqlitePath, "sqlite", "s", "", "Path to SQLite database (e.g., ./tapes.sqlite, in-memory)") - cmd.Flags().StringVar(&cmder.vectorStoreProvider, "vector-store-provider", defaults.VectorStore.Provider, "Vector store provider type (e.g., chroma, sqlite)") - cmd.Flags().StringVar(&cmder.vectorStoreTarget, "vector-store-target", defaults.VectorStore.Target, "Vector store target filepath for sqlite or URL for vector store service (e.g., http://localhost:8000, ./db.sqlite)") - cmd.Flags().StringVar(&cmder.embeddingProvider, "embedding-provider", defaults.Embedding.Provider, "Embedding provider type (e.g., ollama)") - cmd.Flags().StringVar(&cmder.embeddingTarget, "embedding-target", defaults.Embedding.Target, "Embedding provider URL") - cmd.Flags().StringVar(&cmder.embeddingModel, "embedding-model", defaults.Embedding.Model, "Embedding model name (e.g., nomic-embed-text)") - cmd.Flags().UintVar(&cmder.embeddingDimensions, "embedding-dimensions", defaults.Embedding.Dimensions, "Embedding dimensionality.") + config.AddStringFlag(cmd, cmder.flags, config.FlagProxyListen, &cmder.proxyListen) + config.AddStringFlag(cmd, cmder.flags, config.FlagAPIListen, &cmder.apiListen) + config.AddStringFlag(cmd, cmder.flags, config.FlagUpstream, &cmder.upstream) + config.AddStringFlag(cmd, cmder.flags, config.FlagProvider, &cmder.providerType) + config.AddStringFlag(cmd, cmder.flags, config.FlagSQLite, &cmder.sqlitePath) + config.AddStringFlag(cmd, cmder.flags, config.FlagVectorStoreProv, &cmder.vectorStoreProvider) + config.AddStringFlag(cmd, cmder.flags, config.FlagVectorStoreTgt, &cmder.vectorStoreTarget) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingProv, &cmder.embeddingProvider) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingTgt, &cmder.embeddingTarget) + config.AddStringFlag(cmd, cmder.flags, config.FlagEmbeddingModel, &cmder.embeddingModel) + config.AddUintFlag(cmd, cmder.flags, config.FlagEmbeddingDims, &cmder.embeddingDimensions) cmd.AddCommand(apicmder.NewAPICmd()) cmd.AddCommand(proxycmder.NewProxyCmd()) diff --git a/go.mod b/go.mod index 48a26dc..b1f03c1 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 + github.com/charmbracelet/x/ansi v0.10.1 github.com/gofiber/adaptor/v2 v2.2.1 github.com/gofiber/fiber/v2 v2.52.6 github.com/mattn/go-sqlite3 v1.14.24 @@ -17,6 +18,7 @@ require ( github.com/onsi/ginkgo/v2 v2.27.4 github.com/onsi/gomega v1.39.0 github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 go.uber.org/zap v1.27.1 ) @@ -29,13 +31,14 @@ require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.10.1 // indirect github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/term v0.2.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-openapi/inflect v0.19.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect @@ -51,8 +54,14 @@ require ( github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/spf13/pflag v1.0.9 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.62.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect diff --git a/go.sum b/go.sum index 1326614..02ae07e 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gkampitakis/ciinfo v0.3.2 h1:JcuOPk8ZU7nZQjdUhctuhQofk7BGHuIy0c9Ez8BNhXs= github.com/gkampitakis/ciinfo v0.3.2/go.mod h1:1NIwaOcFChN4fa/B0hEBdAb6npDlFL8Bwx4dfRLRqAo= github.com/gkampitakis/go-diff v1.3.2 h1:Qyn0J9XJSDTgnsgHRdz9Zp24RaJeKMUHg2+PDZZdC4M= @@ -57,6 +61,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gofiber/adaptor/v2 v2.2.1 h1:givE7iViQWlsTR4Jh7tB4iXzrlKBgiraB/yTdHs9Lv4= @@ -119,6 +125,8 @@ github.com/onsi/ginkgo/v2 v2.27.4 h1:fcEcQW/A++6aZAZQNUmNjvA9PSOzefMJBerHJ4t8v8Y github.com/onsi/ginkgo/v2 v2.27.4/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q= github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -127,14 +135,27 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= -github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= diff --git a/pkg/config/config.go b/pkg/config/config.go index cfb9f10..837e0de 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/BurntSushi/toml" + "github.com/spf13/viper" "github.com/papercomputeco/tapes/pkg/dotdir" ) @@ -58,11 +59,6 @@ func NewConfiger(override string) (*Configer, error) { // ValidConfigKeys returns the sorted list of all supported configuration key names. func ValidConfigKeys() []string { - keys := make([]string, 0, len(configKeys)) - for k := range configKeys { - keys = append(keys, k) - } - // Return in a stable, logical order matching the TOML section layout. ordered := []string{ "storage.sqlite_path", @@ -83,7 +79,7 @@ func ValidConfigKeys() []string { // Sanity: only return keys that actually exist in the map. result := make([]string, 0, len(ordered)) for _, k := range ordered { - if _, ok := configKeys[k]; ok { + if validConfigKeys[k] { result = append(result, k) } } @@ -93,7 +89,7 @@ func ValidConfigKeys() []string { for _, k := range result { seen[k] = true } - for _, k := range keys { + for k := range validConfigKeys { if !seen[k] { result = append(result, k) } @@ -104,8 +100,7 @@ func ValidConfigKeys() []string { // IsValidConfigKey returns true if the given key is a supported configuration key. func IsValidConfigKey(key string) bool { - _, ok := configKeys[key] - return ok + return validConfigKeys[key] } func (c *Configer) GetTarget() string { @@ -116,7 +111,6 @@ func (c *Configer) GetTarget() string { // If the file does not exist, returns DefaultConfig() so callers always receive // a fully-populated Config with sane defaults. Fields explicitly set in the file // override the defaults. -// If overrideDir is non-empty, it is used instead of the default .tapes/ location. func (c *Configer) LoadConfig() (*Config, error) { if c.targetPath == "" { return NewDefaultConfig(), nil @@ -135,57 +129,25 @@ func (c *Configer) LoadConfig() (*Config, error) { return nil, err } - // Merge in defaults: fill in any zero-value fields from the loaded config - applyDefaults(cfg) - - return cfg, nil -} - -// applyDefaults fills zero-value fields in cfg with values from DefaultConfig(). -func applyDefaults(cfg *Config) { - defaults := NewDefaultConfig() + // Use viper to merge defaults into the parsed config. + // This replaces the old hand-rolled applyDefaults function. + v := viper.New() + setViperDefaults(v) + v.SetConfigType("toml") - if cfg.Version == 0 { - cfg.Version = defaults.Version + if err := v.ReadConfig(bytes.NewReader(data)); err != nil { + return nil, fmt.Errorf("reading config into viper: %w", err) } - if cfg.Proxy.Provider == "" { - cfg.Proxy.Provider = defaults.Proxy.Provider - } - if cfg.Proxy.Upstream == "" { - cfg.Proxy.Upstream = defaults.Proxy.Upstream - } - if cfg.Proxy.Listen == "" { - cfg.Proxy.Listen = defaults.Proxy.Listen + merged := &Config{} + if err := v.Unmarshal(merged); err != nil { + return nil, fmt.Errorf("unmarshalling config: %w", err) } - if cfg.API.Listen == "" { - cfg.API.Listen = defaults.API.Listen - } - - if cfg.Client.ProxyTarget == "" { - cfg.Client.ProxyTarget = defaults.Client.ProxyTarget - } - if cfg.Client.APITarget == "" { - cfg.Client.APITarget = defaults.Client.APITarget - } + // Preserve the version from the parsed config (version 0 is valid). + merged.Version = cfg.Version - if cfg.VectorStore.Provider == "" { - cfg.VectorStore.Provider = defaults.VectorStore.Provider - } - - if cfg.Embedding.Provider == "" { - cfg.Embedding.Provider = defaults.Embedding.Provider - } - if cfg.Embedding.Target == "" { - cfg.Embedding.Target = defaults.Embedding.Target - } - if cfg.Embedding.Model == "" { - cfg.Embedding.Model = defaults.Embedding.Model - } - if cfg.Embedding.Dimensions == 0 { - cfg.Embedding.Dimensions = defaults.Embedding.Dimensions - } + return merged, nil } // SaveConfig persists the configuration to config.toml in the target .tapes/ directory. @@ -214,8 +176,7 @@ func (c *Configer) SaveConfig(cfg *Config) error { // SetConfigValue loads the config, sets the given key to the given value, and saves it. // Returns an error if the key is not a valid config key. func (c *Configer) SetConfigValue(key string, value string) error { - info, ok := configKeys[key] - if !ok { + if !validConfigKeys[key] { return fmt.Errorf("unknown config key: %q", key) } @@ -224,27 +185,53 @@ func (c *Configer) SetConfigValue(key string, value string) error { return err } - if err := info.set(cfg, value); err != nil { - return err + // Use viper to set the value and unmarshal back to the Config struct. + // This handles type coercion (e.g., string to uint for embedding.dimensions). + v := viper.New() + setViperDefaults(v) + v.SetConfigType("toml") + + // Load existing config into viper if the file exists. + if c.targetPath != "" { + data, err := os.ReadFile(c.targetPath) + if err == nil { + _ = v.ReadConfig(bytes.NewReader(data)) + } + } + + v.Set(key, value) + + updated := &Config{} + if err := v.Unmarshal(updated); err != nil { + return fmt.Errorf("invalid value for %s: %w", key, err) } - return c.SaveConfig(cfg) + // Preserve the version from the loaded config. + updated.Version = cfg.Version + + return c.SaveConfig(updated) } // GetConfigValue loads the config and returns the string representation of the given key. // Returns an error if the key is not a valid config key. func (c *Configer) GetConfigValue(key string) (string, error) { - info, ok := configKeys[key] - if !ok { + if !validConfigKeys[key] { return "", fmt.Errorf("unknown config key: %q", key) } - cfg, err := c.LoadConfig() - if err != nil { - return "", err + v := viper.New() + setViperDefaults(v) + v.SetConfigType("toml") + + // Load existing config into viper if the file exists. + if c.targetPath != "" { + data, err := os.ReadFile(c.targetPath) + if err == nil { + _ = v.ReadConfig(bytes.NewReader(data)) + } } - return info.get(cfg), nil + return v.GetString(key), nil } // PresetConfig returns a Config with sane defaults for the named provider preset. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 0785dba..f5ec4fc 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -7,6 +7,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/spf13/cobra" "github.com/papercomputeco/tapes/pkg/config" ) @@ -606,7 +607,185 @@ var _ = Describe("NewDefaultConfig", func() { }) }) -var _ = Describe("applyDefaults via LoadConfig", func() { +var _ = Describe("InitViper", func() { + var tmpDir string + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "viper-test-*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tmpDir) + }) + + It("returns viper with defaults when no config file exists", func() { + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + Expect(v).NotTo(BeNil()) + + defaults := config.NewDefaultConfig() + Expect(v.GetString("proxy.provider")).To(Equal(defaults.Proxy.Provider)) + Expect(v.GetString("proxy.upstream")).To(Equal(defaults.Proxy.Upstream)) + Expect(v.GetString("proxy.listen")).To(Equal(defaults.Proxy.Listen)) + Expect(v.GetString("api.listen")).To(Equal(defaults.API.Listen)) + Expect(v.GetString("client.proxy_target")).To(Equal(defaults.Client.ProxyTarget)) + Expect(v.GetString("client.api_target")).To(Equal(defaults.Client.APITarget)) + }) + + It("reads config file values over defaults", func() { + data := `[proxy] +provider = "anthropic" +upstream = "https://api.anthropic.com" +` + err := os.WriteFile(filepath.Join(tmpDir, "config.toml"), []byte(data), 0o600) + Expect(err).NotTo(HaveOccurred()) + + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + Expect(v.GetString("proxy.provider")).To(Equal("anthropic")) + Expect(v.GetString("proxy.upstream")).To(Equal("https://api.anthropic.com")) + // Unset fields should still get defaults + defaults := config.NewDefaultConfig() + Expect(v.GetString("proxy.listen")).To(Equal(defaults.Proxy.Listen)) + }) + + It("respects environment variables with TAPES_ prefix", func() { + os.Setenv("TAPES_PROXY_PROVIDER", "openai") + defer os.Unsetenv("TAPES_PROXY_PROVIDER") + + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + Expect(v.GetString("proxy.provider")).To(Equal("openai")) + }) + + It("env vars take precedence over config file values", func() { + data := `[proxy] +provider = "anthropic" +` + err := os.WriteFile(filepath.Join(tmpDir, "config.toml"), []byte(data), 0o600) + Expect(err).NotTo(HaveOccurred()) + + os.Setenv("TAPES_PROXY_PROVIDER", "openai") + defer os.Unsetenv("TAPES_PROXY_PROVIDER") + + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + Expect(v.GetString("proxy.provider")).To(Equal("openai")) + }) +}) + +var _ = Describe("BindFlags", func() { + var tmpDir string + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "bindflag-test-*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tmpDir) + }) + + It("binds cobra flags to viper keys via registry", func() { + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + fs := config.FlagSet{ + config.FlagAPIListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "api.listen", Description: "Address for API server to listen on"}, + } + + cmd := &cobra.Command{Use: "test"} + var listen string + config.AddStringFlag(cmd, fs, config.FlagAPIListenStandalone, &listen) + + // Simulate flag being set by user + err = cmd.Flags().Set("listen", ":7777") + Expect(err).NotTo(HaveOccurred()) + + config.BindRegisteredFlags(v, cmd, fs, []string{config.FlagAPIListenStandalone}) + + Expect(v.GetString("api.listen")).To(Equal(":7777")) + }) + + It("falls through to config when flag not set", func() { + data := `[api] +listen = ":5555" +` + err := os.WriteFile(filepath.Join(tmpDir, "config.toml"), []byte(data), 0o600) + Expect(err).NotTo(HaveOccurred()) + + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + fs := config.FlagSet{ + config.FlagAPIListenStandalone: {Name: "listen", Shorthand: "l", ViperKey: "api.listen", Description: "Address for API server to listen on"}, + } + + cmd := &cobra.Command{Use: "test"} + var listen string + config.AddStringFlag(cmd, fs, config.FlagAPIListenStandalone, &listen) + + // Do NOT set the flag -- should fall through to config file value + config.BindRegisteredFlags(v, cmd, fs, []string{config.FlagAPIListenStandalone}) + + Expect(v.GetString("api.listen")).To(Equal(":5555")) + }) + + It("skips bindings for nonexistent registry keys", func() { + v, err := config.InitViper(tmpDir) + Expect(err).NotTo(HaveOccurred()) + + fs := config.FlagSet{} + + cmd := &cobra.Command{Use: "test"} + + // "nonexistent" is not in the FlagSet -- should be safely skipped + config.BindRegisteredFlags(v, cmd, fs, []string{"nonexistent"}) + + defaults := config.NewDefaultConfig() + Expect(v.GetString("proxy.listen")).To(Equal(defaults.Proxy.Listen)) + }) + + It("AddStringFlag pulls name, shorthand, and description from FlagSet", func() { + fs := config.FlagSet{ + config.FlagAPITarget: {Name: "api-target", Shorthand: "a", ViperKey: "client.api_target", Description: "Tapes API server URL"}, + } + + cmd := &cobra.Command{Use: "test"} + var target string + config.AddStringFlag(cmd, fs, config.FlagAPITarget, &target) + + f := cmd.Flags().Lookup("api-target") + Expect(f).NotTo(BeNil()) + Expect(f.Shorthand).To(Equal("a")) + Expect(f.Usage).To(Equal("Tapes API server URL")) + + defaults := config.NewDefaultConfig() + Expect(f.DefValue).To(Equal(defaults.Client.APITarget)) + }) + + It("AddUintFlag works for embedding-dimensions", func() { + fs := config.FlagSet{ + config.FlagEmbeddingDims: {Name: "embedding-dimensions", ViperKey: "embedding.dimensions", Description: "Embedding dimensionality"}, + } + + cmd := &cobra.Command{Use: "test"} + var dims uint + config.AddUintFlag(cmd, fs, config.FlagEmbeddingDims, &dims) + + f := cmd.Flags().Lookup("embedding-dimensions") + Expect(f).NotTo(BeNil()) + Expect(f.Usage).To(Equal("Embedding dimensionality")) + }) +}) + +var _ = Describe("viper default merging via LoadConfig", func() { var tmpDir string BeforeEach(func() { diff --git a/pkg/config/flags.go b/pkg/config/flags.go new file mode 100644 index 0000000..ca8b7fc --- /dev/null +++ b/pkg/config/flags.go @@ -0,0 +1,118 @@ +package config + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// Flag is the single source of truth for a CLI flag. +// Commands reference flags by registry key rather than hard-coding names, +// shorthands, defaults, and descriptions inline. This prevents flag drift +// when the same logical flag appears on multiple commands (e.g., --upstream +// on both "tapes serve" and "tapes serve proxy" and "tapes chat"). +type Flag struct { + // Name is the long flag name (e.g. "upstream"). + Name string + + // Shorthand is the one-letter short flag (e.g. "u"). Empty for no shorthand. + Shorthand string + + // ViperKey is the dotted config key this flag maps to (e.g. "proxy.upstream"). + ViperKey string + + // Description is the help text shown in --help output. + Description string +} + +// FlagSet is a mapping of flag names to Flag structs that hold their name, +// shorthand, viper key, etc. +type FlagSet map[string]Flag + +// Flag registry keys. +// Use these constants when calling AddStringFlag, AddUintFlag, +// and BindRegisteredFlags to avoid typos or drift from one command to another. +const ( + FlagProxyListen = "proxy-listen" + FlagAPIListen = "api-listen" + FlagUpstream = "upstream" + FlagProvider = "provider" + FlagSQLite = "sqlite" + FlagVectorStoreProv = "vector-store-provider" + FlagVectorStoreTgt = "vector-store-target" + FlagEmbeddingProv = "embedding-provider" + FlagEmbeddingTgt = "embedding-target" + FlagEmbeddingModel = "embedding-model" + FlagEmbeddingDims = "embedding-dimensions" + FlagAPITarget = "api-target" + FlagProxyTarget = "proxy-target" + + // Standalone subcommand variants use "listen" as the flag name + // but bind to different viper keys depending on the service. + FlagProxyListenStandalone = "proxy-listen-standalone" + FlagAPIListenStandalone = "api-listen-standalone" +) + +// AddStringFlag registers a string flag on cmd from the given FlagSet. +// The flag's name, shorthand, default, and description all come from the +// FlagSet entry so they cannot drift across commands. +func AddStringFlag(cmd *cobra.Command, fs FlagSet, key string, target *string) { + def, ok := fs[key] + if !ok { + return + } + + defaultVal := defaultString(def.ViperKey) + if def.Shorthand != "" { + cmd.Flags().StringVarP(target, def.Name, def.Shorthand, defaultVal, def.Description) + } else { + cmd.Flags().StringVar(target, def.Name, defaultVal, def.Description) + } +} + +// AddUintFlag registers a uint flag on cmd from the given FlagSet. +func AddUintFlag(cmd *cobra.Command, fs FlagSet, registryKey string, target *uint) { + def, ok := fs[registryKey] + if !ok { + return + } + + defaultVal := defaultUint(def.ViperKey) + if def.Shorthand != "" { + cmd.Flags().UintVarP(target, def.Name, def.Shorthand, defaultVal, def.Description) + } else { + cmd.Flags().UintVar(target, def.Name, defaultVal, def.Description) + } +} + +// BindRegisteredFlags binds already-registered flags to viper using definitions +// from the given FlagSet. Call this in PreRunE after InitViper to connect flags +// to the viper precedence chain (flag > env > config file > default). +func BindRegisteredFlags(v *viper.Viper, cmd *cobra.Command, fs FlagSet, registryKeys []string) { + for _, registryKey := range registryKeys { + def, ok := fs[registryKey] + if !ok { + continue + } + + f := cmd.Flags().Lookup(def.Name) + if f == nil { + continue + } + + _ = v.BindPFlag(def.ViperKey, f) + } +} + +// defaultString returns the default string value for a viper key from NewDefaultConfig. +func defaultString(viperKey string) string { + v := viper.New() + setViperDefaults(v) + return v.GetString(viperKey) +} + +// defaultUint returns the default uint value for a viper key from NewDefaultConfig. +func defaultUint(viperKey string) uint { + v := viper.New() + setViperDefaults(v) + return v.GetUint(viperKey) +} diff --git a/pkg/config/types.go b/pkg/config/types.go index 14b01a0..7b43546 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -1,132 +1,70 @@ package config -import ( - "fmt" - "strconv" -) - // Config represents the persistent tapes configuration stored as config.toml // in the .tapes/ directory. The TOML layout uses sections for logical grouping. type Config struct { - Version int `toml:"version"` - Storage StorageConfig `toml:"storage"` - Proxy ProxyConfig `toml:"proxy"` - API APIConfig `toml:"api"` - Client ClientConfig `toml:"client"` - VectorStore VectorStoreConfig `toml:"vector_store"` - Embedding EmbeddingConfig `toml:"embedding"` + Version int `toml:"version" mapstructure:"version"` + Storage StorageConfig `toml:"storage" mapstructure:"storage"` + Proxy ProxyConfig `toml:"proxy" mapstructure:"proxy"` + API APIConfig `toml:"api" mapstructure:"api"` + Client ClientConfig `toml:"client" mapstructure:"client"` + VectorStore VectorStoreConfig `toml:"vector_store" mapstructure:"vector_store"` + Embedding EmbeddingConfig `toml:"embedding" mapstructure:"embedding"` } // StorageConfig holds shared storage settings used by both proxy and API. type StorageConfig struct { - SQLitePath string `toml:"sqlite_path,omitempty"` + SQLitePath string `toml:"sqlite_path,omitempty" mapstructure:"sqlite_path"` } // ProxyConfig holds proxy-specific settings. type ProxyConfig struct { - Provider string `toml:"provider,omitempty"` - Upstream string `toml:"upstream,omitempty"` - Listen string `toml:"listen,omitempty"` + Provider string `toml:"provider,omitempty" mapstructure:"provider"` + Upstream string `toml:"upstream,omitempty" mapstructure:"upstream"` + Listen string `toml:"listen,omitempty" mapstructure:"listen"` } // APIConfig holds API server settings. type APIConfig struct { - Listen string `toml:"listen,omitempty"` + Listen string `toml:"listen,omitempty" mapstructure:"listen"` } // ClientConfig holds settings for CLI commands that connect to the running // proxy and API servers (e.g. tapes chat, tapes search, tapes checkout). // Values are full URLs (scheme + host + port). type ClientConfig struct { - ProxyTarget string `toml:"proxy_target,omitempty"` - APITarget string `toml:"api_target,omitempty"` + ProxyTarget string `toml:"proxy_target,omitempty" mapstructure:"proxy_target"` + APITarget string `toml:"api_target,omitempty" mapstructure:"api_target"` } // VectorStoreConfig holds vector store settings. type VectorStoreConfig struct { - Provider string `toml:"provider,omitempty"` - Target string `toml:"target,omitempty"` + Provider string `toml:"provider,omitempty" mapstructure:"provider"` + Target string `toml:"target,omitempty" mapstructure:"target"` } // EmbeddingConfig holds embedding provider settings. type EmbeddingConfig struct { - Provider string `toml:"provider,omitempty"` - Target string `toml:"target,omitempty"` - Model string `toml:"model,omitempty"` - Dimensions uint `toml:"dimensions,omitempty"` -} - -// configKeyInfo maps a user-facing dotted key name to a getter and setter on *Config. -type configKeyInfo struct { - get func(c *Config) string - set func(c *Config, v string) error + Provider string `toml:"provider,omitempty" mapstructure:"provider"` + Target string `toml:"target,omitempty" mapstructure:"target"` + Model string `toml:"model,omitempty" mapstructure:"model"` + Dimensions uint `toml:"dimensions,omitempty" mapstructure:"dimensions"` } -// configKeys is the authoritative map of all supported config keys. +// validConfigKeys is the authoritative set of all supported config keys. // Keys use dotted notation matching the TOML section structure. -var configKeys = map[string]configKeyInfo{ - "storage.sqlite_path": { - get: func(c *Config) string { return c.Storage.SQLitePath }, - set: func(c *Config, v string) error { c.Storage.SQLitePath = v; return nil }, - }, - "proxy.provider": { - get: func(c *Config) string { return c.Proxy.Provider }, - set: func(c *Config, v string) error { c.Proxy.Provider = v; return nil }, - }, - "proxy.upstream": { - get: func(c *Config) string { return c.Proxy.Upstream }, - set: func(c *Config, v string) error { c.Proxy.Upstream = v; return nil }, - }, - "proxy.listen": { - get: func(c *Config) string { return c.Proxy.Listen }, - set: func(c *Config, v string) error { c.Proxy.Listen = v; return nil }, - }, - "api.listen": { - get: func(c *Config) string { return c.API.Listen }, - set: func(c *Config, v string) error { c.API.Listen = v; return nil }, - }, - "client.proxy_target": { - get: func(c *Config) string { return c.Client.ProxyTarget }, - set: func(c *Config, v string) error { c.Client.ProxyTarget = v; return nil }, - }, - "client.api_target": { - get: func(c *Config) string { return c.Client.APITarget }, - set: func(c *Config, v string) error { c.Client.APITarget = v; return nil }, - }, - "vector_store.provider": { - get: func(c *Config) string { return c.VectorStore.Provider }, - set: func(c *Config, v string) error { c.VectorStore.Provider = v; return nil }, - }, - "vector_store.target": { - get: func(c *Config) string { return c.VectorStore.Target }, - set: func(c *Config, v string) error { c.VectorStore.Target = v; return nil }, - }, - "embedding.provider": { - get: func(c *Config) string { return c.Embedding.Provider }, - set: func(c *Config, v string) error { c.Embedding.Provider = v; return nil }, - }, - "embedding.target": { - get: func(c *Config) string { return c.Embedding.Target }, - set: func(c *Config, v string) error { c.Embedding.Target = v; return nil }, - }, - "embedding.model": { - get: func(c *Config) string { return c.Embedding.Model }, - set: func(c *Config, v string) error { c.Embedding.Model = v; return nil }, - }, - "embedding.dimensions": { - get: func(c *Config) string { - if c.Embedding.Dimensions == 0 { - return "" - } - return strconv.FormatUint(uint64(c.Embedding.Dimensions), 10) - }, - set: func(c *Config, v string) error { - n, err := strconv.ParseUint(v, 10, 64) - if err != nil { - return fmt.Errorf("invalid value for embedding.dimensions: %w", err) - } - c.Embedding.Dimensions = uint(n) - return nil - }, - }, +var validConfigKeys = map[string]bool{ + "storage.sqlite_path": true, + "proxy.provider": true, + "proxy.upstream": true, + "proxy.listen": true, + "api.listen": true, + "client.proxy_target": true, + "client.api_target": true, + "vector_store.provider": true, + "vector_store.target": true, + "embedding.provider": true, + "embedding.target": true, + "embedding.model": true, + "embedding.dimensions": true, } diff --git a/pkg/config/viper.go b/pkg/config/viper.go new file mode 100644 index 0000000..f7d5271 --- /dev/null +++ b/pkg/config/viper.go @@ -0,0 +1,89 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + "github.com/spf13/viper" + + "github.com/papercomputeco/tapes/pkg/dotdir" +) + +// InitViper creates and returns a configured *viper.Viper. +// It sets defaults from NewDefaultConfig(), reads the config.toml file +// (if found via dotdir resolution), and binds environment variables +// with the TAPES_ prefix. +// +// Config precedence (highest to lowest): +// 1. CLI flags (once bound via BindRegisteredFlags) +// 2. Environment variables (TAPES_PROXY_LISTEN, TAPES_API_LISTEN, etc.) +// 3. config.toml file values +// 4. Defaults from NewDefaultConfig() +func InitViper(configDir string) (*viper.Viper, error) { + v := viper.New() + + // 1. Register all defaults from NewDefaultConfig(). + setViperDefaults(v) + + // 2. Config file discovery via dotdir resolution. + v.SetConfigName("config") + v.SetConfigType("toml") + + ddm := dotdir.NewManager() + target, err := ddm.Target(configDir) + if err != nil { + return nil, fmt.Errorf("resolving config dir: %w", err) + } + + if target != "" { + v.AddConfigPath(target) + } + + if err := v.ReadInConfig(); err != nil { + // Config file not found errors are fine, defaults will apply. + if !errors.As(err, &viper.ConfigFileNotFoundError{}) { + return nil, fmt.Errorf("reading config: %w", err) + } + } + + // 3. Environment variables: TAPES_PROXY_LISTEN, TAPES_STORAGE_SQLITE_PATH, etc. + v.SetEnvPrefix("TAPES") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() + + return v, nil +} + +// setViperDefaults registers defaults from NewDefaultConfig() into viper +// using dotted-key notation. This keeps defaults.go as the single source of truth. +func setViperDefaults(v *viper.Viper) { + d := NewDefaultConfig() + + v.SetDefault("version", d.Version) + + // Storage + v.SetDefault("storage.sqlite_path", d.Storage.SQLitePath) + + // Proxy + v.SetDefault("proxy.provider", d.Proxy.Provider) + v.SetDefault("proxy.upstream", d.Proxy.Upstream) + v.SetDefault("proxy.listen", d.Proxy.Listen) + + // API + v.SetDefault("api.listen", d.API.Listen) + + // Client + v.SetDefault("client.proxy_target", d.Client.ProxyTarget) + v.SetDefault("client.api_target", d.Client.APITarget) + + // Vector store + v.SetDefault("vector_store.provider", d.VectorStore.Provider) + v.SetDefault("vector_store.target", d.VectorStore.Target) + + // Embedding + v.SetDefault("embedding.provider", d.Embedding.Provider) + v.SetDefault("embedding.target", d.Embedding.Target) + v.SetDefault("embedding.model", d.Embedding.Model) + v.SetDefault("embedding.dimensions", d.Embedding.Dimensions) +}