Skip to content

Commit

Permalink
Merge pull request #64 from Tecnobutrul/allow-http-client-configuration
Browse files Browse the repository at this point in the history
Added support for http client configuration via command arguments
  • Loading branch information
speatzle authored Dec 20, 2024
2 parents d9703ff + 78ed21f commit 6033d6b
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 12 deletions.
49 changes: 39 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"time"
Expand Down Expand Up @@ -60,6 +59,12 @@ func init() {
rootCmd.PersistentFlags().Uint("mfaRetrys", 3, "How often to retry TOTP Auth, only used in nointeractive modes")
rootCmd.PersistentFlags().Duration("mfaDelay", time.Second*10, "Delay between MFA Attempts, only used in noninteractive modes")

rootCmd.PersistentFlags().Bool("tlsSkipVerify", false, "Allow servers with self-signed certificates")
rootCmd.PersistentFlags().String("tlsClientPrivateKeyFile", "", "Client private key path for mtls")
rootCmd.PersistentFlags().String("tlsClientCertFile", "", "Client certificate path for mtls")
rootCmd.PersistentFlags().String("tlsClientPrivateKey", "", "Client private key for mtls")
rootCmd.PersistentFlags().String("tlsClientCert", "", "Client certificate for mtls")

viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))
viper.BindPFlag("timeout", rootCmd.PersistentFlags().Lookup("timeout"))
viper.BindPFlag("serverAddress", rootCmd.PersistentFlags().Lookup("serverAddress"))
Expand All @@ -72,6 +77,22 @@ func init() {
viper.BindPFlag("mfaTotpOffset", rootCmd.PersistentFlags().Lookup("mfaTotpOffset"))
viper.BindPFlag("mfaRetrys", rootCmd.PersistentFlags().Lookup("mfaRetrys"))
viper.BindPFlag("mfaDelay", rootCmd.PersistentFlags().Lookup("mfaDelay"))

viper.BindPFlag("tlsSkipVerify", rootCmd.PersistentFlags().Lookup("tlsSkipVerify"))
viper.BindPFlag("tlsClientCert", rootCmd.PersistentFlags().Lookup("tlsClientCert"))
viper.BindPFlag("tlsClientPrivateKey", rootCmd.PersistentFlags().Lookup("tlsClientPrivateKey"))
}

func fileToContent(file, contentFlag string) {
if viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Loading file:", file)
}
content, err := os.ReadFile(file)
if err != nil {
fmt.Fprintln(os.Stderr, "Error Loading File: ", err)
os.Exit(1)
}
viper.Set(contentFlag, string(content))
}

// initConfig reads in config file and ENV variables if set.
Expand Down Expand Up @@ -107,18 +128,26 @@ func initConfig() {
// Read in Private Key from File if userprivatekeyfile is set
userprivatekeyfile, err := rootCmd.PersistentFlags().GetString("userPrivateKeyFile")
if err == nil && userprivatekeyfile != "" {
if viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Loading Private Key from File:", userprivatekeyfile)
}
content, err := ioutil.ReadFile(userprivatekeyfile)
if err != nil {
fmt.Fprintln(os.Stderr, "Error Loading Private Key from File: ", err)
os.Exit(1)
}
viper.Set("userprivatekey", string(content))
fileToContent(userprivatekeyfile, "userPrivateKey")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Private Key File Flag:", err)
}

// Read in Client Certificate Private Key from File if tlsClientPrivateKeyFile is set
tlsclientprivatekeyfile, err := rootCmd.PersistentFlags().GetString("tlsClientPrivateKeyFile")
if err == nil && tlsclientprivatekeyfile != "" {
fileToContent(tlsclientprivatekeyfile, "tlsClientPrivateKey")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Client Certificate Private key File Flag:", err)
}

// Read in Client Certificate from File if tlsClientCertFile is set
tlsclientcertfile, err := rootCmd.PersistentFlags().GetString("tlsClientCertFile")
if err == nil && tlsclientcertfile != "" {
fileToContent(tlsclientcertfile, "tlsClientCert")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Client Certificate File Flag:", err)
}
}

func SetVersionInfo(version, commit, date string, dirty bool) {
Expand Down
6 changes: 5 additions & 1 deletion cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ var verifyCMD = &cobra.Command{
fmt.Println()
}

client, err := api.NewClient(nil, "", serverAddress, userPrivateKey, userPassword)
httpClient, err := util.GetHttpClient()
if err != nil {
return err
}
client, err := api.NewClient(httpClient, "", serverAddress, userPrivateKey, userPassword)
if err != nil {
return fmt.Errorf("Creating Client: %w", err)
}
Expand Down
6 changes: 5 additions & 1 deletion util/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ func GetClient(ctx context.Context) (*api.Client, error) {
fmt.Println()
}

client, err := api.NewClient(nil, "", serverAddress, userPrivateKey, userPassword)
httpClient, err := GetHttpClient()
if err != nil {
return nil, err
}
client, err := api.NewClient(httpClient, "", serverAddress, userPrivateKey, userPassword)
if err != nil {
return nil, fmt.Errorf("Creating Client: %w", err)
}
Expand Down
44 changes: 44 additions & 0 deletions util/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package util

import (
"crypto/tls"
"fmt"
"net/http"

"github.com/spf13/viper"
)

func GetClientCertificate() (tls.Certificate, error) {
cert := viper.GetString("tlsClientCert")
certExists := cert != ""
key := viper.GetString("tlsClientPrivateKey")
keyExists := key != ""
if !certExists && !keyExists {
return tls.Certificate{}, nil
}
if certExists && !keyExists {
return tls.Certificate{}, fmt.Errorf("Client TLS private key is empty, but client TLS cert was set.")
}
if !certExists && keyExists {
return tls.Certificate{}, fmt.Errorf("Client TLS cert is empty, but client TLS private key was set.")
}
return tls.X509KeyPair([]byte(cert), []byte(key))
}

func GetHttpClient() (*http.Client, error) {
tlsSkipVerify := viper.GetBool("tlsSkipVerify")
cert, err := GetClientCertificate()
if err != nil {
return nil, err
}
httpClient := http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: tlsSkipVerify,
},
},
}

return &httpClient, nil
}

0 comments on commit 6033d6b

Please sign in to comment.