From 68826abc0c01a15d7e07e5925a7c8d0a5e9d6571 Mon Sep 17 00:00:00 2001 From: Zino Kader Date: Fri, 24 Feb 2023 02:36:36 +0100 Subject: [PATCH] feat: file overwrite prompt (with configurability) --- go.mod | 1 + go.sum | 2 + internal/config/config.go | 16 +++-- internal/file/file.go | 30 ++++++-- ui/constants.go | 31 +++++++-- ui/filetable/filetable.go | 3 + ui/receiver/receiver.go | 142 +++++++++++++++++++++++++++++++------- ui/sender/sender.go | 2 +- 8 files changed, 185 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index 04ba7af..7cea214 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/charmbracelet/bubbletea v0.23.2 github.com/charmbracelet/lipgloss v0.6.0 github.com/docker/go-connections v0.4.0 + github.com/erikgeiser/promptkit v0.8.0 github.com/fatih/structs v1.1.0 github.com/klauspost/pgzip v1.2.5 github.com/mattn/go-runewidth v0.0.14 diff --git a/go.sum b/go.sum index bf20cb5..909b151 100644 --- a/go.sum +++ b/go.sum @@ -300,6 +300,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/erikgeiser/promptkit v0.8.0 h1:bvOzPs6RLyfRZDSgVWOghQEiBSRHQ3zmDdxcV8zOc+E= +github.com/erikgeiser/promptkit v0.8.0/go.mod h1:QxyFbCrrj20PyvV5b+ckWPozbgX11s04GeRlmTCIMTo= github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= diff --git a/internal/config/config.go b/internal/config/config.go index ccfc705..3be111d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,25 +14,27 @@ const CONFIG_FILE_NAME = "config" const CONFIG_FILE_EXT = "yml" type Config struct { - Relay string `mapstructure:"relay"` - Verbose bool `mapstructure:"verbose"` + Relay string `mapstructure:"relay"` + Verbose bool `mapstructure:"verbose"` + PromptOverwriteFiles bool `mapstructure:"prompt_overwrite_files"` } func GetDefault() Config { return Config{ - Relay: "167.71.65.96:80", - Verbose: false, + Relay: "167.71.65.96:80", + Verbose: false, + PromptOverwriteFiles: true, } } func ToMap(config Config) map[string]any { - p := map[string]any{} + m := map[string]any{} for _, field := range structs.Fields(config) { key := field.Tag("mapstructure") value := field.Value() - p[key] = value + m[key] = value } - return p + return m } func ToYaml(config Config) []byte { diff --git a/internal/file/file.go b/internal/file/file.go index 9ced7c7..d800e12 100644 --- a/internal/file/file.go +++ b/internal/file/file.go @@ -15,6 +15,8 @@ import ( const SEND_TEMP_FILE_NAME_PREFIX = "portal-send-temp" const RECEIVE_TEMP_FILE_NAME_PREFIX = "portal-receive-temp" +type OverwriteDecider func(fileName string) (bool, error) + func ReadFiles(fileNames []string) ([]*os.File, error) { var files []*os.File for _, fileName := range fileNames { @@ -27,9 +29,9 @@ func ReadFiles(fileNames []string) ([]*os.File, error) { return files, nil } -// ArchiveAndCompressFiles tars and gzip-compresses files into a temporary file, returning it +// PackFiles tars and gzip-compresses files into a temporary file, returning it // along with the resulting size -func ArchiveAndCompressFiles(files []*os.File) (*os.File, int64, error) { +func PackFiles(files []*os.File) (*os.File, int64, error) { // chained writers -> writing to tw writes to gw -> writes to temporary file tempFile, err := os.CreateTemp(os.TempDir(), SEND_TEMP_FILE_NAME_PREFIX) if err != nil { @@ -60,9 +62,9 @@ func ArchiveAndCompressFiles(files []*os.File) (*os.File, int64, error) { return tempFile, fileInfo.Size(), nil } -// DecompressAndUnarchiveBytes gzip-decompresses and un-tars files into the current working directory +// UnpackFiles gzip-decompresses and un-tars files into the current working directory // and returns the names and decompressed size of the created files -func DecompressAndUnarchiveBytes(reader io.Reader) ([]string, int64, error) { +func UnpackFiles(reader io.Reader, decideOverwrite OverwriteDecider) ([]string, int64, error) { // chained readers -> gr reads from reader -> tr reads from gr gr, err := pgzip.NewReader(reader) if err != nil { @@ -94,24 +96,39 @@ func DecompressAndUnarchiveBytes(reader io.Reader) ([]string, int64, error) { fileTarget := filepath.Join(cwd, header.Name) switch header.Typeflag { + case tar.TypeDir: if _, err := os.Stat(fileTarget); err != nil { if err := os.MkdirAll(fileTarget, 0755); err != nil { return nil, 0, err } } + case tar.TypeReg: + if fileExists(fileTarget) { + shouldOverwrite, err := decideOverwrite(fileTarget) + if err != nil { + return nil, 0, err + } + if !shouldOverwrite { + continue + } + } + f, err := os.OpenFile(fileTarget, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) if err != nil { return nil, 0, err } + if _, err := io.Copy(f, tr); err != nil { return nil, 0, err } + fileInfo, err := f.Stat() if err != nil { return nil, 0, err } + decompressedSize += fileInfo.Size() createdFiles = append(createdFiles, header.Name) f.Close() @@ -195,6 +212,11 @@ func addToTarArchive(tw *tar.Writer, file *os.File) error { }) } +func fileExists(filename string) bool { + _, err := os.Stat(filename) + return !os.IsNotExist(err) +} + // optimistically remove files created by portal with the specified prefix func RemoveTemporaryFiles(prefix string) { tempFiles, err := os.ReadDir(os.TempDir()) diff --git a/ui/constants.go b/ui/constants.go index 3ed634e..608f1b1 100644 --- a/ui/constants.go +++ b/ui/constants.go @@ -26,10 +26,12 @@ const ( ) type KeyMap struct { - Quit key.Binding - CopyPassword key.Binding - FileListUp key.Binding - FileListDown key.Binding + Quit key.Binding + CopyPassword key.Binding + FileListUp key.Binding + FileListDown key.Binding + OverwritePromptYes key.Binding + OverwritePromptNo key.Binding } func (k KeyMap) ShortHelp() []key.Binding { @@ -38,12 +40,21 @@ func (k KeyMap) ShortHelp() []key.Binding { k.CopyPassword, k.FileListUp, k.FileListDown, + k.OverwritePromptYes, + k.OverwritePromptNo, } } func (k KeyMap) FullHelp() [][]key.Binding { return [][]key.Binding{ - {k.Quit, k.CopyPassword, k.FileListUp, k.FileListDown}, + { + k.Quit, + k.CopyPassword, + k.FileListUp, + k.FileListDown, + k.OverwritePromptYes, + k.OverwritePromptNo, + }, } } @@ -73,6 +84,16 @@ var Keys = KeyMap{ key.WithHelp("(↓/j)", "file summary down"), key.WithDisabled(), ), + OverwritePromptYes: key.NewBinding( + key.WithKeys("y", "Y"), + key.WithHelp("(Y/y)", "confirm overwrite"), + key.WithDisabled(), + ), + OverwritePromptNo: key.NewBinding( + key.WithKeys("n", "N"), + key.WithHelp("(N/n)", "deny overwrite"), + key.WithDisabled(), + ), } var PadText = strings.Repeat(" ", MARGIN) diff --git a/ui/filetable/filetable.go b/ui/filetable/filetable.go index d23f701..82f3d3a 100644 --- a/ui/filetable/filetable.go +++ b/ui/filetable/filetable.go @@ -162,5 +162,8 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (m Model) View() string { + if len(m.rows) == 0 { + return "" + } return fileTableStyle.Render(m.table.View()) + "\n\n" } diff --git a/ui/receiver/receiver.go b/ui/receiver/receiver.go index 064c386..e9750ff 100644 --- a/ui/receiver/receiver.go +++ b/ui/receiver/receiver.go @@ -21,6 +21,9 @@ import ( "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/erikgeiser/promptkit" + "github.com/erikgeiser/promptkit/confirmation" + "github.com/spf13/viper" ) // ------------------------------------------------------ Ui State ----------------------------------------------------- @@ -31,6 +34,7 @@ const ( showEstablishing uiState = iota showReceivingProgress showDecompressing + showOverwritePrompt showFinished ) @@ -47,8 +51,16 @@ type receiveDoneMsg struct { temp *os.File } +type overwritePromptRequestMsg struct { + fileName string +} + +type overwritePromptResponseMsg struct { + shouldOverwrite bool +} + type decompressionDoneMsg struct { - filenames []string + fileNames []string decompressedPayloadSize int64 } @@ -67,8 +79,10 @@ type model struct { transferType transfer.Type password string - ctx context.Context - msgs chan interface{} + ctx context.Context + msgs chan interface{} + overwritePromptRequests chan overwritePromptRequestMsg + overwritePromptResponses chan overwritePromptResponseMsg rendezvousAddr string @@ -81,6 +95,7 @@ type model struct { spinner spinner.Model transferProgress transferprogress.Model fileTable filetable.Model + overwritePrompt confirmation.Model help help.Model keys ui.KeyMap } @@ -88,14 +103,16 @@ type model struct { // New creates a new receiver program. func New(addr string, password string, opts ...Option) *tea.Program { m := model{ - transferProgress: transferprogress.New(), - msgs: make(chan interface{}, 10), - fileTable: filetable.New(), - password: password, - rendezvousAddr: addr, - help: help.New(), - keys: ui.Keys, - ctx: context.Background(), + transferProgress: transferprogress.New(), + msgs: make(chan interface{}, 10), + overwritePromptRequests: make(chan overwritePromptRequestMsg), + overwritePromptResponses: make(chan overwritePromptResponseMsg), + fileTable: filetable.New(), + password: password, + rendezvousAddr: addr, + help: help.New(), + keys: ui.Keys, + ctx: context.Background(), } for _, opt := range opts { opt(&m) @@ -173,6 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case receiveDoneMsg: m.state = showDecompressing m.resetSpinner() + message := fmt.Sprintf("Transfer completed in %s with average transfer speed %s/s", time.Since(m.transferProgress.TransferStartTime).Round(time.Millisecond).String(), ui.ByteCountSI(m.transferProgress.TransferSpeedEstimateBps), @@ -180,11 +198,37 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.fileTable.SetMaxHeight(math.MaxInt) m.fileTable = m.fileTable.Finalize().(filetable.Model) - return m, ui.TaskCmd(message, tea.Batch(m.spinner.Tick, decompressCmd(msg.temp))) + + cmds := []tea.Cmd{m.spinner.Tick, + m.listenOverwritePromptRequestsCmd(), + m.decompressCmd(msg.temp), + } + + return m, ui.TaskCmd(message, tea.Batch(cmds...)) + + case overwritePromptRequestMsg: + var cmds []tea.Cmd + m.state = showOverwritePrompt + m.keys.OverwritePromptYes.SetEnabled(true) + m.keys.OverwritePromptNo.SetEnabled(true) + m.resetSpinner() + cmds = append(cmds, m.spinner.Tick) + + prompt := confirmation.New(fmt.Sprintf("Overwrite file '%s'?", msg.fileName), confirmation.Yes) + m.overwritePrompt = *confirmation.NewModel(prompt) + m.overwritePrompt.MaxWidth = m.width + m.overwritePrompt.WrapMode = promptkit.HardWrap + m.overwritePrompt.Template = confirmation.TemplateYN + m.overwritePrompt.ResultTemplate = confirmation.ResultTemplateYN + m.overwritePrompt.KeyMap.Abort = []string{} + m.overwritePrompt.KeyMap.Toggle = []string{} + cmds = append(cmds, m.overwritePrompt.Init()) + + return m, tea.Batch(cmds...) case decompressionDoneMsg: m.state = showFinished - m.receivedFiles = msg.filenames + m.receivedFiles = msg.fileNames m.decompressedPayloadSize = msg.decompressedPayloadSize m.fileTable.SetFiles(m.receivedFiles) @@ -194,6 +238,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, ui.ErrorCmd(errors.New(msg.Error())) case tea.KeyMsg: + var cmds []tea.Cmd switch { case key.Matches(msg, m.keys.Quit): return m, tea.Quit @@ -201,21 +246,43 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { fileTableModel, fileTableCmd := m.fileTable.Update(msg) m.fileTable = fileTableModel.(filetable.Model) + cmds = append(cmds, fileTableCmd) + + _, promptCmd := m.overwritePrompt.Update(msg) + if m.state == showOverwritePrompt { + switch msg.String() { + case "left", "right": + cmds = append(cmds, promptCmd) + case "y", "Y", "n", "N", "enter": + m.state = showDecompressing + m.keys.OverwritePromptYes.SetEnabled(false) + m.keys.OverwritePromptNo.SetEnabled(false) + shouldOverwrite, _ := m.overwritePrompt.Value() + m.overwritePromptResponses <- overwritePromptResponseMsg{shouldOverwrite} + cmds = append(cmds, m.listenOverwritePromptRequestsCmd()) + } + } - return m, fileTableCmd + return m, tea.Batch(cmds...) case tea.WindowSizeMsg: m.width = msg.Width transferProgressModel, transferProgressCmd := m.transferProgress.Update(msg) m.transferProgress = transferProgressModel.(transferprogress.Model) + fileTableModel, fileTableCmd := m.fileTable.Update(msg) m.fileTable = fileTableModel.(filetable.Model) - return m, tea.Batch(transferProgressCmd, fileTableCmd) + + m.overwritePrompt.MaxWidth = msg.Width - 2*ui.MARGIN - 4 + _, promptCmd := m.overwritePrompt.Update(msg) + + return m, tea.Batch(transferProgressCmd, fileTableCmd, promptCmd) default: - var cmd tea.Cmd - m.spinner, cmd = m.spinner.Update(msg) - return m, cmd + var spinnerCmd tea.Cmd + m.spinner, spinnerCmd = m.spinner.Update(msg) + _, promptCmd := m.overwritePrompt.Update(msg) + return m, tea.Batch(spinnerCmd, promptCmd) } } @@ -243,6 +310,14 @@ func (m model) View() string { ui.PadText + m.transferProgress.View() + "\n\n" + ui.PadText + m.help.View(m.keys) + "\n\n" + case showOverwritePrompt: + waitingText := fmt.Sprintf("%s Waiting for file overwrite confirmation", m.spinner.View()) + return ui.PadText + ui.LogSeparator(m.width) + + ui.PadText + ui.InfoStyle(waitingText) + "\n\n" + + ui.PadText + m.transferProgress.View() + "\n\n" + + ui.PadText + m.overwritePrompt.View() + "\n\n" + + ui.PadText + m.help.View(m.keys) + "\n\n" + case showDecompressing: payloadSize := ui.BoldText(ui.ByteCountSI(m.payloadSize)) decompressingText := fmt.Sprintf("%s Decompressing payload (%s compressed) and writing to disk", m.spinner.View(), payloadSize) @@ -253,10 +328,10 @@ func (m model) View() string { case showFinished: oneOrMoreFiles := "object" - if len(m.receivedFiles) > 1 { + if len(m.receivedFiles) == 0 || len(m.receivedFiles) > 1 { oneOrMoreFiles += "s" } - finishedText := fmt.Sprintf("Received %d %s (%s compressed)", len(m.receivedFiles), oneOrMoreFiles, ui.ByteCountSI(m.payloadSize)) + finishedText := fmt.Sprintf("Received %d %s (%s decompressed)", len(m.receivedFiles), oneOrMoreFiles, ui.ByteCountSI(m.decompressedPayloadSize)) return ui.PadText + ui.LogSeparator(m.width) + ui.PadText + ui.InfoStyle(finishedText) + "\n\n" + ui.PadText + m.transferProgress.View() + "\n\n" + @@ -320,19 +395,36 @@ func listenReceiveCmd(msgs chan interface{}) tea.Cmd { } } -func decompressCmd(temp *os.File) tea.Cmd { +func (m *model) listenOverwritePromptRequestsCmd() tea.Cmd { return func() tea.Msg { - // reset file position for reading + return <-m.overwritePromptRequests + } +} + +func (m *model) decompressCmd(temp *os.File) tea.Cmd { + return func() tea.Msg { + // Reset file position for reading. _, err := temp.Seek(0, 0) if err != nil { return ui.ErrorMsg(err) } - filenames, decompressedSize, err := file.DecompressAndUnarchiveBytes(temp) + // promptFunc is a no-op if we allow overwriting files without prompts. + promptFunc := func(fileName string) (bool, error) { return true, nil } + if viper.GetBool("prompt_overwrite_files") { + promptFunc = func(fileName string) (bool, error) { + m.overwritePromptRequests <- overwritePromptRequestMsg{fileName} + overwritePromptResponse := <-m.overwritePromptResponses + return overwritePromptResponse.shouldOverwrite, nil + } + } + + fileNames, size, err := file.UnpackFiles(temp, promptFunc) if err != nil { return ui.ErrorMsg(err) } - return decompressionDoneMsg{filenames: filenames, decompressedPayloadSize: decompressedSize} + + return decompressionDoneMsg{fileNames, size} } } @@ -341,7 +433,7 @@ func decompressCmd(temp *os.File) tea.Cmd { func (m *model) resetSpinner() { m.spinner = spinner.New() m.spinner.Style = lipgloss.NewStyle().Foreground(lipgloss.Color(ui.ELEMENT_COLOR)) - if m.state == showEstablishing { + if m.state == showEstablishing || m.state == showOverwritePrompt { m.spinner.Spinner = ui.WaitingSpinner } if m.state == showDecompressing { diff --git a/ui/sender/sender.go b/ui/sender/sender.go index 2e8dcb7..8a9668d 100644 --- a/ui/sender/sender.go +++ b/ui/sender/sender.go @@ -405,7 +405,7 @@ func readFilesCmd(paths []string) tea.Cmd { // provided files. func compressFilesCmd(files []*os.File) tea.Cmd { return func() tea.Msg { - tar, size, err := file.ArchiveAndCompressFiles(files) + tar, size, err := file.PackFiles(files) if err != nil { return ui.ErrorMsg(err) }