Skip to content

Commit ad259eb

Browse files
committed
host auto fill attrs
1 parent ad122d8 commit ad259eb

File tree

5 files changed

+203
-91
lines changed

5 files changed

+203
-91
lines changed

cmd/root.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ var rootCmd = &cobra.Command{
5353
if len(args) > 0 {
5454
return fmt.Errorf("host name and args can not be used together")
5555
}
56-
} else {
56+
if len(taskConfig.Tags) > 0 {
57+
return fmt.Errorf("host name and tags can not be used together")
58+
}
59+
} else if len(taskConfig.Tags) == 0 {
5760
if len(args) == 0 {
5861
return fmt.Errorf("host name is required")
5962
}
@@ -96,7 +99,7 @@ func init() {
9699
rootCmd.Flags().Uint16VarP(&taskConfig.Port, "port", "p", 0, "remote host port")
97100
rootCmd.Flags().StringArrayVar(&taskConfig.IdentityFiles, "identity", []string{}, "identity file")
98101
rootCmd.Flags().IntVarP(&taskConfig.Parallel, "parallel", "", 1, "max parallel run tasks num")
99-
rootCmd.Flags().StringVarP(&taskConfig.Tags, "tags", "t", "", "tags filter")
102+
rootCmd.Flags().StringArrayVarP(&taskConfig.Tags, "tags", "t", []string{}, "tags filter")
100103
rootCmd.Flags().BoolVarP(&taskConfig.FailedContinue, "force", "f", false, "force run when failed")
101104

102105
// remote command

config/config.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"os"
66
"path"
7+
"strconv"
78
"strings"
89
"text/tabwriter"
910

@@ -98,11 +99,15 @@ func ListConfigHosts(name string, user string, tags string) error {
9899
return err
99100
}
100101
for _, host := range hosts {
101-
row := fmt.Sprintf("%s\t%s\t%s\t%d\t%s\t%s\t",
102+
port := ""
103+
if host.Port != 0 {
104+
port = strconv.Itoa(int(host.Port))
105+
}
106+
row := fmt.Sprintf("%s\t%s\t%s\t%s\t%s\t%s\t",
102107
strings.Join(host.Patterns, ","),
103108
host.HostName,
104109
host.Username,
105-
host.Port,
110+
port,
106111
host.ProxyJump,
107112
strings.Join(host.TagList, ","),
108113
)

config/host.go

Lines changed: 154 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"fmt"
55
"os"
66
"path/filepath"
7-
"regexp"
7+
"slices"
88
"strconv"
99
"strings"
1010

@@ -72,34 +72,6 @@ func NewHost(username, hostname string, port uint16, proxyJump string, identityF
7272
host.HostName = parts[0]
7373
}
7474

75-
// port
76-
if host.Port == 0 {
77-
portInt, err := strconv.Atoi(ssh_config.Get(host.HostName, "Port"))
78-
if err != nil {
79-
return nil, fmt.Errorf("invalid port format: %v", host.HostName)
80-
} else if portInt <= 0 || portInt >= 65536 {
81-
return nil, fmt.Errorf("invalid port format: %v", host.HostName)
82-
} else {
83-
host.Port = uint16(portInt)
84-
}
85-
86-
// default port
87-
if host.Port == 0 {
88-
host.Port = 22
89-
}
90-
}
91-
92-
// username
93-
if host.Username == "" {
94-
host.Username = ssh_config.Get(host.HostName, "User")
95-
if host.Username == "" {
96-
host.Username = os.Getenv("USER")
97-
}
98-
if host.Username == "" {
99-
host.Username = "root"
100-
}
101-
}
102-
10375
// identity files
10476
for _, identityFile := range identityFiles {
10577
if _, err := os.Stat(identityFile); err != nil {
@@ -108,42 +80,7 @@ func NewHost(username, hostname string, port uint16, proxyJump string, identityF
10880
host.IdentityFiles = append(host.IdentityFiles, identityFile)
10981
}
11082

111-
if len(host.IdentityFiles) == 0 {
112-
identityFiles := ssh_config.GetAll(host.HostName, "IdentityFile")
113-
for _, identityFile := range identityFiles {
114-
if _, err := os.Stat(identityFile); err != nil {
115-
continue
116-
}
117-
host.IdentityFiles = append(host.IdentityFiles, identityFile)
118-
}
119-
120-
// default identity file
121-
if len(host.IdentityFiles) == 0 {
122-
host.IdentityFiles = []string{filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa")}
123-
}
124-
}
125-
126-
// proxy jump
127-
if host.ProxyJump == "" {
128-
host.ProxyJump = ssh_config.Get(host.HostName, "ProxyJump")
129-
}
130-
131-
// jump list
132-
for _, jump := range strings.Split(host.ProxyJump, ",") {
133-
if jump == "" {
134-
continue
135-
}
136-
jumpHost, err := NewHost("", jump, 0, "", host.IdentityFiles)
137-
if err != nil {
138-
return nil, err
139-
}
140-
host.JumpList = append(host.JumpList, jumpHost)
141-
}
142-
143-
rawHostname := ssh_config.Get(host.HostName, "HostName")
144-
if rawHostname != "" {
145-
host.HostName = rawHostname
146-
}
83+
host.FillAttrsWithSSHConfig()
14784

14885
logger.Debugf("host: %+#v", host)
14986
return host, nil
@@ -171,27 +108,164 @@ func (host *Host) JumpString() string {
171108
return strings.Join(hosts, ",")
172109
}
173110

174-
func CheckTags(tags string) error {
175-
if matched, err := regexp.MatchString(`[0-9a-zA-z_\-,]*`, tags); err != nil {
176-
return err
177-
} else if !matched {
178-
return fmt.Errorf("Invalid tags!")
111+
func (host *Host) MatchTags(tags []string) bool {
112+
if len(tags) == 0 {
113+
return false
114+
}
115+
for _, tag := range tags {
116+
if slices.Contains(host.TagList, tag) {
117+
return true
118+
}
119+
}
120+
return false
121+
}
122+
123+
func (host *Host) fillUsername() {
124+
if host.Username != "" {
125+
return
126+
}
127+
128+
// fill username with patterns
129+
for _, pattern := range host.Patterns {
130+
if strings.ContainsAny(pattern, "*!?") {
131+
continue
132+
}
133+
host.Username = ssh_config.Get(pattern, "User")
134+
if host.Username != "" {
135+
break
136+
}
137+
}
138+
139+
// fill username with host name
140+
if host.Username == "" {
141+
host.Username = ssh_config.Get(host.HostName, "User")
142+
}
143+
144+
// fill username with environment variable
145+
if host.Username == "" {
146+
host.Username = os.Getenv("USER")
147+
}
148+
149+
// default username
150+
if host.Username == "" {
151+
host.Username = "root"
152+
}
153+
}
154+
155+
func (host *Host) fillPort() {
156+
if host.Port != 0 {
157+
return
158+
}
159+
160+
// fill port with patterns
161+
for _, pattern := range host.Patterns {
162+
if strings.ContainsAny(pattern, "*!?") {
163+
continue
164+
}
165+
portInt, err := strconv.Atoi(ssh_config.Get(pattern, "Port"))
166+
if err != nil {
167+
continue
168+
}
169+
if portInt <= 0 || portInt >= 65536 {
170+
continue
171+
}
172+
host.Port = uint16(portInt)
173+
return
174+
}
175+
176+
// fill port with host name
177+
if host.Port == 0 {
178+
portInt, err := strconv.Atoi(ssh_config.Get(host.HostName, "Port"))
179+
if err != nil {
180+
return
181+
}
182+
host.Port = uint16(portInt)
183+
}
184+
185+
// fill port with host name
186+
if host.Port == 0 {
187+
host.Port = 22
179188
}
180-
return nil
181189
}
182190

183-
func GetHostNames() (names []string) {
184-
hosts, err := GetHostsFromSSHConfig()
185-
if err != nil {
186-
return nil
191+
func (host *Host) fillProxyJump() {
192+
if host.ProxyJump != "" {
193+
return
187194
}
188-
for _, host := range hosts {
189-
for _, pattern := range host.Patterns {
190-
if strings.ContainsAny(pattern, "*!?") {
195+
196+
// fill proxy jump with patterns
197+
for _, pattern := range host.Patterns {
198+
if strings.ContainsAny(pattern, "*!?") {
199+
continue
200+
}
201+
host.ProxyJump = ssh_config.Get(pattern, "ProxyJump")
202+
}
203+
204+
// fill proxy jump with host name
205+
if host.ProxyJump == "" {
206+
host.ProxyJump = ssh_config.Get(host.HostName, "ProxyJump")
207+
}
208+
209+
// jump list
210+
if host.ProxyJump != "" {
211+
for _, jump := range strings.Split(host.ProxyJump, ",") {
212+
if jump == "" {
213+
continue
214+
}
215+
jumpHost, err := NewHost("", jump, 0, "", host.IdentityFiles)
216+
if err != nil {
191217
continue
192218
}
193-
names = append(names, pattern)
219+
host.JumpList = append(host.JumpList, jumpHost)
220+
}
221+
}
222+
}
223+
224+
func (host *Host) fillIdentityFiles() {
225+
if len(host.IdentityFiles) > 0 {
226+
return
227+
}
228+
229+
// fill identity files with patterns
230+
for _, pattern := range host.Patterns {
231+
if strings.ContainsAny(pattern, "*!?") {
232+
continue
233+
}
234+
identityFiles := ssh_config.GetAll(pattern, "IdentityFile")
235+
for _, identityFile := range identityFiles {
236+
if _, err := os.Stat(identityFile); err == nil {
237+
host.IdentityFiles = append(host.IdentityFiles, identityFile)
238+
}
194239
}
195240
}
196-
return names
241+
242+
// fill identity files with host name
243+
if len(host.IdentityFiles) == 0 {
244+
identityFiles := ssh_config.GetAll(host.HostName, "IdentityFile")
245+
for _, identityFile := range identityFiles {
246+
if _, err := os.Stat(identityFile); err == nil {
247+
host.IdentityFiles = append(host.IdentityFiles, identityFile)
248+
}
249+
}
250+
}
251+
252+
// default identity file
253+
if len(host.IdentityFiles) == 0 {
254+
defaultIdentityFile := filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa")
255+
if _, err := os.Stat(defaultIdentityFile); err == nil {
256+
host.IdentityFiles = []string{defaultIdentityFile}
257+
}
258+
}
259+
}
260+
261+
func (host *Host) FillAttrsWithSSHConfig() {
262+
host.fillUsername()
263+
host.fillPort()
264+
host.fillIdentityFiles()
265+
host.fillProxyJump() // must after identity files
266+
267+
rawHostname := ssh_config.Get(host.HostName, "HostName")
268+
if rawHostname != "" {
269+
host.HostName = rawHostname
270+
}
197271
}

config/ssh_config.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ func GetHostsFromSSHConfig() (hosts []*Host, err error) {
7272
continue
7373
}
7474

75+
// 从正则配置中解析
76+
host.FillAttrsWithSSHConfig()
77+
7578
// tags
7679
if hostConfig.EOLComment != "" {
7780
tagsRegex := regexp.MustCompile(`tags:([0-9a-zA-z_\-,]*)`)
@@ -93,3 +96,19 @@ func GetHostsFromSSHConfig() (hosts []*Host, err error) {
9396

9497
return hosts, nil
9598
}
99+
100+
func GetHostNames() (names []string) {
101+
hosts, err := GetHostsFromSSHConfig()
102+
if err != nil {
103+
return nil
104+
}
105+
for _, host := range hosts {
106+
for _, pattern := range host.Patterns {
107+
if strings.ContainsAny(pattern, "*!?") {
108+
continue
109+
}
110+
names = append(names, pattern)
111+
}
112+
}
113+
return names
114+
}

config/task_config.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"os"
66
"path/filepath"
7+
"regexp"
78
"strings"
89
)
910

@@ -51,7 +52,7 @@ type TaskConfig struct {
5152
Port uint16
5253
ProxyJump string
5354
IdentityFiles []string
54-
Tags string
55+
Tags []string
5556
Targets []string
5657
RemoteListen string
5758
ProxyServer string
@@ -101,9 +102,11 @@ func (cfg *TaskConfig) addTask(target string) (err error) {
101102
}
102103

103104
func (cfg *TaskConfig) InitTasks() error {
104-
if cfg.Tags != "" {
105-
if err := CheckTags(cfg.Tags); err != nil {
106-
return err
105+
if len(cfg.Tags) > 0 {
106+
for _, tag := range cfg.Tags {
107+
if matched, err := regexp.MatchString(`[0-9a-zA-z_\-,]*`, tag); err != nil || !matched {
108+
return fmt.Errorf("invalid tags: %s", tag)
109+
}
107110
}
108111

109112
hosts, err := GetHostsFromSSHConfig()
@@ -112,10 +115,18 @@ func (cfg *TaskConfig) InitTasks() error {
112115
}
113116

114117
for _, host := range hosts {
118+
if !host.MatchTags(cfg.Tags) {
119+
continue
120+
}
115121
task := &Task{
116-
Index: len(cfg.Tasks), Target: host,
117-
UploadSrc: cfg.UploadSrc, UploadDest: cfg.UploadDest,
118-
DownloadSrc: cfg.DownloadSrc, DownloadDest: cfg.DownloadDest,
122+
Index: len(cfg.Tasks),
123+
Target: host,
124+
RemoteListen: cfg.RemoteListen,
125+
ProxyServer: cfg.ProxyServer,
126+
UploadSrc: cfg.UploadSrc,
127+
UploadDest: cfg.UploadDest,
128+
DownloadSrc: cfg.DownloadSrc,
129+
DownloadDest: cfg.DownloadDest,
119130
}
120131
if err = task.ParseCommand(cfg.Command, cfg.Script, cfg.Module); err != nil {
121132
return err

0 commit comments

Comments
 (0)