Skip to content

Commit

Permalink
Merge pull request #388 from zong-zhe/refactor-oci-auth
Browse files Browse the repository at this point in the history
feat: add cache for credential to reduce the probability that kpm would be considered a threat
Peefy authored Jul 18, 2024
2 parents b65d5ed + 74b0e81 commit 635551d
Showing 8 changed files with 291 additions and 40 deletions.
142 changes: 135 additions & 7 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"fmt"
"io"
@@ -19,7 +20,9 @@ import (
"github.com/otiai10/copy"
"golang.org/x/mod/module"
"kcl-lang.io/kcl-go/pkg/kcl"
"oras.land/oras-go/pkg/auth"
"oras.land/oras-go/v2"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"

"kcl-lang.io/kpm/pkg/constants"
"kcl-lang.io/kpm/pkg/downloader"
@@ -41,6 +44,8 @@ type KpmClient struct {
logWriter io.Writer
// The downloader of the dependencies.
DepDownloader *downloader.DepDownloader
// credential store
credsClient *downloader.CredClient
// The home path of kpm for global configuration file and kcl package storage path.
homePath string
// The settings of kpm loaded from the global configuration file.
@@ -75,6 +80,33 @@ func (c *KpmClient) SetNoSumCheck(noSumCheck bool) {
c.noSumCheck = noSumCheck
}

// GetCredsClient will return the credential client.
func (c *KpmClient) GetCredsClient() (*downloader.CredClient, error) {
if c.credsClient == nil {
credCli, err := downloader.LoadCredentialFile(c.settings.CredentialsFile)
if err != nil {
return nil, err
}
c.credsClient = credCli
}
return c.credsClient, nil
}

// GetCredentials will return the credentials of the host.
func (c *KpmClient) GetCredentials(hostName string) (*remoteauth.Credential, error) {
credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}

creds, err := credCli.Credential(hostName)
if err != nil {
return nil, err
}

return creds, nil
}

// GetNoSumCheck will return the 'noSumCheck' flag.
func (c *KpmClient) GetNoSumCheck() bool {
return c.noSumCheck
@@ -953,7 +985,18 @@ func (c *KpmClient) FillDependenciesInfo(modFile *pkg.ModFile) error {

// AcquireTheLatestOciVersion will acquire the latest version of the OCI reference.
func (c *KpmClient) AcquireTheLatestOciVersion(ociSource downloader.Oci) (string, error) {
ociClient, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &c.settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)
cred, err := c.GetCredentials(ociSource.Reg)
if err != nil {
return "", err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
@@ -1098,11 +1141,16 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
// clean the temp dir.
defer os.RemoveAll(tmpDir)

credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}
err = c.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(dep.Source),
downloader.WithLogWriter(c.logWriter),
downloader.WithSettings(c.settings),
downloader.WithCredsClient(credCli),
))
if err != nil {
return nil, err
@@ -1276,10 +1324,22 @@ func (c *KpmClient) ParseKclModFile(kclPkg *pkg.KclPkg) (map[string]map[string]s

// LoadPkgFromOci will download the kcl package from the oci repository and return an `KclPkg`.
func (c *KpmClient) DownloadPkgFromOci(dep *downloader.Oci, localPath string) (*pkg.KclPkg, error) {
ociClient, err := oci.NewOciClient(dep.Reg, dep.Repo, &c.settings)
repoPath := utils.JoinPath(dep.Reg, dep.Repo)
cred, err := c.GetCredentials(dep.Reg)
if err != nil {
return nil, err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return nil, err
}

ociClient.SetLogWriter(c.logWriter)
// Select the latest tag, if the tag, the user inputed, is empty.
var tagSelected string
@@ -1478,7 +1538,18 @@ func (c *KpmClient) PullFromOci(localPath, source, tag string) error {

// PushToOci will push a kcl package to oci registry.
func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {
ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
@@ -1504,12 +1575,46 @@ func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {

// LoginOci will login to the oci registry.
func (c *KpmClient) LoginOci(hostname, username, password string) error {
return oci.Login(hostname, username, password, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().LoginWithOpts(
[]auth.LoginOption{
auth.WithLoginHostname(hostname),
auth.WithLoginUsername(username),
auth.WithLoginSecret(password),
}...,
)

if err != nil {
return reporter.NewErrorEvent(
reporter.FailedLogin,
err,
fmt.Sprintf("failed to login '%s', please check registry, username and password is valid", hostname),
)
}

return nil
}

// LogoutOci will logout from the oci registry.
func (c *KpmClient) LogoutOci(hostname string) error {
return oci.Logout(hostname, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().Logout(context.Background(), hostname)

if err != nil {
return reporter.NewErrorEvent(reporter.FailedLogout, err, fmt.Sprintf("failed to logout '%s'", hostname))
}

return nil
}

// ParseOciRef will parser '<repo_name>:<repo_tag>' into an 'OciOptions'.
@@ -1753,7 +1858,18 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er
return reporter.NewErrorEvent(reporter.Bug, err)
}

ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
@@ -1790,7 +1906,19 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er

// FetchOciManifestConfIntoJsonStr will fetch the oci manifest config of the kcl package from the oci registry and return it into json string.
func (c *KpmClient) FetchOciManifestIntoJsonStr(opts opt.OciFetchOptions) (string, error) {
ociCli, err := oci.NewOciClient(opts.Reg, opts.Repo, &c.settings)

repoPath := utils.JoinPath(opts.Reg, opts.Repo)
cred, err := c.GetCredentials(opts.Reg)
if err != nil {
return "", err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
6 changes: 6 additions & 0 deletions pkg/client/visitor.go
Original file line number Diff line number Diff line change
@@ -150,12 +150,18 @@ func (rv *RemoteVisitor) Visit(s *downloader.Source, v visitFunc) error {
tmpDir = filepath.Join(tmpDir, constants.GitScheme)
}

credCli, err := rv.kpmcli.GetCredsClient()
if err != nil {
return err
}

defer os.RemoveAll(tmpDir)
err = rv.kpmcli.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(*s),
downloader.WithLogWriter(rv.kpmcli.GetLogWriter()),
downloader.WithSettings(*rv.kpmcli.GetSettings()),
downloader.WithCredsClient(credCli),
))

if err != nil {
50 changes: 50 additions & 0 deletions pkg/downloader/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package downloader

import (
"fmt"

dockerauth "oras.land/oras-go/pkg/auth/docker"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// CredClient is the client to get the credentials.
type CredClient struct {
credsClient *dockerauth.Client
}

// LoadCredentialFile loads the credential file and return the CredClient.
func LoadCredentialFile(filepath string) (*CredClient, error) {
authClient, err := dockerauth.NewClientWithDockerFallback(filepath)
if err != nil {
return nil, err
}
dockerAuthClient, ok := authClient.(*dockerauth.Client)
if !ok {
return nil, fmt.Errorf("authClient is not *docker.Client type")
}

return &CredClient{
credsClient: dockerAuthClient,
}, nil
}

// GetAuthClient returns the auth client.
func (cred *CredClient) GetAuthClient() *dockerauth.Client {
return cred.credsClient
}

// Credential will reture the credential info cache in CredClient
func (cred *CredClient) Credential(hostName string) (*remoteauth.Credential, error) {
if len(hostName) == 0 {
return nil, fmt.Errorf("hostName is empty")
}
username, password, err := cred.credsClient.Credential(hostName)
if err != nil {
return nil, err
}

return &remoteauth.Credential{
Username: username,
Password: password,
}, nil
}
29 changes: 28 additions & 1 deletion pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ import (
"kcl-lang.io/kpm/pkg/reporter"
"kcl-lang.io/kpm/pkg/settings"
"kcl-lang.io/kpm/pkg/utils"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// DownloadOptions is the options for downloading a package.
@@ -25,10 +26,18 @@ type DownloadOptions struct {
Settings settings.Settings
// LogWriter is the writer to write the log.
LogWriter io.Writer
// credsClient is the client to get the credentials.
credsClient *CredClient
}

type Option func(*DownloadOptions)

func WithCredsClient(credsClient *CredClient) Option {
return func(do *DownloadOptions) {
do.credsClient = credsClient
}
}

func WithLogWriter(logWriter io.Writer) Option {
return func(do *DownloadOptions) {
do.LogWriter = logWriter
@@ -125,7 +134,25 @@ func (d *OciDownloader) Download(opts DownloadOptions) error {

localPath := opts.LocalPath

ociCli, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &opts.Settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)

var cred *remoteauth.Credential
var err error
if opts.credsClient != nil {
cred, err = opts.credsClient.Credential(ociSource.Reg)
if err != nil {
return err
}
} else {
cred = &remoteauth.Credential{}
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(opts.Settings.DefaultOciPlainHttp()),
)

if err != nil {
return err
}
96 changes: 69 additions & 27 deletions pkg/oci/oci.go
Original file line number Diff line number Diff line change
@@ -96,9 +96,44 @@ type OciClient struct {
repo *remote.Repository
ctx *context.Context
logWriter io.Writer
cred *remoteauth.Credential
PullOciOptions *PullOciOptions
}

// OciClientOption configures how we set up the OciClient
type OciClientOption func(*OciClient) error

// WithRepoPath sets the repo path of the OciClient
func WithRepoPath(repoPath string) OciClientOption {
return func(c *OciClient) error {
var err error
c.repo, err = remote.NewRepository(repoPath)
if err != nil {
return fmt.Errorf("repository '%s' not found", repoPath)
}
return nil
}
}

// WithCredential sets the credential of the OciClient
func WithCredential(credential *remoteauth.Credential) OciClientOption {
return func(c *OciClient) error {
c.cred = credential
return nil
}
}

// WithPlainHttp sets the plain http of the OciClient
func WithPlainHttp(plainHttp bool) OciClientOption {
return func(c *OciClient) error {
if c.repo == nil {
return fmt.Errorf("repo is nil")
}
c.repo.PlainHTTP = plainHttp
return nil
}
}

type PullOciOptions struct {
Platform string
CopyOpts *oras.CopyOptions
@@ -112,40 +147,25 @@ func (ociClient *OciClient) GetReference() string {
return ociClient.repo.Reference.String()
}

// NewOciClient will new an OciClient.
// regName is the registry. e.g. ghcr.io or docker.io.
// repoName is the repo name on registry.
func NewOciClient(regName, repoName string, settings *settings.Settings) (*OciClient, error) {
repoPath := utils.JoinPath(regName, repoName)
repo, err := remote.NewRepository(repoPath)

if err != nil {
return nil, reporter.NewErrorEvent(
reporter.RepoNotFound,
err,
fmt.Sprintf("repository '%s' not found", repoPath),
)
// NewOciClientWithOpts will new an OciClient with options.
func NewOciClientWithOpts(opts ...OciClientOption) (*OciClient, error) {
client := &OciClient{}
for _, opt := range opts {
err := opt(client)
if err != nil {
return nil, err
}
}
ctx := context.Background()
repo.PlainHTTP = settings.DefaultOciPlainHttp()

// Login
credential, err := loadCredential(regName, settings)
if err != nil {
return nil, reporter.NewErrorEvent(
reporter.FailedLoadCredential,
err,
fmt.Sprintf("failed to load credential for '%s' from '%s'.", regName, settings.CredentialsFile),
)
}
repo.Client = &remoteauth.Client{
ctx := context.Background()
client.repo.Client = &remoteauth.Client{
Client: retry.DefaultClient,
Cache: remoteauth.DefaultCache,
Credential: remoteauth.StaticCredential(repo.Reference.Host(), *credential),
Credential: remoteauth.StaticCredential(client.repo.Reference.Host(), *client.cred),
}

return &OciClient{
repo: repo,
repo: client.repo,
ctx: &ctx,
PullOciOptions: &PullOciOptions{
CopyOpts: &oras.CopyOptions{
@@ -157,6 +177,28 @@ func NewOciClient(regName, repoName string, settings *settings.Settings) (*OciCl
}, nil
}

// NewOciClient will new an OciClient.
// regName is the registry. e.g. ghcr.io or docker.io.
// repoName is the repo name on registry.
// Deprecated: use NewOciClientWithOpts instead.
func NewOciClient(regName, repoName string, settings *settings.Settings) (*OciClient, error) {
// Login
credential, err := loadCredential(regName, settings)
if err != nil {
return nil, reporter.NewErrorEvent(
reporter.FailedLoadCredential,
err,
fmt.Sprintf("failed to load credential for '%s' from '%s'.", regName, settings.CredentialsFile),
)
}

return NewOciClientWithOpts(
WithRepoPath(utils.JoinPath(regName, repoName)),
WithCredential(credential),
WithPlainHttp(settings.DefaultOciPlainHttp()),
)
}

// The default limit of the store size is 64 MiB.
const DEFAULT_LIMIT_STORE_SIZE = 64 * 1024 * 1024

Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
repository 'invalid_url/' not found
invalid reference: invalid repository ""
repository 'invalid_url/' not found
Original file line number Diff line number Diff line change
@@ -1 +1 @@
kpm run oci://invalid_rul
kpm run oci://invalid_url
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
repository 'invalid_rul/' not found
invalid reference: invalid repository ""
repository 'invalid_url/' not found

0 comments on commit 635551d

Please sign in to comment.