From aa66c698c85df45cc8526e73289636741f90ba15 Mon Sep 17 00:00:00 2001 From: windvalley Date: Wed, 10 Jan 2024 11:20:39 +0800 Subject: [PATCH] refactor: optimize errors handler --- internal/cmd/command.go | 4 ++-- internal/cmd/fetch.go | 2 +- internal/cmd/push.go | 10 +++++----- internal/cmd/root.go | 13 ++++++++----- internal/cmd/script.go | 6 +++--- internal/cmd/vault/decrypt.go | 4 ++-- internal/cmd/vault/decrypt_file.go | 8 ++++---- internal/cmd/vault/encrypt.go | 4 ++-- internal/cmd/vault/encrypt_file.go | 10 +++++++--- internal/cmd/vault/vault.go | 18 ++++++++++-------- internal/cmd/vault/view.go | 11 +++++++---- internal/pkg/aes/aes.go | 2 +- internal/pkg/sshtask/sshtask.go | 11 ++++++----- pkg/util/cobra.go | 4 ++-- pkg/util/error.go | 16 +++++----------- scripts/makefiles/tools.makefile | 2 +- 16 files changed, 66 insertions(+), 59 deletions(-) diff --git a/internal/cmd/command.go b/internal/cmd/command.go index 3027774..6ca42f8 100644 --- a/internal/cmd/command.go +++ b/internal/cmd/command.go @@ -73,7 +73,7 @@ Execute commands on target hosts.`, Example: commandCmdExamples, PreRun: func(cmd *cobra.Command, args []string) { if errs := configflags.Config.Validate(); len(errs) != 0 { - util.CheckErr(errs) + util.PrintErrExit(errs) } if noSafeCheck { @@ -87,7 +87,7 @@ Execute commands on target hosts.`, } if err := checkCommand(shellCommand, configflags.Config.Run.CommandBlacklist); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } } }, diff --git a/internal/cmd/fetch.go b/internal/cmd/fetch.go index 53c54a7..947d1c9 100644 --- a/internal/cmd/fetch.go +++ b/internal/cmd/fetch.go @@ -58,7 +58,7 @@ Copy files and dirs from target hosts to local.`, Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/fetch.md`, PreRun: func(cmd *cobra.Command, args []string) { if errs := configflags.Config.Validate(); len(errs) != 0 { - util.CheckErr(errs) + util.PrintErrExit(errs) } }, Run: func(cmd *cobra.Command, args []string) { diff --git a/internal/cmd/push.go b/internal/cmd/push.go index e416d62..e7fc6b8 100644 --- a/internal/cmd/push.go +++ b/internal/cmd/push.go @@ -64,14 +64,14 @@ Copy local files and dirs to target hosts.`, Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/push.md`, PreRun: func(cmd *cobra.Command, args []string) { if errs := configflags.Config.Validate(); len(errs) != 0 { - util.CheckErr(errs) + util.PrintErrExit(errs) } if len(files) != 0 { for _, f := range files { _, err := os.Stat(f) if err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } } } @@ -88,7 +88,7 @@ Copy local files and dirs to target hosts.`, workDir, err := os.Getwd() if err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } for _, f := range files { @@ -97,12 +97,12 @@ Copy local files and dirs to target hosts.`, zipFile := path.Join(workDir, zipName) if err := util.Zip(strings.TrimSuffix(f, string(os.PathSeparator)), zipFile); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } stat, err := os.Stat(zipFile) if err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } //nolint:gomnd log.Debugf("zip file '%s' size: %d MB", zipFile, stat.Size()/1024/1024) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 08e9eab..e0f3cd8 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -90,7 +90,9 @@ func initConfig() { } else { // Find home directory. home, err := os.UserHomeDir() - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } // Search the default configuration file. viper.AddConfigPath(".") @@ -99,21 +101,22 @@ func initConfig() { viper.SetConfigName(".gossh") } - viper.AutomaticEnv() // read in environment variables that match + // Read in environment variables that match. + viper.AutomaticEnv() // If a config file is found, read it in. _ = viper.ReadInConfig() if err := viper.BindPFlags(rootCmd.PersistentFlags()); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } if err := viper.Unmarshal(&configflags.Config); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } if err := configflags.Config.Complete(); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } } diff --git a/internal/cmd/script.go b/internal/cmd/script.go index 43ea5be..7a1a610 100644 --- a/internal/cmd/script.go +++ b/internal/cmd/script.go @@ -57,11 +57,11 @@ Execute a local shell script on target hosts.`, Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/script.md`, PreRun: func(cmd *cobra.Command, args []string) { if errs := configflags.Config.Validate(); len(errs) != 0 { - util.CheckErr(errs) + util.PrintErrExit(errs) } if scriptFile != "" && !util.FileExists(scriptFile) { - util.CheckErr(fmt.Sprintf("script '%s' not found", scriptFile)) + util.PrintErrExit(fmt.Sprintf("script '%s' not found", scriptFile)) } if noSafeCheck { @@ -75,7 +75,7 @@ Execute a local shell script on target hosts.`, } if err := checkScript(scriptFile, configflags.Config.Run.CommandBlacklist); err != nil { - util.CheckErr(err) + util.PrintErrExit(err) } } }, diff --git a/internal/cmd/vault/decrypt.go b/internal/cmd/vault/decrypt.go index 4e76382..0c783fa 100644 --- a/internal/cmd/vault/decrypt.go +++ b/internal/cmd/vault/decrypt.go @@ -55,7 +55,7 @@ Decrypt content encrypted by vault.`, } if !aes.IsAES256CipherText(args[0]) { - util.CheckErr(fmt.Sprintf("'%s' is not vault encrypted content", args[0])) + util.PrintErrExit(fmt.Sprintf("'%s' is not vault encrypted content", args[0])) } return nil @@ -65,8 +65,8 @@ Decrypt content encrypted by vault.`, plainText, err := aes.AES256Decode(args[0], vaultPass) if err != nil { err = fmt.Errorf("decrypt failed: %w", err) + util.PrintErrExit(err) } - util.CheckErr(err) fmt.Printf("\n%s\n", plainText) }, diff --git a/internal/cmd/vault/decrypt_file.go b/internal/cmd/vault/decrypt_file.go index 77722bd..44d35b6 100644 --- a/internal/cmd/vault/decrypt_file.go +++ b/internal/cmd/vault/decrypt_file.go @@ -34,8 +34,6 @@ import ( var deOutputFile string -// decryptFileCmd represents the vault decrypt-file command -// //nolint:dupl var decryptFileCmd = &cobra.Command{ Use: "decrypt-file FILENAME", @@ -66,7 +64,7 @@ Decrypt vault encrypted file.`, } if !util.FileExists(args[0]) { - util.CheckErr(fmt.Sprintf("file '%s' not found", args[0])) + util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0])) } return nil @@ -77,7 +75,9 @@ Decrypt vault encrypted file.`, file := args[0] content, err := decryptFile(file, vaultPass) - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } handleOutput(content, file, deOutputFile) diff --git a/internal/cmd/vault/encrypt.go b/internal/cmd/vault/encrypt.go index e8b1fe0..3ee00f0 100644 --- a/internal/cmd/vault/encrypt.go +++ b/internal/cmd/vault/encrypt.go @@ -61,14 +61,14 @@ Encrypt sensitive content.`, plainPassword, err := getPlainPassword(args) if err != nil { err = fmt.Errorf("get plaintext to be encrypted failed: %s", err) + util.PrintErrExit(err) } - util.CheckErr(err) encryptContent, err := aes.AES256Encode(plainPassword, vaultPass) if err != nil { err = fmt.Errorf("encrypt failed: %w", err) + util.PrintErrExit(err) } - util.CheckErr(err) fmt.Printf("\n%s\n", encryptContent) }, diff --git a/internal/cmd/vault/encrypt_file.go b/internal/cmd/vault/encrypt_file.go index 1374302..54077bd 100644 --- a/internal/cmd/vault/encrypt_file.go +++ b/internal/cmd/vault/encrypt_file.go @@ -67,7 +67,7 @@ Encrypt a file.`, } if !util.FileExists(args[0]) { - util.CheckErr(fmt.Sprintf("file '%s' not found", args[0])) + util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0])) } return nil @@ -78,7 +78,9 @@ Encrypt a file.`, file := args[0] content, err := encryptFile(file, vaultPass) - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } handleOutput(content, file, outputFile) @@ -108,7 +110,9 @@ func handleOutput(content, originalFile, newFile string) { err = writeContentToOriFile(originalFile, content) } - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } } func encryptFile(file, vaultPass string) (string, error) { diff --git a/internal/cmd/vault/vault.go b/internal/cmd/vault/vault.go index 3860ea3..24b6a11 100644 --- a/internal/cmd/vault/vault.go +++ b/internal/cmd/vault/vault.go @@ -103,7 +103,7 @@ func getVaultConfirmPassword() string { prompt := "New Vault password: " password, err := getConfirmPasswordFromPrompt(prompt) if err != nil { - util.CheckErr(fmt.Sprintf("get vault password from terminal prompt failed: %s", err)) + util.PrintErrExit(fmt.Sprintf("get vault password from terminal prompt failed: %s", err)) } log.Debugf("Vault: confirmed vault password that from terminal prompt") @@ -124,7 +124,7 @@ func GetVaultPassword() string { for { password, err = getPasswordFromPrompt(prompt) if err != nil { - util.CheckErr(fmt.Sprintf("get vault password from terminal prompt '%s' failed: %s", prompt, err)) + util.PrintErrExit(fmt.Sprintf("get vault password from terminal prompt '%s' failed: %s", prompt, err)) } if password != "" { break @@ -142,20 +142,22 @@ func getVaultPasswordFromFile() string { vaultPassFile := configflags.Config.Auth.VaultPassFile if vaultPassFile != "" { ok, err := isExectuable(vaultPassFile) - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } if ok { bin := fmt.Sprintf("./%s", vaultPassFile) out, err1 := exec.Command(bin).Output() if err1 != nil { - util.CheckErr(fmt.Errorf( + util.PrintErrExit(fmt.Errorf( "problem executing file '%s': %s, if this is not a executable file, "+ "remove the executable bit from the file", vaultPassFile, err1)) } vaultPass := strings.TrimSpace(string(out)) if vaultPass == "" { - util.CheckErr(fmt.Sprintf( + util.PrintErrExit(fmt.Sprintf( "problem executing file '%s': output cannot be empty, if this is not a script, "+ "remove the executable bit from the file", vaultPassFile)) } @@ -168,16 +170,16 @@ func getVaultPasswordFromFile() string { passwordContent, err := os.ReadFile(vaultPassFile) if err != nil { err = fmt.Errorf("read vault password file '%s' failed: %w", vaultPassFile, err) + util.PrintErrExit(err) } - util.CheckErr(err) vaultPass := strings.TrimSpace(string(passwordContent)) if vaultPass == "" { - util.CheckErr("vault password file cannot be empty") + util.PrintErrExit("vault password file cannot be empty") } if strings.HasPrefix(vaultPass, "#!/") { - util.CheckErr(fmt.Sprintf( + util.PrintErrExit(fmt.Sprintf( "'%s' looks like a script file, please add the executable bit to this file", vaultPassFile, )) diff --git a/internal/cmd/vault/view.go b/internal/cmd/vault/view.go index ad8c3d9..dd391fc 100644 --- a/internal/cmd/vault/view.go +++ b/internal/cmd/vault/view.go @@ -54,7 +54,7 @@ View vault encrypted file.`, } if !util.FileExists(args[0]) { - util.CheckErr(fmt.Sprintf("file '%s' not found", args[0])) + util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0])) } return nil @@ -65,9 +65,12 @@ View vault encrypted file.`, file := args[0] decryptContent, err := decryptFile(file, vaultPass) - util.CheckErr(err) + if err != nil { + util.PrintErrExit(err) + } - err = util.LessContent(decryptContent) - util.CheckErr(err) + if err := util.LessContent(decryptContent); err != nil { + util.PrintErrExit(err) + } }, } diff --git a/internal/pkg/aes/aes.go b/internal/pkg/aes/aes.go index bb45a55..139ff4d 100644 --- a/internal/pkg/aes/aes.go +++ b/internal/pkg/aes/aes.go @@ -53,7 +53,7 @@ func AES256Encode(plainText, key string) (string, error) { func AES256Decode(hexCipherText, key string) (string, error) { defer func() { if err := recover(); err != nil { - util.CheckErr("decryption failed: wrong vault password") + util.PrintErrExit("decryption failed: wrong vault password") } }() diff --git a/internal/pkg/sshtask/sshtask.go b/internal/pkg/sshtask/sshtask.go index 1d20425..71951de 100644 --- a/internal/pkg/sshtask/sshtask.go +++ b/internal/pkg/sshtask/sshtask.go @@ -231,8 +231,9 @@ func (t *Task) batchRunSSH() { t.err = errors.New("need flag '-d/--dest-path' or '-l/--hosts.list'") } else { if !util.DirExists(t.dstDir) { - err := os.MkdirAll(t.dstDir, os.ModePerm) - util.CheckErr(err) + if err := os.MkdirAll(t.dstDir, os.ModeDir); err != nil { + util.PrintErrExit(err) + } } } } @@ -694,8 +695,8 @@ func getDefaultPassword(auth *configflags.Auth) string { passwordContent, err := os.ReadFile(authFile) if err != nil { err = fmt.Errorf("read password file '%s' failed: %w", authFile, err) + util.PrintErrExit(err) } - util.CheckErr(err) password = strings.TrimSpace(string(passwordContent)) @@ -784,8 +785,8 @@ func getPasswordFromPrompt(loginUser string) string { passwordByte, err := term.ReadPassword(0) if err != nil { err = fmt.Errorf("get password from terminal failed: %s", err) + util.PrintErrExit(err) } - util.CheckErr(err) password := string(passwordByte) @@ -803,7 +804,7 @@ func getRealPass(pass string, host, objectType string) string { realPass, err := aes.AES256Decode(pass, vaultPass) if err != nil { log.Debugf("Vault: decrypt %s for '%s' failed: %s", objectType, host, err) - util.CheckErr(err) + util.PrintErrExit(err) } log.Debugf("Vault: decrypt %s for '%s' success", objectType, host) diff --git a/pkg/util/cobra.go b/pkg/util/cobra.go index 7905992..45b9152 100644 --- a/pkg/util/cobra.go +++ b/pkg/util/cobra.go @@ -38,7 +38,7 @@ func CobraCheckErrWithHelp(cmd *cobra.Command, errMsg interface{}) { fmt.Println() - CheckErr(errMsg) + PrintErrExit(errMsg) } } @@ -46,7 +46,7 @@ func CobraCheckErrWithHelp(cmd *cobra.Command, errMsg interface{}) { func CobraMarkHiddenGlobalFlags(command *cobra.Command, flags ...string) { for _, v := range flags { if err := command.Flags().MarkHidden(v); err != nil { - CheckErr(fmt.Sprintf("cannot mark hidden flag: %s", err)) + PrintErrExit(fmt.Sprintf("cannot mark hidden flag: %s", err)) } } } diff --git a/pkg/util/error.go b/pkg/util/error.go index 0ad10ec..b97ec42 100644 --- a/pkg/util/error.go +++ b/pkg/util/error.go @@ -29,17 +29,11 @@ import ( "github.com/fatih/color" ) -// CheckErr and exit. -func CheckErr(msg interface{}) { - if msg != nil { - fmt.Fprintln(os.Stderr, color.RedString("Error:"), msg) - os.Exit(1) - } +func PrintErr(msg interface{}) { + fmt.Fprintln(os.Stderr, color.RedString("Error:"), msg) } -// PrintErr with red color if err not nil. -func PrintErr(msg interface{}) { - if msg != nil { - fmt.Fprintln(os.Stderr, color.RedString("Error:"), msg) - } +func PrintErrExit(msg interface{}) { + fmt.Fprintln(os.Stderr, color.RedString("Error:"), msg) + os.Exit(1) } diff --git a/scripts/makefiles/tools.makefile b/scripts/makefiles/tools.makefile index dfa8f40..1842e50 100644 --- a/scripts/makefiles/tools.makefile +++ b/scripts/makefiles/tools.makefile @@ -20,7 +20,7 @@ install.swagger: .PHONY: install.golangci-lint install.golangci-lint: - @${GO} install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.41.1 + @${GO} install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.54.2 .PHONY: install.go-junit-report install.go-junit-report: