diff --git a/docs/command.md b/docs/command.md index dd05b5f..b5fccc8 100644 --- a/docs/command.md +++ b/docs/command.md @@ -5,11 +5,8 @@ Execute shell commands on target hosts. ## Examples ```sh -# Use sudo as user 'zhangsan' to execute commands on target hosts. -$ gossh command host[1-3] -e "uptime" -s -U zhangsan - # Set timeout seconds for executing commands on each target host. -$ gossh command host[1-3] -e "uptime" --timeout.command 10 +$ gossh command host[1-3] -e "uptime" -t 20 # Connect target hosts by proxy server 10.16.0.1. $ gossh command host[1-3] -e "uptime" -X 10.16.0.1 @@ -17,3 +14,5 @@ $ gossh command host[1-3] -e "uptime" -X 10.16.0.1 # Specify concurrency connections. $ gossh command host[1-3] -e "uptime" -c 10 ``` + + diff --git a/internal/cmd/command.go b/internal/cmd/command.go index d29d146..0e0241e 100644 --- a/internal/cmd/command.go +++ b/internal/cmd/command.go @@ -23,14 +23,35 @@ THE SOFTWARE. package cmd import ( + "fmt" + "regexp" + "strings" + "github.com/spf13/cobra" "github.com/windvalley/gossh/internal/pkg/configflags" "github.com/windvalley/gossh/internal/pkg/sshtask" + "github.com/windvalley/gossh/pkg/log" "github.com/windvalley/gossh/pkg/util" ) -var shellCommand string +var ( + shellCommand string + noSafeCheck bool +) + +var defaultCommandBlacklist = []string{ + "rm", + "reboot", + "halt", + "shutdown", + "poweroff", + "init", + "mkfs", + "mkfs.*", + "umount", + "dd", +} const commandCmdExamples = ` Execute command 'uptime' on target hosts. @@ -55,6 +76,21 @@ Execute commands on target hosts.`, if errs := configflags.Config.Validate(); len(errs) != 0 { util.CheckErr(errs) } + + if noSafeCheck { + log.Debugf("Skip the safety check of commands before execution") + } else { + if len(configflags.Config.Run.CommandBlacklist) == 0 { + configflags.Config.Run.CommandBlacklist = defaultCommandBlacklist + log.Debugf("Using default command blacklist for the safety check: %s", defaultCommandBlacklist) + } else { + log.Debugf("Using custom command blacklist for the safety check: %s", configflags.Config.Run.CommandBlacklist) + } + + if err := checkCommand(shellCommand, configflags.Config.Run.CommandBlacklist); err != nil { + util.CheckErr(err) + } + } }, Run: func(cmd *cobra.Command, args []string) { task := sshtask.NewTask(sshtask.CommandTask, configflags.Config) @@ -76,4 +112,43 @@ func init() { "", "commands to be executed on target hosts", ) + commandCmd.Flags().BoolVarP( + &noSafeCheck, + "no-safe-check", + "n", + false, + "ignore dangerous commands (from '-B,--run.command-blacklist') check", + ) +} + +func checkCommand(command string, commandBlacklist []string) error { + unsafeCommands := make([]string, 0) + + commands := strings.FieldsFunc(command, func(r rune) bool { + if r == ';' || r == '|' || r == '&' || r == ' ' { + return true + } + return false + }) + + for _, cmd := range commands { + for _, unsafeCmd := range commandBlacklist { + re := regexp.MustCompile(fmt.Sprintf(`^%s(?:\s+|;)*$`, unsafeCmd)) + if re.MatchString(cmd) { + unsafeCommands = append(unsafeCommands, cmd) + break + } + } + } + + if len(unsafeCommands) > 0 { + unsafeCommands = util.RemoveDuplStr(unsafeCommands) + + return fmt.Errorf( + "found dangerous commands: '%s', you can add '-n/--no-safe-check' flag to ignore this check", + strings.Join(unsafeCommands, ", "), + ) + } + + return nil } diff --git a/internal/cmd/config.go b/internal/cmd/config.go index 8347863..5f3dacd 100644 --- a/internal/cmd/config.go +++ b/internal/cmd/config.go @@ -91,6 +91,12 @@ run: # Default: 1 concurrency: %d + # Linux Command Blacklist for gossh subcommands 'command' and 'script'. + # Commands listed in this blacklist will be prohibited from executing on remote hosts for security reasons. + # You can add flag '-n, --no-safe-check' to disable this feature. + # Default: ["rm", "reboot", "halt", "shutdown", "init", "mkfs", "mkfs.*", "umount", "dd"] + command-blacklist: [] + output: # File to which messages are output. # Default: "" @@ -195,6 +201,7 @@ func init() { "auth.identity-files", "proxy.identity-files", "hosts.list", + "run.command-blacklist", ) command.Parent().HelpFunc()(command, strings) diff --git a/internal/pkg/configflags/configflags.go b/internal/pkg/configflags/configflags.go index 82e3e48..9d31b5e 100644 --- a/internal/pkg/configflags/configflags.go +++ b/internal/pkg/configflags/configflags.go @@ -75,10 +75,21 @@ func (c *ConfigFlags) Complete() error { if err := c.Auth.Complete(); err != nil { return err } - + if err := c.Hosts.Complete(); err != nil { + return err + } + if err := c.Run.Complete(); err != nil { + return err + } + if err := c.Output.Complete(); err != nil { + return err + } if err := c.Proxy.Complete(); err != nil { return err } + if err := c.Timeout.Complete(); err != nil { + return err + } return nil } diff --git a/internal/pkg/configflags/run.go b/internal/pkg/configflags/run.go index c83e0a9..49dd393 100644 --- a/internal/pkg/configflags/run.go +++ b/internal/pkg/configflags/run.go @@ -24,23 +24,26 @@ package configflags import ( "fmt" + "strings" "github.com/spf13/pflag" ) const ( - flagRunSudo = "run.sudo" - flagRunAsUser = "run.as-user" - flagRunLang = "run.lang" - flagRunConcurrency = "run.concurrency" + flagRunSudo = "run.sudo" + flagRunAsUser = "run.as-user" + flagRunLang = "run.lang" + flagRunConcurrency = "run.concurrency" + flagRunCommandBlacklist = "run.command-blacklist" ) // Run ... type Run struct { - Sudo bool `json:"sudo" mapstructure:"sudo"` - AsUser string `json:"as-user" mapstructure:"as-user"` - Lang string `json:"lang" mapstructure:"lang"` - Concurrency int `json:"concurrency" mapstructure:"concurrency"` + Sudo bool `json:"sudo" mapstructure:"sudo"` + AsUser string `json:"as-user" mapstructure:"as-user"` + Lang string `json:"lang" mapstructure:"lang"` + Concurrency int `json:"concurrency" mapstructure:"concurrency"` + CommandBlacklist []string `json:"command-blacklist" mapstructure:"command-blacklist"` } // NewRun ... @@ -66,10 +69,28 @@ func (r *Run) AddFlagsTo(flags *pflag.FlagSet) { ) flags.IntVarP(&r.Concurrency, flagRunConcurrency, "c", r.Concurrency, "number of concurrent connections") + flags.StringSliceVarP( + &r.CommandBlacklist, + flagRunCommandBlacklist, + "B", + r.CommandBlacklist, + `commands that are prohibited from execution on target hosts +(default: rm,reboot,halt,shutdown,init,mkfs,mkfs.*,umount,dd)`, + ) } // Complete ... func (r *Run) Complete() error { + newSlice := make([]string, 0) + for _, s := range r.CommandBlacklist { + item := strings.TrimSpace(s) + if item != "" { + newSlice = append(newSlice, item) + } + } + + r.CommandBlacklist = newSlice + return nil } diff --git a/internal/pkg/sshtask/sshtask.go b/internal/pkg/sshtask/sshtask.go index 3ec0301..1d20425 100644 --- a/internal/pkg/sshtask/sshtask.go +++ b/internal/pkg/sshtask/sshtask.go @@ -299,7 +299,7 @@ func (t *Task) handleOutput() { outputNoSudoPrompt := "" re, err := regexp.Compile(sudoPromptRegex) if err != nil { - log.Debugf("re compile '%s' failed: %s", sudoPromptRegex, err) + log.Debugf("regexp compile '%s' failed: %s", sudoPromptRegex, err) } else { outputNoSudoPrompt = re.ReplaceAllString(outputNoR, "") } diff --git a/pkg/util/slice.go b/pkg/util/slice.go index 34d5796..c6f0f93 100644 --- a/pkg/util/slice.go +++ b/pkg/util/slice.go @@ -37,3 +37,13 @@ func RemoveDuplStr(strSlice []string) []string { return set } + +func IsInStringSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + + return false +}