diff --git a/memmap.go b/memmap.go index e6b7d70b..d6c744e8 100644 --- a/memmap.go +++ b/memmap.go @@ -16,9 +16,12 @@ package afero import ( "fmt" "io" + "log" "os" "path/filepath" + + "sort" "strings" "sync" "time" @@ -88,6 +91,24 @@ func (m *MemMapFs) findParent(f *mem.FileData) *mem.FileData { return pfile } +func (m *MemMapFs) findDescendants(name string) []*mem.FileData { + fData := m.getData() + descendants := make([]*mem.FileData, 0, len(fData)) + for p, dFile := range fData { + if strings.HasPrefix(p, name+FilePathSeparator) { + descendants = append(descendants, dFile) + } + } + + sort.Slice(descendants, func(i, j int) bool { + cur := len(strings.Split(descendants[i].Name(), FilePathSeparator)) + next := len(strings.Split(descendants[j].Name(), FilePathSeparator)) + return cur < next + }) + + return descendants +} + func (m *MemMapFs) registerWithParent(f *mem.FileData, perm os.FileMode) { if f == nil { return @@ -309,29 +330,51 @@ func (m *MemMapFs) Rename(oldname, newname string) error { if _, ok := m.getData()[oldname]; ok { m.mu.RUnlock() m.mu.Lock() - m.unRegisterWithParent(oldname) + err := m.unRegisterWithParent(oldname) + if err != nil { + return err + } + fileData := m.getData()[oldname] - delete(m.getData(), oldname) mem.ChangeFileName(fileData, newname) m.getData()[newname] = fileData + + err = m.renameDescendants(oldname, newname) + if err != nil { + return err + } + + delete(m.getData(), oldname) + m.registerWithParent(fileData, 0) m.mu.Unlock() m.mu.RLock() } else { return &os.PathError{Op: "rename", Path: oldname, Err: ErrFileNotFound} } + return nil +} - for p, fileData := range m.getData() { - if strings.HasPrefix(p, oldname+FilePathSeparator) { - m.mu.RUnlock() - m.mu.Lock() - delete(m.getData(), p) - p := strings.Replace(p, oldname, newname, 1) - m.getData()[p] = fileData - m.mu.Unlock() - m.mu.RLock() +func (m *MemMapFs) renameDescendants(oldname, newname string) error { + descendants := m.findDescendants(oldname) + removes := make([]string, 0, len(descendants)) + for _, desc := range descendants { + descNewName := strings.Replace(desc.Name(), oldname, newname, 1) + err := m.unRegisterWithParent(desc.Name()) + if err != nil { + return err } + + removes = append(removes, desc.Name()) + mem.ChangeFileName(desc, descNewName) + m.getData()[descNewName] = desc + + m.registerWithParent(desc, 0) + } + for _, r := range removes { + delete(m.getData(), r) } + return nil } diff --git a/memmap_test.go b/memmap_test.go index 52a492e8..c47fadc8 100644 --- a/memmap_test.go +++ b/memmap_test.go @@ -833,3 +833,88 @@ func TestMemFsRenameDir(t *testing.T) { t.Errorf("Cannot recreate the subdir in the source dir: %s", err) } } + +func TestMemMapFsRename(t *testing.T) { + t.Parallel() + + fs := &MemMapFs{} + tDir := testDir(fs) + rFrom := "/renamefrom" + rTo := "/renameto" + rExists := "/renameexists" + + type test struct { + dirs []string + from string + to string + exists string + } + + parts := strings.Split(tDir, "/") + root := "/" + if len(parts) > 1 { + root = filepath.Join("/", parts[1]) + } + + testData := make([]test, 0, len(parts)) + + i := len(parts) + for i > 0 { + prefix := strings.Join(parts[:i], "/") + suffix := strings.Join(parts[i:], "/") + testData = append(testData, test{ + dirs: []string{ + filepath.Join(prefix, rFrom, suffix), + filepath.Join(prefix, rExists, suffix), + }, + from: filepath.Join(prefix, rFrom), + to: filepath.Join(prefix, rTo), + exists: filepath.Join(prefix, rExists), + }) + i-- + } + + for _, data := range testData { + err := fs.RemoveAll(root) + if err != nil { + t.Fatalf("%s: RemoveAll %q failed: %v", fs.Name(), root, err) + } + + for _, dir := range data.dirs { + err = fs.MkdirAll(dir, os.FileMode(0775)) + if err != nil { + t.Fatalf("%s: MkdirAll %q failed: %v", fs.Name(), dir, err) + } + } + + dataCnt := len(fs.getData()) + err = fs.Rename(data.from, data.to) + if err != nil { + t.Fatalf("%s: rename %q, %q failed: %v", fs.Name(), data.from, data.to, err) + } + err = fs.Mkdir(data.from, os.FileMode(0775)) + if err != nil { + t.Fatalf("%s: Mkdir %q failed: %v", fs.Name(), data.from, err) + } + + err = fs.Rename(data.from, data.exists) + if err != nil { + t.Errorf("%s: rename %q, %q failed: %v", fs.Name(), data.from, data.exists, err) + } + + for p := range fs.getData() { + if strings.Contains(p, data.from) { + t.Errorf("File was not renamed to renameto: %v", p) + } + } + + _, err = fs.Stat(data.to) + if err != nil { + t.Errorf("%s: stat %q failed: %v", fs.Name(), data.to, err) + } + + if dataCnt != len(fs.getData()) { + t.Errorf("invalid data len: expected %v, get %v", dataCnt, len(fs.getData())) + } + } +}