diff --git a/clients/stellartoml/client.go b/clients/stellartoml/client.go index 373f3d36f0..3f92789fd7 100644 --- a/clients/stellartoml/client.go +++ b/clients/stellartoml/client.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "os" "github.com/BurntSushi/toml" "github.com/stellar/go/address" @@ -55,6 +56,14 @@ func (c *Client) GetStellarTomlByAddress(addr string) (*Response, error) { return c.GetStellarToml(domain) } +func getWellKnownPathFromEnv() string { + path := os.Getenv("STELLAR_TOML_PATH") + if path == "" { + return WellKnownPath + } + return path +} + // url returns the appropriate url to load for resolving domain's stellar.toml // file func (c *Client) url(domain string) string { @@ -66,5 +75,7 @@ func (c *Client) url(domain string) string { scheme = "https" } - return fmt.Sprintf("%s://%s%s", scheme, domain, WellKnownPath) + wellKnownPath := getWellKnownPathFromEnv() + + return fmt.Sprintf("%s://%s%s", scheme, domain, wellKnownPath) } diff --git a/clients/stellartoml/client_test.go b/clients/stellartoml/client_test.go index b93c1e44a0..f113c5ac4e 100644 --- a/clients/stellartoml/client_test.go +++ b/clients/stellartoml/client_test.go @@ -1,6 +1,7 @@ package stellartoml import ( + "os" "strings" "testing" @@ -64,3 +65,17 @@ func TestClient(t *testing.T) { assert.Contains(t, err.Error(), "toml decode failed") } } + +func TestGetWellKnownPathFromEnv(t *testing.T) { + // Backup and defer restore + orig := os.Getenv("STELLAR_TOML_PATH") + defer os.Setenv("STELLAR_TOML_PATH", orig) + + // Test default + os.Unsetenv("STELLAR_TOML_PATH") + assert.Equal(t, WellKnownPath, getWellKnownPathFromEnv()) + + // Test custom path + os.Setenv("STELLAR_TOML_PATH", "/custom/path/stellar.toml") + assert.Equal(t, "/custom/path/stellar.toml", getWellKnownPathFromEnv()) +}