Skip to content

Commit eea636d

Browse files
Add data-dir flag (ava-labs#1386)
Co-authored-by: Stephen <[email protected]>
1 parent 885f34a commit eea636d

File tree

4 files changed

+57
-29
lines changed

4 files changed

+57
-29
lines changed

config/config.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ var (
7474
func GetRunnerConfig(v *viper.Viper) (runner.Config, error) {
7575
config := runner.Config{
7676
DisplayVersionAndExit: v.GetBool(VersionKey),
77-
BuildDir: os.ExpandEnv(v.GetString(BuildDirKey)),
77+
BuildDir: GetExpandedArg(v, BuildDirKey),
7878
PluginMode: v.GetBool(PluginModeKey),
7979
}
8080

@@ -100,6 +100,7 @@ func GetRunnerConfig(v *viper.Viper) (runner.Config, error) {
100100

101101
foundBuildDir := false
102102
for _, dir := range defaultBuildDirs {
103+
dir = GetExpandedString(v, dir)
103104
if validBuildDir(dir) {
104105
config.BuildDir = dir
105106
foundBuildDir = true
@@ -136,7 +137,7 @@ func getConsensusConfig(v *viper.Viper) avalanche.Parameters {
136137

137138
func getLoggingConfig(v *viper.Viper) (logging.Config, error) {
138139
loggingConfig := logging.Config{}
139-
loggingConfig.Directory = os.ExpandEnv(v.GetString(LogsDirKey))
140+
loggingConfig.Directory = GetExpandedArg(v, LogsDirKey)
140141
var err error
141142
loggingConfig.LogLevel, err = logging.ToLevel(v.GetString(LogLevelKey))
142143
if err != nil {
@@ -194,7 +195,7 @@ func getIPCConfig(v *viper.Viper) node.IPCConfig {
194195
config.IPCDefaultChainIDs = strings.Split(v.GetString(IpcsChainIDsKey), ",")
195196
}
196197
if v.IsSet(IpcsPathKey) {
197-
config.IPCPath = os.ExpandEnv(v.GetString(IpcsPathKey))
198+
config.IPCPath = GetExpandedArg(v, IpcsPathKey)
198199
}
199200
return config
200201
}
@@ -213,7 +214,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) {
213214
return node.HTTPConfig{}, fmt.Errorf("unable to decode base64 content: %w", err)
214215
}
215216
case v.IsSet(HTTPSKeyFileKey):
216-
httpsKeyFilepath := os.ExpandEnv(v.GetString(HTTPSKeyFileKey))
217+
httpsKeyFilepath := GetExpandedArg(v, HTTPSKeyFileKey)
217218
if httpsKey, err = os.ReadFile(filepath.Clean(httpsKeyFilepath)); err != nil {
218219
return node.HTTPConfig{}, err
219220
}
@@ -227,7 +228,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) {
227228
return node.HTTPConfig{}, fmt.Errorf("unable to decode base64 content: %w", err)
228229
}
229230
case v.IsSet(HTTPSCertFileKey):
230-
httpsCertFilepath := os.ExpandEnv(v.GetString(HTTPSCertFileKey))
231+
httpsCertFilepath := GetExpandedArg(v, HTTPSCertFileKey)
231232
if httpsCert, err = os.ReadFile(filepath.Clean(httpsCertFilepath)); err != nil {
232233
return node.HTTPConfig{}, err
233234
}
@@ -599,7 +600,7 @@ func getIPConfig(v *viper.Viper) (node.IPConfig, error) {
599600

600601
func getProfilerConfig(v *viper.Viper) (profiler.Config, error) {
601602
config := profiler.Config{
602-
Dir: os.ExpandEnv(v.GetString(ProfileDirKey)),
603+
Dir: GetExpandedArg(v, ProfileDirKey),
603604
Enabled: v.GetBool(ProfileContinuousEnabledKey),
604605
Freq: v.GetDuration(ProfileContinuousFreqKey),
605606
MaxNumFiles: v.GetInt(ProfileContinuousMaxFilesKey),
@@ -633,8 +634,8 @@ func getStakingTLSCertFromFlag(v *viper.Viper) (tls.Certificate, error) {
633634

634635
func getStakingTLSCertFromFile(v *viper.Viper) (tls.Certificate, error) {
635636
// Parse the staking key/cert paths and expand environment variables
636-
stakingKeyPath := os.ExpandEnv(v.GetString(StakingKeyPathKey))
637-
stakingCertPath := os.ExpandEnv(v.GetString(StakingCertPathKey))
637+
stakingKeyPath := GetExpandedArg(v, StakingKeyPathKey)
638+
stakingCertPath := GetExpandedArg(v, StakingCertPathKey)
638639

639640
// If staking key/cert locations are specified but not found, error
640641
if v.IsSet(StakingKeyPathKey) || v.IsSet(StakingCertPathKey) {
@@ -684,8 +685,8 @@ func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, err
684685
config := node.StakingConfig{
685686
EnableStaking: v.GetBool(StakingEnabledKey),
686687
DisabledStakingWeight: v.GetUint64(StakingDisabledWeightKey),
687-
StakingKeyPath: os.ExpandEnv(v.GetString(StakingKeyPathKey)),
688-
StakingCertPath: os.ExpandEnv(v.GetString(StakingCertPathKey)),
688+
StakingKeyPath: GetExpandedArg(v, StakingKeyPathKey),
689+
StakingCertPath: GetExpandedArg(v, StakingCertPathKey),
689690
}
690691
if !config.EnableStaking && config.DisabledStakingWeight == 0 {
691692
return node.StakingConfig{}, errInvalidStakerWeights
@@ -755,7 +756,7 @@ func getGenesisData(v *viper.Viper, networkID uint32) ([]byte, ids.ID, error) {
755756

756757
// if content is not specified go for the file
757758
if v.IsSet(GenesisConfigFileKey) {
758-
genesisFileName := os.ExpandEnv(v.GetString(GenesisConfigFileKey))
759+
genesisFileName := GetExpandedArg(v, GenesisConfigFileKey)
759760
return genesis.FromFile(networkID, genesisFileName)
760761
}
761762

@@ -794,7 +795,7 @@ func getDatabaseConfig(v *viper.Viper, networkID uint32) (node.DatabaseConfig, e
794795
return node.DatabaseConfig{}, fmt.Errorf("unable to decode base64 content: %w", err)
795796
}
796797
} else if v.IsSet(DBConfigFileKey) {
797-
path := os.ExpandEnv(v.GetString(DBConfigFileKey))
798+
path := GetExpandedArg(v, DBConfigFileKey)
798799
configBytes, err = os.ReadFile(path)
799800
if err != nil {
800801
return node.DatabaseConfig{}, err
@@ -804,7 +805,7 @@ func getDatabaseConfig(v *viper.Viper, networkID uint32) (node.DatabaseConfig, e
804805
return node.DatabaseConfig{
805806
Name: v.GetString(DBTypeKey),
806807
Path: filepath.Join(
807-
os.ExpandEnv(v.GetString(DBPathKey)),
808+
GetExpandedArg(v, DBPathKey),
808809
constants.NetworkName(networkID),
809810
),
810811
Config: configBytes,
@@ -866,7 +867,7 @@ func getVMManager(v *viper.Viper) (vms.Manager, error) {
866867

867868
// getPathFromDirKey reads flag value from viper instance and then checks the folder existence
868869
func getPathFromDirKey(v *viper.Viper, configKey string) (string, error) {
869-
configDir := os.ExpandEnv(v.GetString(configKey))
870+
configDir := GetExpandedArg(v, configKey)
870871
cleanPath := filepath.Clean(configDir)
871872
ok, err := storage.FolderExists(cleanPath)
872873
if err != nil {

config/flags.go

+40-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414

1515
"github.com/kardianos/osext"
1616

17+
"github.com/spf13/viper"
18+
1719
"github.com/ava-labs/avalanchego/database/leveldb"
1820
"github.com/ava-labs/avalanchego/database/memdb"
1921
"github.com/ava-labs/avalanchego/database/rocksdb"
@@ -26,21 +28,21 @@ import (
2628
const (
2729
DefaultHTTPPort = 9650
2830
DefaultStakingPort = 9651
31+
32+
AvalancheGoDataDirVar = "AVALANCHEGO_DATA_DIR"
33+
defaultUnexpandedDataDir = "$" + AvalancheGoDataDirVar
2934
)
3035

31-
// Results of parsing the CLI
3236
var (
33-
defaultNetworkName = constants.MainnetName
34-
homeDir = os.ExpandEnv("$HOME")
35-
prefixedAppName = fmt.Sprintf(".%s", constants.AppName)
36-
defaultDataDir = filepath.Join(homeDir, prefixedAppName)
37-
defaultDBDir = filepath.Join(defaultDataDir, "db")
38-
defaultProfileDir = filepath.Join(defaultDataDir, "profiles")
39-
defaultStakingPath = filepath.Join(defaultDataDir, "staking")
37+
// [defaultUnexpandedDataDir] will be expanded when reading the flags
38+
defaultDataDir = filepath.Join("$HOME", ".avalanchego")
39+
defaultDBDir = filepath.Join(defaultUnexpandedDataDir, "db")
40+
defaultLogDir = filepath.Join(defaultUnexpandedDataDir, "logs")
41+
defaultProfileDir = filepath.Join(defaultUnexpandedDataDir, "profiles")
42+
defaultStakingPath = filepath.Join(defaultUnexpandedDataDir, "staking")
4043
defaultStakingKeyPath = filepath.Join(defaultStakingPath, "staker.key")
4144
defaultStakingCertPath = filepath.Join(defaultStakingPath, "staker.crt")
42-
defaultLogDirectory = filepath.Join(defaultDataDir, "logs")
43-
defaultConfigDir = filepath.Join(defaultDataDir, "configs")
45+
defaultConfigDir = filepath.Join(defaultUnexpandedDataDir, "configs")
4446
defaultChainConfigDir = filepath.Join(defaultConfigDir, "chains")
4547
defaultVMConfigDir = filepath.Join(defaultConfigDir, "vms")
4648
defaultVMAliasFilePath = filepath.Join(defaultVMConfigDir, "aliases.json")
@@ -63,7 +65,7 @@ func init() {
6365
defaultBuildDirs = append(defaultBuildDirs,
6466
wd,
6567
filepath.Join("/", "usr", "local", "lib", constants.AppName),
66-
defaultDataDir,
68+
defaultUnexpandedDataDir,
6769
)
6870
}
6971

@@ -79,6 +81,8 @@ func addProcessFlags(fs *flag.FlagSet) {
7981
}
8082

8183
func addNodeFlags(fs *flag.FlagSet) {
84+
// Home directory
85+
fs.String(DataDirKey, defaultDataDir, "Sets the base data directory where default sub-directories will be placed unless otherwise specified.")
8286
// System
8387
fs.Uint64(FdLimitKey, ulimit.DefaultFDLimit, "Attempts to raise the process file descriptor limit to at least this value and error if the value is above the system max")
8488

@@ -93,7 +97,7 @@ func addNodeFlags(fs *flag.FlagSet) {
9397
fs.String(GenesisConfigContentKey, "", "Specifies base64 encoded genesis content")
9498

9599
// Network ID
96-
fs.String(NetworkNameKey, defaultNetworkName, "Network ID this node will connect to")
100+
fs.String(NetworkNameKey, constants.MainnetName, "Network ID this node will connect to")
97101

98102
// AVAX fees
99103
fs.Uint64(TxFeeKey, genesis.LocalParams.TxFee, "Transaction fee, in nAVAX")
@@ -108,7 +112,7 @@ func addNodeFlags(fs *flag.FlagSet) {
108112
fs.String(DBConfigContentKey, "", "Specifies base64 encoded database config content")
109113

110114
// Logging
111-
fs.String(LogsDirKey, defaultLogDirectory, "Logging directory for Avalanche")
115+
fs.String(LogsDirKey, defaultLogDir, "Logging directory for Avalanche")
112116
fs.String(LogLevelKey, "info", "The log level. Should be one of {verbo, debug, trace, info, warn, error, fatal, off}")
113117
fs.String(LogDisplayLevelKey, "", "The log display level. If left blank, will inherit the value of log-level. Otherwise, should be one of {verbo, debug, trace, info, warn, error, fatal, off}")
114118
fs.String(LogFormatKey, "auto", "The structure of log format. Defaults to 'auto' which formats terminal-like logs, when the output is a terminal. Otherwise, should be one of {auto, plain, colors, json}")
@@ -355,3 +359,26 @@ func BuildFlagSet() *flag.FlagSet {
355359
addNodeFlags(fs)
356360
return fs
357361
}
362+
363+
// GetExpandedArg gets the string in viper corresponding to [key] and expands
364+
// any variables using the OS env. If the [AvalancheGoDataDirVar] var is used,
365+
// we expand the value of the variable with the string in viper corresponding to
366+
// [DataDirKey].
367+
func GetExpandedArg(v *viper.Viper, key string) string {
368+
return GetExpandedString(v, v.GetString(key))
369+
}
370+
371+
// GetExpandedString expands [s] with any variables using the OS env. If the
372+
// [AvalancheGoDataDirVar] var is used, we expand the value of the variable with
373+
// the string in viper corresponding to [DataDirKey].
374+
func GetExpandedString(v *viper.Viper, s string) string {
375+
return os.Expand(
376+
s,
377+
func(strVar string) string {
378+
if strVar == AvalancheGoDataDirVar {
379+
return os.ExpandEnv(v.GetString(DataDirKey))
380+
}
381+
return os.Getenv(strVar)
382+
},
383+
)
384+
}

config/keys.go

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package config
55

66
// #nosec G101
77
const (
8+
DataDirKey = "data-dir"
89
ConfigFileKey = "config-file"
910
ConfigContentKey = "config-file-content"
1011
ConfigContentTypeKey = "config-file-content-type"

config/viper.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"flag"
1010
"fmt"
1111
"io"
12-
"os"
1312
"strings"
1413

1514
"github.com/spf13/viper"
@@ -49,7 +48,7 @@ func BuildViper(fs *flag.FlagSet, args []string) (*viper.Viper, error) {
4948
}
5049

5150
case v.IsSet(ConfigFileKey):
52-
filename := os.ExpandEnv(v.GetString(ConfigFileKey))
51+
filename := GetExpandedArg(v, ConfigFileKey)
5352
v.SetConfigFile(filename)
5453
if err := v.ReadInConfig(); err != nil {
5554
return nil, err

0 commit comments

Comments
 (0)