Skip to content

Commit

Permalink
Fix: fix CLI arg <-> config <-> env parity
Browse files Browse the repository at this point in the history
  • Loading branch information
yunginnanet committed Jun 26, 2024
1 parent 23c65b9 commit 93dcb98
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 24 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
git.tcp.direct/kayos/common v0.9.7
github.com/fasthttp/router v1.5.1
github.com/knadh/koanf/parsers/toml v0.1.0
github.com/knadh/koanf/providers/basicflag v1.0.0
github.com/knadh/koanf/providers/env v0.1.0
github.com/knadh/koanf/v2 v2.1.1
github.com/rs/zerolog v1.33.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NI
github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI=
github.com/knadh/koanf/parsers/toml v0.1.0 h1:S2hLqS4TgWZYj4/7mI5m1CQQcWurxUz6ODgOub/6LCI=
github.com/knadh/koanf/parsers/toml v0.1.0/go.mod h1:yUprhq6eo3GbyVXFFMdbfZSo928ksS+uo0FFqNMnO18=
github.com/knadh/koanf/providers/basicflag v1.0.0 h1:qB0es/9fYsLuYnrKazxNCuWtkv3JFX1lI1druUsDDvY=
github.com/knadh/koanf/providers/basicflag v1.0.0/go.mod h1:n0NlnaxXUCER/WIzRroT9q3Np+FiZ9pSjrC6A/OozI8=
github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg=
github.com/knadh/koanf/providers/env v0.1.0/go.mod h1:RE8K9GbACJkeEnkl8L/Qcj8p4ZyPXZIQ191HJi44ZaQ=
github.com/knadh/koanf/v2 v2.1.1 h1:/R8eXqasSTsmDCsAyYj+81Wteg8AqrV9CP6gvsTsOmM=
Expand Down
73 changes: 62 additions & 11 deletions internal/config/command_line.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,88 @@ import (
"flag"
"io"
"os"
"slices"
"strings"

"github.com/yunginnanet/HellPot/internal/extra"
"github.com/yunginnanet/HellPot/internal/version"
)

var CLIFlags = flag.NewFlagSet("config", flag.ExitOnError)
var CLIFlags = flag.NewFlagSet("config", flag.ContinueOnError)

var (
sliceDefs = make(map[string][]string)
slicePtrs = make(map[string]*string)
)

func addCLIFlags() {
parse := func(k string, v interface{}, nestedName string) {
switch casted := v.(type) {
case bool:
CLIFlags.Bool(nestedName, casted, "set "+k)
case string:
CLIFlags.String(nestedName, casted, "set "+k)
case int:
CLIFlags.Int(nestedName, casted, "set "+k)
case float64:
CLIFlags.Float64(nestedName, casted, "set "+k)
case []string:
sliceDefs[nestedName] = casted
joined := strings.Join(sliceDefs[nestedName], ",")
slicePtrs[nestedName] = CLIFlags.String(nestedName, joined, "set "+k)
}
}

for key, val := range Defaults.val {
if _, ok := val.(map[string]interface{}); !ok {
parse(key, val, key)
continue
}
nested, ok := val.(map[string]interface{})
if !ok {
// linter was confused by the above check
panic("unreachable, if you see this you have entered a real life HellPot")
}
for k, v := range nested {
nestedName := key + "." + k
parse(k, v, nestedName)
}
}
}

var replacer = map[string][]string{
"-h": {"-help"},
"-v": {"-version"},
"-c": {"-config"},
"-g": {"-bespoke.enable_grimoire", "true", "-bespoke.grimoire_file"},
}

func InitCLI() {
newArgs := make([]string, 0)
for _, arg := range os.Args {
if repl, ok := replacer[arg]; ok {
newArgs = append(newArgs, repl...)
continue
}
// check for unit test flags
if !strings.HasPrefix(arg, "-test.") {
newArgs = append(newArgs, arg)
}
}

CLIFlags.Bool("logger-debug", false, "force debug logging")
CLIFlags.Bool("logger-trace", false, "force trace logging")
CLIFlags.Bool("logger-nocolor", false, "force no color logging")
CLIFlags.String("bespoke-grimoire", "", "specify a custom file used for text generation")
newArgs = slices.Compact(newArgs)

CLIFlags.Bool("banner", false, "show banner and version then exit")
CLIFlags.Bool("genconfig", false, "write default config to stdout then exit")
CLIFlags.Bool("h", false, "show this help and exit")
CLIFlags.Bool("help", false, "show this help and exit")
CLIFlags.String("c", "", "specify config file")
CLIFlags.String("config", "", "specify config file")
CLIFlags.String("version", "", "show version and exit")
CLIFlags.String("v", "", "show version and exit")

addCLIFlags()

if err := CLIFlags.Parse(newArgs[1:]); err != nil {
println(err.Error())
// flag.ExitOnError will call os.Exit(2)
os.Exit(2)
}
if os.Getenv("HELLPOT_CONFIG") != "" {
if err := CLIFlags.Set("config", os.Getenv("HELLPOT_CONFIG")); err != nil {
Expand All @@ -45,11 +95,11 @@ func InitCLI() {
panic(err)
}
}
if CLIFlags.Lookup("h").Value.String() == "true" || CLIFlags.Lookup("help").Value.String() == "true" {
if CLIFlags.Lookup("help").Value.String() == "true" {
CLIFlags.Usage()
os.Exit(0)
}
if CLIFlags.Lookup("version").Value.String() == "true" || CLIFlags.Lookup("v").Value.String() == "true" {
if CLIFlags.Lookup("version").Value.String() == "true" {
_, _ = os.Stdout.WriteString("HellPot version: " + version.Version + "\n")
os.Exit(0)
}
Expand All @@ -66,4 +116,5 @@ func InitCLI() {
extra.Banner()
os.Exit(0)
}

}
4 changes: 4 additions & 0 deletions internal/config/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,8 @@ var defOpts = map[string]interface{}{
"deception": map[string]interface{}{
"server_name": "nginx",
},
"bespoke": map[string]interface{}{
"grimoire_file": "",
"enable_grimoire": false,
},
}
4 changes: 2 additions & 2 deletions internal/config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ type DevilsPlaythings struct {

// Customization represents the configuration for the customizations.
type Customization struct {
CustomHeffalump bool `koanf:"custom_heffalump"`
Grimoire string `koanf:"grimoire"`
CustomHeffalump bool `koanf:"enable_grimoire"`
Grimoire string `koanf:"grimoire_file"`
}
141 changes: 131 additions & 10 deletions internal/config/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package config
import (
"fmt"
"io"
"slices"
"strings"

"github.com/knadh/koanf/parsers/toml"
flags "github.com/knadh/koanf/providers/basicflag"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/v2"
)
Expand All @@ -26,19 +28,89 @@ func (r *readerProvider) Read() (map[string]interface{}, error) {
return toml.Parser().Unmarshal(b) //nolint:wrapcheck
}

func Setup(source io.Reader) (*Parameters, error) {
k := koanf.New(".")
func normalizeMap(m map[string]interface{}) map[string]interface{} {
for k, v := range m {
ogk := k
k = strings.ToLower(k)

if err := k.Load(Defaults, nil); err != nil {
return nil, fmt.Errorf("failed to load defaults: %w", err)
var sslice []string
var sliceOK bool

if sslice, sliceOK = v.([]string); !sliceOK {
goto justLower
}
for i, s := range sslice {
sslice[i] = strings.ToLower(s)
}
slices.Sort(sslice)
m[k] = sslice
justLower:
if k != ogk {
delete(m, ogk)
}
}
return m
}

if source != nil {
if err := k.Load(&readerProvider{source}, toml.Parser()); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
func (p *Parameters) merge(ogk *koanf.Koanf, newk *koanf.Koanf, friendlyName string) error {
if ogk == nil {
panic("original koanf is nil")
}
if newk == nil {
return nil
}
dirty := false

newKeys := normalizeMap(newk.All())

if len(newk.All()) == 0 || len(newKeys) == 0 {
return nil
}

for k, v := range newKeys {
if !ogk.Exists(k) {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
continue
}

ogv := ogk.Get(k)
if ogv == nil {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
continue
}

if _, hasDefault := Defaults.val[k]; !hasDefault {
continue
}

if ogv == Defaults.val[k] && v != ogv {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
}
}

if !dirty {
return nil
}

println("found configuration overrides in " + friendlyName)

if err := ogk.Merge(newk); err != nil {
return fmt.Errorf("failed to merge env config: %w", err)
}

return nil
}

func (p *Parameters) LoadEnv(k *koanf.Koanf) error {
envK := koanf.New(".")

envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string {
Expand All @@ -50,9 +122,50 @@ func Setup(source io.Reader) (*Parameters, error) {
return s
}), nil)

if envErr == nil && envK != nil && len(envK.All()) > 0 {
if err := k.Merge(envK); err != nil {
return nil, fmt.Errorf("failed to merge env config: %w", err)
if envErr != nil {
return fmt.Errorf("failed to load env: %w", envErr)
}

if err := p.merge(k, envK, "environment variables"); err != nil {
return err
}

return nil
}

func parseCLISlice(key string, value string) (string, interface{}) {
if _, ok := slicePtrs[key]; !ok {
return key, value
}
split := strings.Split(value, ",")
slices.Sort(split)
return key, split
}

func (p *Parameters) LoadFlags(k *koanf.Koanf) error {
flagsK := koanf.New(".")

if err := flagsK.Load(flags.ProviderWithValue(CLIFlags, ".", parseCLISlice), nil); err != nil {
return fmt.Errorf("failed to load flags: %w", err)
}

if err := p.merge(k, flagsK, "cli arguments"); err != nil {
return err
}

return nil
}

func Setup(source io.Reader) (*Parameters, error) {
k := koanf.New(".")

if err := k.Load(Defaults, nil); err != nil {
return nil, fmt.Errorf("failed to load defaults: %w", err)
}

if source != nil {
if err := k.Load(&readerProvider{source}, toml.Parser()); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
}

Expand All @@ -64,6 +177,14 @@ func Setup(source io.Reader) (*Parameters, error) {
p.UsingDefaults = true
}

if err := p.LoadFlags(k); err != nil {
return nil, err
}

if err := p.LoadEnv(k); err != nil {
return nil, err
}

if err := k.Unmarshal("", p); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
Expand Down
6 changes: 5 additions & 1 deletion internal/http/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func getSrv(r *router.Router) fasthttp.Server {
}
}

func setupHeffalump(config *config.Parameters) error {
func SetupHeffalump(config *config.Parameters) error {
switch config.Bespoke.CustomHeffalump {
case true:
content, err := os.ReadFile(config.Bespoke.Grimoire)
Expand Down Expand Up @@ -151,6 +151,10 @@ func Serve(config *config.Parameters) error {
log = config.GetLogger()
runningConfig = config

if err := SetupHeffalump(config); err != nil {
return fmt.Errorf("failed to setup heffalump: %w", err)
}

l := config.HTTP.Bind + ":" + strconv.Itoa(int(config.HTTP.Port))

r := router.New()
Expand Down

0 comments on commit 93dcb98

Please sign in to comment.