Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cache for credential to reduce the probability that kpm would be considered a threat #388

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 135 additions & 7 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -19,7 +20,9 @@ import (
"github.com/otiai10/copy"
"golang.org/x/mod/module"
"kcl-lang.io/kcl-go/pkg/kcl"
"oras.land/oras-go/pkg/auth"
"oras.land/oras-go/v2"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"

"kcl-lang.io/kpm/pkg/constants"
"kcl-lang.io/kpm/pkg/downloader"
Expand All @@ -41,6 +44,8 @@ type KpmClient struct {
logWriter io.Writer
// The downloader of the dependencies.
DepDownloader *downloader.DepDownloader
// credential store
credsClient *downloader.CredClient
// The home path of kpm for global configuration file and kcl package storage path.
homePath string
// The settings of kpm loaded from the global configuration file.
Expand Down Expand Up @@ -75,6 +80,33 @@ func (c *KpmClient) SetNoSumCheck(noSumCheck bool) {
c.noSumCheck = noSumCheck
}

// GetCredsClient will return the credential client.
func (c *KpmClient) GetCredsClient() (*downloader.CredClient, error) {
if c.credsClient == nil {
credCli, err := downloader.LoadCredentialFile(c.settings.CredentialsFile)
if err != nil {
return nil, err
}
c.credsClient = credCli
}
return c.credsClient, nil
}

// GetCredentials will return the credentials of the host.
func (c *KpmClient) GetCredentials(hostName string) (*remoteauth.Credential, error) {
credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}

creds, err := credCli.Credential(hostName)
if err != nil {
return nil, err
}

return creds, nil
}

// GetNoSumCheck will return the 'noSumCheck' flag.
func (c *KpmClient) GetNoSumCheck() bool {
return c.noSumCheck
Expand Down Expand Up @@ -953,7 +985,18 @@ func (c *KpmClient) FillDependenciesInfo(modFile *pkg.ModFile) error {

// AcquireTheLatestOciVersion will acquire the latest version of the OCI reference.
func (c *KpmClient) AcquireTheLatestOciVersion(ociSource downloader.Oci) (string, error) {
ociClient, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &c.settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)
cred, err := c.GetCredentials(ociSource.Reg)
if err != nil {
return "", err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
Expand Down Expand Up @@ -1098,11 +1141,16 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
// clean the temp dir.
defer os.RemoveAll(tmpDir)

credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}
err = c.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(dep.Source),
downloader.WithLogWriter(c.logWriter),
downloader.WithSettings(c.settings),
downloader.WithCredsClient(credCli),
))
if err != nil {
return nil, err
Expand Down Expand Up @@ -1276,10 +1324,22 @@ func (c *KpmClient) ParseKclModFile(kclPkg *pkg.KclPkg) (map[string]map[string]s

// LoadPkgFromOci will download the kcl package from the oci repository and return an `KclPkg`.
func (c *KpmClient) DownloadPkgFromOci(dep *downloader.Oci, localPath string) (*pkg.KclPkg, error) {
ociClient, err := oci.NewOciClient(dep.Reg, dep.Repo, &c.settings)
repoPath := utils.JoinPath(dep.Reg, dep.Repo)
cred, err := c.GetCredentials(dep.Reg)
if err != nil {
return nil, err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return nil, err
}

ociClient.SetLogWriter(c.logWriter)
// Select the latest tag, if the tag, the user inputed, is empty.
var tagSelected string
Expand Down Expand Up @@ -1478,7 +1538,18 @@ func (c *KpmClient) PullFromOci(localPath, source, tag string) error {

// PushToOci will push a kcl package to oci registry.
func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {
ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand All @@ -1504,12 +1575,46 @@ func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {

// LoginOci will login to the oci registry.
func (c *KpmClient) LoginOci(hostname, username, password string) error {
return oci.Login(hostname, username, password, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().LoginWithOpts(
[]auth.LoginOption{
auth.WithLoginHostname(hostname),
auth.WithLoginUsername(username),
auth.WithLoginSecret(password),
}...,
)

if err != nil {
return reporter.NewErrorEvent(
reporter.FailedLogin,
err,
fmt.Sprintf("failed to login '%s', please check registry, username and password is valid", hostname),
)
}

return nil
}

// LogoutOci will logout from the oci registry.
func (c *KpmClient) LogoutOci(hostname string) error {
return oci.Logout(hostname, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().Logout(context.Background(), hostname)

if err != nil {
return reporter.NewErrorEvent(reporter.FailedLogout, err, fmt.Sprintf("failed to logout '%s'", hostname))
}

return nil
}

// ParseOciRef will parser '<repo_name>:<repo_tag>' into an 'OciOptions'.
Expand Down Expand Up @@ -1753,7 +1858,18 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er
return reporter.NewErrorEvent(reporter.Bug, err)
}

ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand Down Expand Up @@ -1790,7 +1906,19 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er

// FetchOciManifestConfIntoJsonStr will fetch the oci manifest config of the kcl package from the oci registry and return it into json string.
func (c *KpmClient) FetchOciManifestIntoJsonStr(opts opt.OciFetchOptions) (string, error) {
ociCli, err := oci.NewOciClient(opts.Reg, opts.Repo, &c.settings)

repoPath := utils.JoinPath(opts.Reg, opts.Repo)
cred, err := c.GetCredentials(opts.Reg)
if err != nil {
return "", err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/client/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ func (rv *RemoteVisitor) Visit(s *downloader.Source, v visitFunc) error {
tmpDir = filepath.Join(tmpDir, constants.GitScheme)
}

credCli, err := rv.kpmcli.GetCredsClient()
if err != nil {
return err
}

defer os.RemoveAll(tmpDir)
err = rv.kpmcli.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(*s),
downloader.WithLogWriter(rv.kpmcli.GetLogWriter()),
downloader.WithSettings(*rv.kpmcli.GetSettings()),
downloader.WithCredsClient(credCli),
))

if err != nil {
Expand Down
50 changes: 50 additions & 0 deletions pkg/downloader/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package downloader

import (
"fmt"

dockerauth "oras.land/oras-go/pkg/auth/docker"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// CredClient is the client to get the credentials.
type CredClient struct {
credsClient *dockerauth.Client
}

// LoadCredentialFile loads the credential file and return the CredClient.
func LoadCredentialFile(filepath string) (*CredClient, error) {
authClient, err := dockerauth.NewClientWithDockerFallback(filepath)
if err != nil {
return nil, err
}
dockerAuthClient, ok := authClient.(*dockerauth.Client)
if !ok {
return nil, fmt.Errorf("authClient is not *docker.Client type")
}

return &CredClient{
credsClient: dockerAuthClient,
}, nil
}

// GetAuthClient returns the auth client.
func (cred *CredClient) GetAuthClient() *dockerauth.Client {
return cred.credsClient
}

// Credential will reture the credential info cache in CredClient
func (cred *CredClient) Credential(hostName string) (*remoteauth.Credential, error) {
if len(hostName) == 0 {
return nil, fmt.Errorf("hostName is empty")
}
username, password, err := cred.credsClient.Credential(hostName)
if err != nil {
return nil, err
}

return &remoteauth.Credential{
Username: username,
Password: password,
}, nil
}
29 changes: 28 additions & 1 deletion pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"kcl-lang.io/kpm/pkg/reporter"
"kcl-lang.io/kpm/pkg/settings"
"kcl-lang.io/kpm/pkg/utils"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// DownloadOptions is the options for downloading a package.
Expand All @@ -25,10 +26,18 @@ type DownloadOptions struct {
Settings settings.Settings
// LogWriter is the writer to write the log.
LogWriter io.Writer
// credsClient is the client to get the credentials.
credsClient *CredClient
}

type Option func(*DownloadOptions)

func WithCredsClient(credsClient *CredClient) Option {
return func(do *DownloadOptions) {
do.credsClient = credsClient
}
}

func WithLogWriter(logWriter io.Writer) Option {
return func(do *DownloadOptions) {
do.LogWriter = logWriter
Expand Down Expand Up @@ -125,7 +134,25 @@ func (d *OciDownloader) Download(opts DownloadOptions) error {

localPath := opts.LocalPath

ociCli, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &opts.Settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)

var cred *remoteauth.Credential
var err error
if opts.credsClient != nil {
cred, err = opts.credsClient.Credential(ociSource.Reg)
if err != nil {
return err
}
} else {
cred = &remoteauth.Credential{}
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(opts.Settings.DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand Down
Loading
Loading