Skip to content

Commit

Permalink
feat: make it possible to read secrets from files
Browse files Browse the repository at this point in the history
  • Loading branch information
tboerger committed Aug 10, 2023
1 parent f33559c commit ed721cf
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 31 deletions.
16 changes: 16 additions & 0 deletions pkg/action/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package action

// boolP returns a boolean pointer.
func boolP(i bool) *bool {
return &i
}

// stringP returns a string pointer.
func stringP(i string) *string {
return &i
}

// slceP returns a slice pointer.
func sliceP(i []string) *[]string {
return &i
}
59 changes: 28 additions & 31 deletions pkg/action/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package action
import (
"context"
"crypto/tls"
"encoding/base64"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -233,30 +232,6 @@ func handler(cfg *config.Config, logger log.Logger, client *github.Client) *chi.
return mux
}

func boolP(i bool) *bool {
return &i
}

func stringP(i string) *string {
return &i
}

func sliceP(i []string) *[]string {
return &i
}

func contentOrDecode(file string) ([]byte, error) {
decoded, err := base64.StdEncoding.DecodeString(
file,
)

if err != nil {
return os.ReadFile(file)
}

return decoded, nil
}

func useEnterprise(cfg *config.Config, _ log.Logger) bool {
return cfg.Target.BaseURL != ""
}
Expand All @@ -271,7 +246,7 @@ func getClient(cfg *config.Config, logger log.Logger) (*github.Client, error) {
}

if useApplication(cfg, logger) {
privateKey, err := contentOrDecode(cfg.Target.PrivateKey)
privateKey, err := config.Value(cfg.Target.PrivateKey)

if err != nil {
level.Error(logger).Log(
Expand All @@ -286,7 +261,7 @@ func getClient(cfg *config.Config, logger log.Logger) (*github.Client, error) {
http.DefaultTransport,
cfg.Target.AppID,
cfg.Target.InstallID,
privateKey,
[]byte(privateKey),
)

if err != nil {
Expand All @@ -305,12 +280,23 @@ func getClient(cfg *config.Config, logger log.Logger) (*github.Client, error) {
), nil
}

accessToken, err := config.Value(cfg.Target.Token)

if err != nil {
level.Error(logger).Log(
"msg", "Failed to read token",
"err", err,
)

return nil, err
}

return github.NewClient(
oauth2.NewClient(
context.Background(),
oauth2.StaticTokenSource(
&oauth2.Token{
AccessToken: cfg.Target.Token,
AccessToken: accessToken,
},
),
),
Expand All @@ -319,7 +305,7 @@ func getClient(cfg *config.Config, logger log.Logger) (*github.Client, error) {

func getEnterprise(cfg *config.Config, logger log.Logger) (*github.Client, error) {
if useApplication(cfg, logger) {
privateKey, err := contentOrDecode(cfg.Target.PrivateKey)
privateKey, err := config.Value(cfg.Target.PrivateKey)

if err != nil {
level.Error(logger).Log(
Expand All @@ -334,7 +320,7 @@ func getEnterprise(cfg *config.Config, logger log.Logger) (*github.Client, error
http.DefaultTransport,
cfg.Target.AppID,
cfg.Target.InstallID,
privateKey,
[]byte(privateKey),
)

if err != nil {
Expand Down Expand Up @@ -373,6 +359,17 @@ func getEnterprise(cfg *config.Config, logger log.Logger) (*github.Client, error
return client, err
}

accessToken, err := config.Value(cfg.Target.Token)

if err != nil {
level.Error(logger).Log(
"msg", "Failed to read token",
"err", err,
)

return nil, err
}

client, err := github.NewEnterpriseClient(
cfg.Target.BaseURL,
cfg.Target.BaseURL,
Expand All @@ -390,7 +387,7 @@ func getEnterprise(cfg *config.Config, logger log.Logger) (*github.Client, error
),
oauth2.StaticTokenSource(
&oauth2.Token{
AccessToken: cfg.Target.Token,
AccessToken: accessToken,
},
),
),
Expand Down
33 changes: 33 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package config

import (
"encoding/base64"
"fmt"
"os"
"strings"
"time"

"github.com/urfave/cli/v2"
Expand Down Expand Up @@ -65,3 +69,32 @@ type Config struct {
func Load() *Config {
return &Config{}
}

// Value returns the config value based on a DSN.
func Value(val string) (string, error) {
if strings.HasPrefix(val, "file://") {
content, err := os.ReadFile(
strings.TrimPrefix(val, "file://"),
)

if err != nil {
return "", fmt.Errorf("failed to parse secret file: %w", err)
}

return string(content), nil
}

if strings.HasPrefix(val, "base64://") {
content, err := base64.StdEncoding.DecodeString(
strings.TrimPrefix(val, "base64://"),
)

if err != nil {
return "", fmt.Errorf("failed to parse base64 value: %w", err)
}

return string(content), nil
}

return val, nil
}

0 comments on commit ed721cf

Please sign in to comment.