diff --git a/internal/pkg/configflags/output.go b/internal/pkg/configflags/output.go index fd0f208..2ae2469 100644 --- a/internal/pkg/configflags/output.go +++ b/internal/pkg/configflags/output.go @@ -57,8 +57,7 @@ func (o *Output) AddFlagsTo(flags *pflag.FlagSet) { flags.StringVarP(&o.File, flagOutputFile, "o", o.File, "file to which messages are output") flags.BoolVarP(&o.JSON, flagOutputJSON, "j", o.JSON, "output messages in json format") flags.BoolVarP(&o.Condense, flagOutputCondense, "C", o.Condense, "condense output and disable color") - flags.BoolVarP(&o.Quiet, flagOutputQuite, "q", o.Quiet, - "do not output messages to screen (except error messages)") + flags.BoolVarP(&o.Quiet, flagOutputQuite, "q", o.Quiet, "do not output messages to screen") flags.BoolVarP(&o.Verbose, flagOutputVerbose, "v", o.Verbose, "show debug messages") } diff --git a/internal/pkg/sshtask/sshtask.go b/internal/pkg/sshtask/sshtask.go index 039a21b..3ec0301 100644 --- a/internal/pkg/sshtask/sshtask.go +++ b/internal/pkg/sshtask/sshtask.go @@ -55,10 +55,8 @@ var ( ) ) -// TaskType ... type TaskType int -// ... const ( CommandTask TaskType = iota ScriptTask @@ -66,7 +64,6 @@ const ( FetchTask ) -// taskResult ... type taskResult struct { taskID string hostsSuccessCount int @@ -87,7 +84,7 @@ type pushFiles struct { zipFiles []string } -// Task ... +// Task is a ssh task for one or more hosts. type Task struct { configFlags *configflags.ConfigFlags @@ -97,7 +94,7 @@ type Task struct { sshClient *batchssh.Client sshAgent net.Conn defaultUser string - defaultPass *string + defaultPass string defaultIdentityFiles []string defaultSSHAuthMethods []ssh.AuthMethod @@ -121,7 +118,7 @@ type Task struct { err error } -// NewTask ... +// NewTask create a new task. func NewTask(taskType TaskType, configFlags *configflags.ConfigFlags) *Task { defaultIdentityFiles := parseItentityFiles(configFlags.Auth.IdentityFiles) @@ -132,14 +129,38 @@ func NewTask(taskType TaskType, configFlags *configflags.ConfigFlags) *Task { id: time.Now().Format("20060102150405"), taskType: taskType, defaultUser: configFlags.Auth.User, - defaultPass: &defaultPass, + defaultPass: defaultPass, defaultIdentityFiles: defaultIdentityFiles, taskOutput: make(chan taskResult, 1), detailOutput: make(chan detailResult), } } -// Start task. +// SSH implements batchssh.Tasker. +func (t *Task) SSH(host *batchssh.Host) (string, error) { + lang := t.configFlags.Run.Lang + runAs := t.configFlags.Run.AsUser + sudo := t.configFlags.Run.Sudo + + switch t.taskType { + case CommandTask: + return t.sshClient.ExecuteCmd(host, t.command, lang, runAs, sudo) + case ScriptTask: + return t.sshClient.ExecuteScript(host, t.scriptFile, t.dstDir, lang, runAs, sudo, t.remove, t.allowOverwrite) + case PushTask: + return t.sshClient.PushFiles(host, t.pushFiles.files, t.pushFiles.zipFiles, t.dstDir, t.allowOverwrite, t.enableZip) + case FetchTask: + hosts, err := t.getAllHosts() + if err != nil { + return "", err + } + return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs, t.enableZip, len(hosts)) + default: + return "", fmt.Errorf("unknown task type: %v", t.taskType) + } +} + +// Start to run ssh task. func (t *Task) Start() { if t.sshAgent != nil { defer t.sshAgent.Close() @@ -148,7 +169,7 @@ func (t *Task) Start() { go func() { defer close(t.taskOutput) defer close(t.detailOutput) - t.BatchRun() + t.batchRunSSH() }() taskTimeout := t.configFlags.Timeout.Task @@ -165,86 +186,11 @@ func (t *Task) Start() { }() } - t.HandleOutput() -} - -// SetTargetHosts ... -func (t *Task) SetTargetHosts(hosts []string) { - t.argHosts = hosts -} - -// SetCommand ... -func (t *Task) SetCommand(command string) { - t.command = command -} - -// SetScriptFile ... -func (t *Task) SetScriptFile(sciptFile string) { - t.scriptFile = sciptFile -} - -// SetPushfiles ... -func (t *Task) SetPushfiles(files, zipFiles []string) { - t.pushFiles = &pushFiles{ - files: files, - zipFiles: zipFiles, - } -} - -// SetFetchFiles ... -func (t *Task) SetFetchFiles(files []string) { - t.fetchFiles = files -} - -// SetScriptOptions ... -func (t *Task) SetScriptOptions(destPath string, remove, allowOverwrite bool) { - t.dstDir = destPath - t.remove = remove - t.allowOverwrite = allowOverwrite -} - -// SetPushOptions ... -func (t *Task) SetPushOptions(destPath string, allowOverwrite, enableZip bool) { - t.dstDir = destPath - t.allowOverwrite = allowOverwrite - t.enableZip = enableZip -} - -// SetFetchOptions ... -func (t *Task) SetFetchOptions(destPath, tmpDir string, enableZipFiles bool) { - t.dstDir = destPath - t.tmpDir = tmpDir - t.enableZip = enableZipFiles -} - -// RunSSH implements batchssh.Task -func (t *Task) RunSSH(host *batchssh.Host) (string, error) { - lang := t.configFlags.Run.Lang - runAs := t.configFlags.Run.AsUser - sudo := t.configFlags.Run.Sudo - - switch t.taskType { - case CommandTask: - return t.sshClient.ExecuteCmd(host, t.command, lang, runAs, sudo) - case ScriptTask: - return t.sshClient.ExecuteScript(host, t.scriptFile, t.dstDir, lang, runAs, sudo, t.remove, t.allowOverwrite) - case PushTask: - return t.sshClient.PushFiles(host, t.pushFiles.files, t.pushFiles.zipFiles, t.dstDir, t.allowOverwrite, t.enableZip) - case FetchTask: - hosts, err := t.getAllHosts() - if err != nil { - return "", err - } - return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs, t.enableZip, len(hosts)) - default: - return "", fmt.Errorf("unknown task type: %v", t.taskType) - } + t.handleOutput() } -// BatchRun ... -// //nolint:gocyclo -func (t *Task) BatchRun() { +func (t *Task) batchRunSSH() { timeNow := time.Now() if t.configFlags.Hosts.List { @@ -265,15 +211,6 @@ func (t *Task) BatchRun() { return } - authConf := t.configFlags.Auth - runConf := t.configFlags.Run - - log.Debugf("Default Auth: login user '%s'", authConf.User) - - if runConf.Sudo { - log.Debugf("Default Auth: use sudo as user '%s'", runConf.AsUser) - } - switch t.taskType { case CommandTask: if t.command == "" { @@ -304,6 +241,15 @@ func (t *Task) BatchRun() { return } + authConf := t.configFlags.Auth + runConf := t.configFlags.Run + + log.Debugf("Default Auth: login user '%s'", authConf.User) + + if runConf.Sudo { + log.Debugf("Default Auth: use sudo as user '%s'", runConf.AsUser) + } + t.setDefaultSSHAuthMethods() t.buildSSHClient() @@ -314,7 +260,7 @@ func (t *Task) BatchRun() { return } - log.Debugf("got target hosts, count: %d", len(allHosts)) + log.Debugf("Got target hosts, count: %d", len(allHosts)) result := t.sshClient.BatchRun(allHosts, t) successCount, failedCount := 0, 0 @@ -343,8 +289,7 @@ func (t *Task) BatchRun() { } } -// HandleOutput ... -func (t *Task) HandleOutput() { +func (t *Task) handleOutput() { for res := range t.detailOutput { // Fix the problem of special characters ^M appearing at the end of // the line break when writing files in text format. @@ -385,12 +330,51 @@ func (t *Task) HandleOutput() { } } -// CheckErr ... +func (t *Task) SetTargetHosts(hosts []string) { + t.argHosts = hosts +} + +func (t *Task) SetCommand(command string) { + t.command = command +} + +func (t *Task) SetScriptFile(sciptFile string) { + t.scriptFile = sciptFile +} + +func (t *Task) SetPushfiles(files, zipFiles []string) { + t.pushFiles = &pushFiles{ + files: files, + zipFiles: zipFiles, + } +} + +func (t *Task) SetFetchFiles(files []string) { + t.fetchFiles = files +} + +func (t *Task) SetScriptOptions(destPath string, remove, allowOverwrite bool) { + t.dstDir = destPath + t.remove = remove + t.allowOverwrite = allowOverwrite +} + +func (t *Task) SetPushOptions(destPath string, allowOverwrite, enableZip bool) { + t.dstDir = destPath + t.allowOverwrite = allowOverwrite + t.enableZip = enableZip +} + +func (t *Task) SetFetchOptions(destPath, tmpDir string, enableZipFiles bool) { + t.dstDir = destPath + t.tmpDir = tmpDir + t.enableZip = enableZipFiles +} + func (t *Task) CheckErr() error { return t.err } -// ListHosts ... func (t *Task) ListHosts() ([]string, error) { var hosts []string @@ -453,7 +437,7 @@ func (t *Task) getAllHosts() ([]*batchssh.Host, error) { Host: v, Port: t.configFlags.Hosts.Port, User: t.defaultUser, - Password: *t.defaultPass, + Password: t.defaultPass, Keys: t.defaultIdentityFiles, SSHAuths: t.defaultSSHAuthMethods, }) @@ -504,7 +488,7 @@ func (t *Task) getAllHosts() ([]*batchssh.Host, error) { realPassword := "" if v.Password == "" { - realPassword = getRealPass(*t.defaultPass, v.Alias, "password") + realPassword = getRealPass(t.defaultPass, v.Alias, "password") } else { realPassword = getRealPass(v.Password, v.Alias, "password") hostSSHAuths = append(hostSSHAuths, ssh.Password(v.Password)) @@ -636,20 +620,20 @@ func (t *Task) setDefaultSSHAuthMethods() { auths = append(auths, ssh.PublicKeys(signers...)) } - if *t.defaultPass != "" { - auths = append(auths, ssh.Password(*t.defaultPass)) + if t.defaultPass != "" { + auths = append(auths, ssh.Password(t.defaultPass)) } else { log.Debugf("Default Auth: password of the login user '%s' not provided", t.defaultUser) } - if *t.defaultPass == "" && t.configFlags.Run.Sudo { + if t.defaultPass == "" && t.configFlags.Run.Sudo { log.Debugf( "Default Auth: using sudo as other user needs password. Prompt for password of the login user '%s'", t.defaultUser, ) - *t.defaultPass = getPasswordFromPrompt(t.defaultUser) - auths = append(auths, ssh.Password(*t.defaultPass)) + t.defaultPass = getPasswordFromPrompt(t.defaultUser) + auths = append(auths, ssh.Password(t.defaultPass)) } t.defaultSSHAuthMethods = auths @@ -693,7 +677,7 @@ func (t *Task) getProxySSHAuthMethods() []ssh.AuthMethod { if t.configFlags.Proxy.Password != "" { proxyAuths = append(proxyAuths, ssh.Password(t.configFlags.Proxy.Password)) } else { - proxyAuths = append(proxyAuths, ssh.Password(*t.defaultPass)) + proxyAuths = append(proxyAuths, ssh.Password(t.defaultPass)) } log.Debugf("Proxy Auth: received password of the proxy user") diff --git a/pkg/batchssh/batchssh.go b/pkg/batchssh/batchssh.go index bbfbc27..e9dba51 100644 --- a/pkg/batchssh/batchssh.go +++ b/pkg/batchssh/batchssh.go @@ -49,9 +49,9 @@ const ( FailedIdentifier = "FAILED" ) -// Task execute command or copy file or execute script. -type Task interface { - RunSSH(host *Host) (string, error) +// Tasker for ssh. +type Tasker interface { + SSH(host *Host) (string, error) } // Result of ssh command. @@ -104,10 +104,7 @@ func NewClient(options ...func(*Client)) *Client { } // BatchRun command on remote servers. -func (c *Client) BatchRun( - hosts []*Host, - sshTask Task, -) <-chan *Result { +func (c *Client) BatchRun(hosts []*Host, sshTask Tasker) <-chan *Result { hostCh := make(chan *Host) go func() { defer close(hostCh) @@ -128,7 +125,7 @@ func (c *Client) BatchRun( go func() { defer close(done) - output, err := sshTask.RunSSH(host) + output, err := sshTask.SSH(host) if err != nil { result = &Result{host.Alias, FailedIdentifier, err.Error()} } else { @@ -192,7 +189,7 @@ func (c *Client) ExecuteCmd(host *Host, command, lang, runAs string, sudo bool) command = exportLang + command } - return c.executeCmd(session, command, host.Password) + return c.executeCmd(session, command, host) } // ExecuteScript on remote host. @@ -249,7 +246,7 @@ func (c *Client) ExecuteScript( command = exportLang + script } - return c.executeCmd(session, command, host.Password) + return c.executeCmd(session, command, host) } // PushFiles to remote host. @@ -321,7 +318,7 @@ func (c *Client) PushFiles( dstDir, dstZipFile, ), - host.Password, + host, ) if err != nil { return "", err @@ -406,8 +403,7 @@ func (c *Client) FetchFiles( dstDir = filepath.Join(dstDir, host.Host) err = os.MkdirAll(dstDir, os.ModePerm) if err != nil { - log.Errorf("make local dir '%s' failed: %v", dstDir, err) - return "", err + return "", fmt.Errorf("make local dir '%s' failed: %v", dstDir, err) } log.Debugf("make local dir '%s'", dstDir) } @@ -468,7 +464,7 @@ func (c *Client) FetchFiles( return ret, nil } -func (c *Client) executeCmd(session *ssh.Session, command, password string) (string, error) { +func (c *Client) executeCmd(session *ssh.Session, command string, host *Host) (string, error) { modes := ssh.TerminalModes{ ssh.ECHO: 0, ssh.TTY_OP_ISPEED: 28800, @@ -490,7 +486,7 @@ func (c *Client) executeCmd(session *ssh.Session, command, password string) (str return "", err } - out, isWrongPass := c.handleOutput(w, r, password) + out, isWrongPass := c.handleOutput(w, r, host.Password) done := make(chan struct{}) go func() { @@ -512,9 +508,10 @@ func (c *Client) executeCmd(session *ssh.Session, command, password string) (str <-done if err != nil { - log.Debugf("'%s' executed failed: %s", command, err) + log.Debugf("%s: execute command '%s' failed, error: %v, output: %s", host.Host, command, err, outputStr) return "", errors.New(outputStr) } + log.Debugf("%s: execute command '%s' success, output: %s", host.Host, command, outputStr) return outputStr, nil } @@ -640,7 +637,7 @@ func WithProxyServer(proxyServer, user string, port int, auths []ssh.AuthMethod) proxySSHConfig, ) if err1 != nil { - c.Proxy.Err = fmt.Errorf("connet to proxy %s:%d failed: %s", proxyServer, port, err1) + c.Proxy.Err = fmt.Errorf("connect to proxy %s:%d failed: %s", proxyServer, port, err1) return } @@ -648,13 +645,3 @@ func WithProxyServer(proxyServer, user string, port int, auths []ssh.AuthMethod) c.Proxy.SSHClient = proxyClient } } - -// 判断是否是目录 -func isDir(path string) bool { - stat, err := os.Stat(path) - if err != nil { - return false - } - - return stat.IsDir() -} diff --git a/pkg/batchssh/fetch.go b/pkg/batchssh/fetch.go index 6a96130..1de1321 100644 --- a/pkg/batchssh/fetch.go +++ b/pkg/batchssh/fetch.go @@ -19,8 +19,7 @@ func (c *Client) fetchFileOrDir( ) error { fStat, err := ftpC.Stat(srcFile) if err != nil { - log.Errorf("%s: stat '%s' failed: %v", host, srcFile, err) - return err + return fmt.Errorf("%s: stat '%s' failed: %v", host, srcFile, err) } if !fStat.IsDir() { @@ -40,8 +39,7 @@ func (c *Client) fetchFileOrDir( localFilePath := path.Join(dstDir, filepath.Base(srcFile)) err = os.MkdirAll(localFilePath, fStat.Mode().Perm()) if err != nil { - log.Errorf("make local dir '%s' failed: %v", localFilePath, err) - return err + return fmt.Errorf("make local dir '%s' failed: %v", localFilePath, err) } log.Debugf("make local dir '%s'", localFilePath) @@ -51,14 +49,12 @@ func (c *Client) fetchFileOrDir( if item.IsDir() { err = c.fetchFileOrDir(ftpC, remoteFilePath, localFilePath, host) if err != nil { - log.Errorf("%s: fetchFileOrDir '%s' failed, error: %v", host, remoteFilePath, err) - return err + return fmt.Errorf("%s: fetchFileOrDir '%s' failed, error: %v", host, remoteFilePath, err) } } else { err = fetchFile(ftpC, remoteFilePath, localFilePath, host) if err != nil { - log.Errorf("%s: fetchFile '%s' failed, error: %v", host, localFilePath, err) - return err + return fmt.Errorf("%s: fetchFile '%s' failed, error: %v", host, localFilePath, err) } } } @@ -112,7 +108,7 @@ func fetchFile( return fmt.Errorf("chmod local file '%s' failed: %w", dstFile, err) } - log.Debugf("%s: '%s' -> '%s fetched", host, srcFile, dstFile) + log.Debugf("%s: %s -> %s fetched", host, srcFile, dstFile) return nil } diff --git a/pkg/batchssh/fetch_zip.go b/pkg/batchssh/fetch_zip.go index e879322..0daa383 100644 --- a/pkg/batchssh/fetch_zip.go +++ b/pkg/batchssh/fetch_zip.go @@ -57,18 +57,16 @@ fi`, zippedFileFullpath, srcFileName, ), - host.Password, + host, ) if err != nil { - log.Errorf("%s: zip '%s' failed: %s", host.Host, srcFile, err) - return err + return fmt.Errorf("%s: zip '%s' failed: %s", host.Host, srcFile, err) } log.Debugf("%s: zip '%s' cost %s", host.Host, srcFile, time.Since(timeStart)) if err = fetchZipFile(ftpC, zippedFileFullpath, dstDir); err != nil { - log.Errorf("%s: fetch zip file '%s' failed: %s", host.Host, zippedFileFullpath, err) - return err + return fmt.Errorf("%s: fetch zip file '%s' failed: %s", host.Host, zippedFileFullpath, err) } log.Debugf("%s: fetched zip file '%s'", host.Host, zippedFileFullpath) @@ -81,25 +79,23 @@ fi`, _, err = c.executeCmd( session2, fmt.Sprintf("sudo -u %s -H bash -c 'rm -f %s'", runAs, zippedFileFullpath), - host.Password, + host, ) if err != nil { - log.Errorf("%s: remove '%s' failed: %s", host.Host, zippedFileFullpath, err) - return err + return fmt.Errorf("%s: remove '%s' failed: %s", host.Host, zippedFileFullpath, err) } log.Debugf("%s: removed '%s'", host.Host, zippedFileFullpath) localZippedFileFullpath := path.Join(dstDir, tmpZipFile) defer func() { if err := os.Remove(localZippedFileFullpath); err != nil { - log.Errorf("remove '%s' failed: %s", localZippedFileFullpath, err) + log.Debugf("remove '%s' failed: %s", localZippedFileFullpath, err) } else { log.Debugf("removed '%s'", localZippedFileFullpath) } }() if err := util.Unzip(localZippedFileFullpath, dstDir); err != nil { - log.Errorf("unzip '%s' to '%s' failed: %s", localZippedFileFullpath, dstDir, err) - return err + return fmt.Errorf("unzip '%s' to '%s' failed: %s", localZippedFileFullpath, dstDir, err) } log.Debugf("unzipped '%s' to '%s'", localZippedFileFullpath, dstDir) diff --git a/pkg/batchssh/pushv1.go b/pkg/batchssh/push.go similarity index 85% rename from pkg/batchssh/pushv1.go rename to pkg/batchssh/push.go index 4d320ba..613e08b 100644 --- a/pkg/batchssh/pushv1.go +++ b/pkg/batchssh/push.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/sftp" "github.com/windvalley/gossh/pkg/log" + "github.com/windvalley/gossh/pkg/util" ) // pushFileOrDir is less efficient than pushFileOrDirV2. @@ -22,7 +23,7 @@ func (c *Client) pushFileOrDir( srcFile, dstDir, host string, allowOverwrite bool, ) error { - if !isDir(srcFile) { + if !util.IsDir(srcFile) { _, err := c.pushFile(ftpC, srcFile, dstDir, host, allowOverwrite) if err != nil { return err @@ -39,8 +40,7 @@ func (c *Client) pushFileOrDir( remoteFilePath := path.Join(dstDir, filepath.Base(srcFile)) err = ftpC.MkdirAll(remoteFilePath) if err != nil { - log.Errorf("%s: mkdir '%s' failed", host, remoteFilePath) - return err + return fmt.Errorf("%s: mkdir '%s' failed, error: %v", host, remoteFilePath, err) } for _, item := range localFiles { @@ -49,14 +49,12 @@ func (c *Client) pushFileOrDir( if item.IsDir() { err = c.pushFileOrDir(ftpC, localFilePath, remoteFilePath, host, allowOverwrite) if err != nil { - log.Errorf("%s: pushFileOrDir '%s' failed", host, localFilePath) - return err + return fmt.Errorf("%s: pushFileOrDir '%s' failed, error: %v", host, localFilePath, err) } } else { _, err = c.pushFile(ftpC, localFilePath, remoteFilePath, host, allowOverwrite) if err != nil { - log.Errorf("%s: pushFile '%s' failed", host, localFilePath) - return err + return fmt.Errorf("%s: pushFile '%s' failed, error: %v", host, localFilePath, err) } } } @@ -120,7 +118,7 @@ func (c *Client) pushFile( return nil, err } - log.Debugf("%s: '%s' -> '%s", host, srcFile, dstFile) + log.Debugf("%s: %s -> %s pushed", host, srcFile, dstFile) return file, nil } diff --git a/pkg/batchssh/pushv2.go b/pkg/batchssh/pushv2.go index a8684b0..ef572fa 100644 --- a/pkg/batchssh/pushv2.go +++ b/pkg/batchssh/pushv2.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/windvalley/gossh/pkg/log" + "github.com/windvalley/gossh/pkg/util" ) const ( @@ -41,7 +42,7 @@ func (c *Client) pushFileOrDirV2( } defer session.Close() - if !isDir(srcFile) { + if !util.IsDir(srcFile) { return pushFileV2(session, srcFile, dstDir, host) } @@ -52,19 +53,19 @@ func pushFileV2(session *ssh.Session, src, dest, host string) error { go func() { w, err := session.StdinPipe() if err != nil { - log.Errorf("%s: failed to open stdin, error: %v", host, err) + log.Debugf("%s: failed to open stdin, error: %v", host, err) return } defer w.Close() fileinfo, err := os.Stat(src) if err != nil { - log.Errorf("%s: failed to get file stat, error: %v", host, err) + log.Debugf("%s: failed to get file stat, file: %s, error: %v", host, src, err) return } if err := createFile(w, host, src, fileinfo); err != nil { - return + log.Debugf("%s: failed to create file, file: %s, error: %v", host, src, err) } }() @@ -75,21 +76,21 @@ func pushDirV2(session *ssh.Session, src, dest, host string) error { go func() { w, err := session.StdinPipe() if err != nil { - log.Errorf("failed to open stdin, error: %v", err) + log.Debugf("%s: failed to open stdin, error: %v", host, err) return } defer w.Close() fileinfo, err := os.Stat(src) if err != nil { - log.Errorf("failed to get file stat, error: %v", err) + log.Debugf("%s: failed to get file stat, file: %s, error: %v", host, src, err) return } fmt.Fprintln(w, PushBeginFolder+getMode(fileinfo), PushBeginEndFolder, fileinfo.Name()) if err := walkDir(w, src, host); err != nil { - log.Errorf("failed to walk dir, error: %v", err) + log.Debugf("%s: failed to walk dir, dir: %s, error: %v", host, src, err) return } @@ -164,8 +165,7 @@ func getMode(f fs.FileInfo) string { func checkAllowOverWrite(ftpC *sftp.Client, host, dstFile string) error { f, err := ftpC.Stat(dstFile) if err != nil && !os.IsNotExist(err) { - log.Errorf("%s: failed to stat %s, error: %v", host, dstFile, err) - return err + return fmt.Errorf("%s: failed to stat %s, error: %v", host, dstFile, err) } if f != nil { diff --git a/pkg/log/log.go b/pkg/log/log.go index dc6babb..73bdde6 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -26,6 +26,8 @@ import ( "fmt" "io" "os" + + "github.com/fatih/color" ) // User can directly use package level functions @@ -57,18 +59,20 @@ func Init(logfile string, json, verbose, quiet, condense bool) { if logfile != "" { //nolint:gomnd - file, err := os.OpenFile(logfile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + f, err := os.OpenFile(logfile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - fmt.Printf("Failed to log to '%s'\n", logfile) + errMsg := color.YellowString(fmt.Sprintf("Warning: Failed to write log to %s: %v\n", logfile, err)) + println(errMsg) + if quiet { std.Out = io.Discard } } else { if !quiet { - mw := io.MultiWriter(os.Stdout, file) + mw := io.MultiWriter(os.Stdout, f) std.Out = mw } else { - std.Out = file + std.Out = f } } } else { diff --git a/pkg/util/file.go b/pkg/util/file.go index ecad1a1..ee7cdd7 100644 --- a/pkg/util/file.go +++ b/pkg/util/file.go @@ -27,7 +27,6 @@ import ( "strings" ) -// FileExists ... func FileExists(path string) bool { path = rebuildPath(path) @@ -39,7 +38,6 @@ func FileExists(path string) bool { return !f.IsDir() } -// DirExists ... func DirExists(path string) bool { path = rebuildPath(path) @@ -51,6 +49,15 @@ func DirExists(path string) bool { return f.IsDir() } +func IsDir(path string) bool { + stat, err := os.Stat(path) + if err != nil { + return false + } + + return stat.IsDir() +} + func rebuildPath(path string) string { homeDir := os.Getenv("HOME") if strings.HasPrefix(path, "~/") {