Skip to content

Commit 5724bd1

Browse files
committed
config: add UserSettings.ConfigFinder
Fixes #48.
1 parent aae6f39 commit 5724bd1

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

Diff for: config.go

+39-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ type configFinder func() string
5353
// files are parsed and cached the first time Get() or GetStrict() is called.
5454
type UserSettings struct {
5555
IgnoreErrors bool
56+
customConfig *Config
57+
customConfigFinder configFinder
5658
systemConfig *Config
5759
systemConfigFinder configFinder
5860
userConfig *Config
@@ -203,6 +205,13 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
203205
if u.onceErr != nil && u.IgnoreErrors == false {
204206
return "", u.onceErr
205207
}
208+
// TODO this is getting repetitive
209+
if u.customConfig != nil {
210+
val, err := findVal(u.customConfig, alias, key)
211+
if err != nil || val != "" {
212+
return val, err
213+
}
214+
}
206215
val, err := findVal(u.userConfig, alias, key)
207216
if err != nil || val != "" {
208217
return val, err
@@ -228,6 +237,12 @@ func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
228237
if u.onceErr != nil && u.IgnoreErrors == false {
229238
return nil, u.onceErr
230239
}
240+
if u.customConfig != nil {
241+
val, err := findAll(u.customConfig, alias, key)
242+
if err != nil || val != nil {
243+
return val, err
244+
}
245+
}
231246
val, err := findAll(u.userConfig, alias, key)
232247
if err != nil || val != nil {
233248
return val, err
@@ -243,16 +258,38 @@ func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
243258
return []string{}, nil
244259
}
245260

261+
// ConfigFinder will invoke f to try to find a ssh config file in a custom
262+
// location on disk, instead of in /etc/ssh or $HOME/.ssh. f should return the
263+
// name of a file containing SSH configuration.
264+
//
265+
// ConfigFinder must be invoked before any calls to Get or GetStrict and panics
266+
// if f is nil. Most users should not need to use this function.
267+
func (u *UserSettings) ConfigFinder(f func() string) {
268+
if f == nil {
269+
panic("cannot call ConfigFinder with nil function")
270+
}
271+
u.customConfigFinder = f
272+
}
273+
246274
func (u *UserSettings) doLoadConfigs() {
247275
u.loadConfigs.Do(func() {
248-
// can't parse user file, that's ok.
249276
var filename string
277+
var err error
278+
if u.customConfigFinder != nil {
279+
filename = u.customConfigFinder()
280+
u.customConfig, err = parseFile(filename)
281+
// IsNotExist should be returned because a user specified this
282+
// function - not existing likely means they made an error
283+
if err != nil {
284+
u.onceErr = err
285+
}
286+
return
287+
}
250288
if u.userConfigFinder == nil {
251289
filename = userConfigFinder()
252290
} else {
253291
filename = u.userConfigFinder()
254292
}
255-
var err error
256293
u.userConfig, err = parseFile(filename)
257294
//lint:ignore S1002 I prefer it this way
258295
if err != nil && os.IsNotExist(err) == false {

Diff for: config_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,15 @@ func TestNoTrailingNewline(t *testing.T) {
455455
t.Errorf("wrong port: got %q want 4242", port)
456456
}
457457
}
458+
459+
func TestCustomFinder(t *testing.T) {
460+
us := &UserSettings{}
461+
us.ConfigFinder(func() string {
462+
return "testdata/config1"
463+
})
464+
465+
val := us.Get("wap", "User")
466+
if val != "root" {
467+
t.Errorf("expected to find User root, got %q", val)
468+
}
469+
}

Diff for: example_test.go

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh_config_test
22

33
import (
44
"fmt"
5+
"path/filepath"
56
"strings"
67

78
"github.com/kevinburke/ssh_config"
@@ -46,3 +47,12 @@ func ExampleDefault() {
4647
// 22
4748
//
4849
}
50+
51+
func ExampleUserSettings_ConfigFinder() {
52+
// This can be used to test SSH config parsing.
53+
u := ssh_config.UserSettings{}
54+
u.ConfigFinder(func() string {
55+
return filepath.Join("testdata", "test_config")
56+
})
57+
u.Get("example.com", "Host")
58+
}

0 commit comments

Comments
 (0)