Skip to content

Commit

Permalink
refactor: optimize output messages and enrich debug log information
Browse files Browse the repository at this point in the history
  • Loading branch information
windvalley committed Jan 5, 2024
1 parent 2c5a417 commit dbe1467
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 180 deletions.
3 changes: 1 addition & 2 deletions internal/pkg/configflags/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ func (o *Output) AddFlagsTo(flags *pflag.FlagSet) {
flags.StringVarP(&o.File, flagOutputFile, "o", o.File, "file to which messages are output")
flags.BoolVarP(&o.JSON, flagOutputJSON, "j", o.JSON, "output messages in json format")
flags.BoolVarP(&o.Condense, flagOutputCondense, "C", o.Condense, "condense output and disable color")
flags.BoolVarP(&o.Quiet, flagOutputQuite, "q", o.Quiet,
"do not output messages to screen (except error messages)")
flags.BoolVarP(&o.Quiet, flagOutputQuite, "q", o.Quiet, "do not output messages to screen")
flags.BoolVarP(&o.Verbose, flagOutputVerbose, "v", o.Verbose, "show debug messages")
}

Expand Down
200 changes: 92 additions & 108 deletions internal/pkg/sshtask/sshtask.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,15 @@ var (
)
)

// TaskType ...
type TaskType int

// ...
const (
CommandTask TaskType = iota
ScriptTask
PushTask
FetchTask
)

// taskResult ...
type taskResult struct {
taskID string
hostsSuccessCount int
Expand All @@ -87,7 +84,7 @@ type pushFiles struct {
zipFiles []string
}

// Task ...
// Task is a ssh task for one or more hosts.
type Task struct {
configFlags *configflags.ConfigFlags

Expand All @@ -97,7 +94,7 @@ type Task struct {
sshClient *batchssh.Client
sshAgent net.Conn
defaultUser string
defaultPass *string
defaultPass string
defaultIdentityFiles []string
defaultSSHAuthMethods []ssh.AuthMethod

Expand All @@ -121,7 +118,7 @@ type Task struct {
err error
}

// NewTask ...
// NewTask create a new task.
func NewTask(taskType TaskType, configFlags *configflags.ConfigFlags) *Task {
defaultIdentityFiles := parseItentityFiles(configFlags.Auth.IdentityFiles)

Expand All @@ -132,14 +129,38 @@ func NewTask(taskType TaskType, configFlags *configflags.ConfigFlags) *Task {
id: time.Now().Format("20060102150405"),
taskType: taskType,
defaultUser: configFlags.Auth.User,
defaultPass: &defaultPass,
defaultPass: defaultPass,
defaultIdentityFiles: defaultIdentityFiles,
taskOutput: make(chan taskResult, 1),
detailOutput: make(chan detailResult),
}
}

// Start task.
// SSH implements batchssh.Tasker.
func (t *Task) SSH(host *batchssh.Host) (string, error) {
lang := t.configFlags.Run.Lang
runAs := t.configFlags.Run.AsUser
sudo := t.configFlags.Run.Sudo

switch t.taskType {
case CommandTask:
return t.sshClient.ExecuteCmd(host, t.command, lang, runAs, sudo)
case ScriptTask:
return t.sshClient.ExecuteScript(host, t.scriptFile, t.dstDir, lang, runAs, sudo, t.remove, t.allowOverwrite)
case PushTask:
return t.sshClient.PushFiles(host, t.pushFiles.files, t.pushFiles.zipFiles, t.dstDir, t.allowOverwrite, t.enableZip)
case FetchTask:
hosts, err := t.getAllHosts()
if err != nil {
return "", err
}
return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs, t.enableZip, len(hosts))
default:
return "", fmt.Errorf("unknown task type: %v", t.taskType)
}
}

// Start to run ssh task.
func (t *Task) Start() {
if t.sshAgent != nil {
defer t.sshAgent.Close()
Expand All @@ -148,7 +169,7 @@ func (t *Task) Start() {
go func() {
defer close(t.taskOutput)
defer close(t.detailOutput)
t.BatchRun()
t.batchRunSSH()
}()

taskTimeout := t.configFlags.Timeout.Task
Expand All @@ -165,86 +186,11 @@ func (t *Task) Start() {
}()
}

t.HandleOutput()
}

// SetTargetHosts ...
func (t *Task) SetTargetHosts(hosts []string) {
t.argHosts = hosts
}

// SetCommand ...
func (t *Task) SetCommand(command string) {
t.command = command
}

// SetScriptFile ...
func (t *Task) SetScriptFile(sciptFile string) {
t.scriptFile = sciptFile
}

// SetPushfiles ...
func (t *Task) SetPushfiles(files, zipFiles []string) {
t.pushFiles = &pushFiles{
files: files,
zipFiles: zipFiles,
}
}

// SetFetchFiles ...
func (t *Task) SetFetchFiles(files []string) {
t.fetchFiles = files
}

// SetScriptOptions ...
func (t *Task) SetScriptOptions(destPath string, remove, allowOverwrite bool) {
t.dstDir = destPath
t.remove = remove
t.allowOverwrite = allowOverwrite
}

// SetPushOptions ...
func (t *Task) SetPushOptions(destPath string, allowOverwrite, enableZip bool) {
t.dstDir = destPath
t.allowOverwrite = allowOverwrite
t.enableZip = enableZip
}

// SetFetchOptions ...
func (t *Task) SetFetchOptions(destPath, tmpDir string, enableZipFiles bool) {
t.dstDir = destPath
t.tmpDir = tmpDir
t.enableZip = enableZipFiles
}

// RunSSH implements batchssh.Task
func (t *Task) RunSSH(host *batchssh.Host) (string, error) {
lang := t.configFlags.Run.Lang
runAs := t.configFlags.Run.AsUser
sudo := t.configFlags.Run.Sudo

switch t.taskType {
case CommandTask:
return t.sshClient.ExecuteCmd(host, t.command, lang, runAs, sudo)
case ScriptTask:
return t.sshClient.ExecuteScript(host, t.scriptFile, t.dstDir, lang, runAs, sudo, t.remove, t.allowOverwrite)
case PushTask:
return t.sshClient.PushFiles(host, t.pushFiles.files, t.pushFiles.zipFiles, t.dstDir, t.allowOverwrite, t.enableZip)
case FetchTask:
hosts, err := t.getAllHosts()
if err != nil {
return "", err
}
return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs, t.enableZip, len(hosts))
default:
return "", fmt.Errorf("unknown task type: %v", t.taskType)
}
t.handleOutput()
}

// BatchRun ...
//
//nolint:gocyclo
func (t *Task) BatchRun() {
func (t *Task) batchRunSSH() {
timeNow := time.Now()

if t.configFlags.Hosts.List {
Expand All @@ -265,15 +211,6 @@ func (t *Task) BatchRun() {
return
}

authConf := t.configFlags.Auth
runConf := t.configFlags.Run

log.Debugf("Default Auth: login user '%s'", authConf.User)

if runConf.Sudo {
log.Debugf("Default Auth: use sudo as user '%s'", runConf.AsUser)
}

switch t.taskType {
case CommandTask:
if t.command == "" {
Expand Down Expand Up @@ -304,6 +241,15 @@ func (t *Task) BatchRun() {
return
}

authConf := t.configFlags.Auth
runConf := t.configFlags.Run

log.Debugf("Default Auth: login user '%s'", authConf.User)

if runConf.Sudo {
log.Debugf("Default Auth: use sudo as user '%s'", runConf.AsUser)
}

t.setDefaultSSHAuthMethods()

t.buildSSHClient()
Expand All @@ -314,7 +260,7 @@ func (t *Task) BatchRun() {
return
}

log.Debugf("got target hosts, count: %d", len(allHosts))
log.Debugf("Got target hosts, count: %d", len(allHosts))

result := t.sshClient.BatchRun(allHosts, t)
successCount, failedCount := 0, 0
Expand Down Expand Up @@ -343,8 +289,7 @@ func (t *Task) BatchRun() {
}
}

// HandleOutput ...
func (t *Task) HandleOutput() {
func (t *Task) handleOutput() {
for res := range t.detailOutput {
// Fix the problem of special characters ^M appearing at the end of
// the line break when writing files in text format.
Expand Down Expand Up @@ -385,12 +330,51 @@ func (t *Task) HandleOutput() {
}
}

// CheckErr ...
func (t *Task) SetTargetHosts(hosts []string) {
t.argHosts = hosts
}

func (t *Task) SetCommand(command string) {
t.command = command
}

func (t *Task) SetScriptFile(sciptFile string) {
t.scriptFile = sciptFile
}

func (t *Task) SetPushfiles(files, zipFiles []string) {
t.pushFiles = &pushFiles{
files: files,
zipFiles: zipFiles,
}
}

func (t *Task) SetFetchFiles(files []string) {
t.fetchFiles = files
}

func (t *Task) SetScriptOptions(destPath string, remove, allowOverwrite bool) {
t.dstDir = destPath
t.remove = remove
t.allowOverwrite = allowOverwrite
}

func (t *Task) SetPushOptions(destPath string, allowOverwrite, enableZip bool) {
t.dstDir = destPath
t.allowOverwrite = allowOverwrite
t.enableZip = enableZip
}

func (t *Task) SetFetchOptions(destPath, tmpDir string, enableZipFiles bool) {
t.dstDir = destPath
t.tmpDir = tmpDir
t.enableZip = enableZipFiles
}

func (t *Task) CheckErr() error {
return t.err
}

// ListHosts ...
func (t *Task) ListHosts() ([]string, error) {
var hosts []string

Expand Down Expand Up @@ -453,7 +437,7 @@ func (t *Task) getAllHosts() ([]*batchssh.Host, error) {
Host: v,
Port: t.configFlags.Hosts.Port,
User: t.defaultUser,
Password: *t.defaultPass,
Password: t.defaultPass,
Keys: t.defaultIdentityFiles,
SSHAuths: t.defaultSSHAuthMethods,
})
Expand Down Expand Up @@ -504,7 +488,7 @@ func (t *Task) getAllHosts() ([]*batchssh.Host, error) {

realPassword := ""
if v.Password == "" {
realPassword = getRealPass(*t.defaultPass, v.Alias, "password")
realPassword = getRealPass(t.defaultPass, v.Alias, "password")
} else {
realPassword = getRealPass(v.Password, v.Alias, "password")
hostSSHAuths = append(hostSSHAuths, ssh.Password(v.Password))
Expand Down Expand Up @@ -636,20 +620,20 @@ func (t *Task) setDefaultSSHAuthMethods() {
auths = append(auths, ssh.PublicKeys(signers...))
}

if *t.defaultPass != "" {
auths = append(auths, ssh.Password(*t.defaultPass))
if t.defaultPass != "" {
auths = append(auths, ssh.Password(t.defaultPass))
} else {
log.Debugf("Default Auth: password of the login user '%s' not provided", t.defaultUser)
}

if *t.defaultPass == "" && t.configFlags.Run.Sudo {
if t.defaultPass == "" && t.configFlags.Run.Sudo {
log.Debugf(
"Default Auth: using sudo as other user needs password. Prompt for password of the login user '%s'",
t.defaultUser,
)

*t.defaultPass = getPasswordFromPrompt(t.defaultUser)
auths = append(auths, ssh.Password(*t.defaultPass))
t.defaultPass = getPasswordFromPrompt(t.defaultUser)
auths = append(auths, ssh.Password(t.defaultPass))
}

t.defaultSSHAuthMethods = auths
Expand Down Expand Up @@ -693,7 +677,7 @@ func (t *Task) getProxySSHAuthMethods() []ssh.AuthMethod {
if t.configFlags.Proxy.Password != "" {
proxyAuths = append(proxyAuths, ssh.Password(t.configFlags.Proxy.Password))
} else {
proxyAuths = append(proxyAuths, ssh.Password(*t.defaultPass))
proxyAuths = append(proxyAuths, ssh.Password(t.defaultPass))
}
log.Debugf("Proxy Auth: received password of the proxy user")

Expand Down
Loading

0 comments on commit dbe1467

Please sign in to comment.