diff --git a/tool/golang.go b/tool/golang.go index d9fb081..b8cc721 100644 --- a/tool/golang.go +++ b/tool/golang.go @@ -4,6 +4,7 @@ import ( "go/ast" "go/parser" "go/token" + "io/ioutil" "os" "path/filepath" "runtime" @@ -31,8 +32,7 @@ func (b *GoBuilder) Build(rootDir string) error { return err } for i := range paths { - dir := filepath.Dir(paths[i]) - if err = buildDir(b.conf, dir); err != nil { + if err = buildDir(b.conf, paths[i]); err != nil { return err } } @@ -63,7 +63,7 @@ func walkMainDir(rootDir string) (paths []string, err error) { return err } if yes { - paths = append(paths, path) + paths = append(paths, filepath.Dir(path)) } return nil @@ -91,3 +91,51 @@ func hasMain(srcfile string) (bool, error) { } return false, nil } + +func walkPkgDir(rootDir string) (paths []string, err error) { + return paths, filepath.Walk(rootDir, func(path string, info os.FileInfo, e error) error { + if e != nil { + return e + } + + if !info.IsDir() { + return nil + } + + if info.Name() == "vendor" || + (runtime.GOOS != "windows" && strings.HasPrefix(info.Name(), ".")) { + return filepath.SkipDir + } + + yes, err := isGoPkg(path) + if err != nil { + return err + } + if yes { + paths = append(paths, path) + } + return nil + }) +} + +// isGoPkg 判断路径是否是golang的包 +func isGoPkg(path string) (yes bool, err error) { + path = strings.TrimSpace(path) + if path == "" { + return false, nil + } + infos, err := ioutil.ReadDir(path) + if err != nil { + return false, err + } + + for i := range infos { + if infos[i].IsDir() { + continue + } + if ext := filepath.Ext(infos[i].Name()); ext == ".go" { // TODO 排除go test目录 + return true, nil + } + } + return false, nil +} diff --git a/tool/golang_test.go b/tool/golang_test.go index 96179ad..665a5b3 100644 --- a/tool/golang_test.go +++ b/tool/golang_test.go @@ -1,13 +1,14 @@ package tool import ( - "testing" - + "errors" + "fmt" "os" - "path/filepath" "strings" + "testing" + "github.com/bouk/monkey" . "github.com/smartystreets/goconvey/convey" ) @@ -36,11 +37,96 @@ func TestWalkMainDir(t *testing.T) { dir, err := os.Getwd() So(err, ShouldBeNil) So(strings.HasSuffix(dir, filepath.Join("src", "github.com", "voidint", "gbb", "tool")), ShouldBeTrue) - workspace := strings.TrimRight(dir, "tool") + workspace := strings.TrimRight(dir, fmt.Sprintf("%ctool", os.PathSeparator)) paths, err := walkMainDir(workspace) So(err, ShouldBeNil) So(paths, ShouldNotBeEmpty) So(len(paths), ShouldEqual, 1) - So(paths[0], ShouldEqual, filepath.Join(workspace, "main.go")) + So(paths[0], ShouldEqual, workspace) + }) +} + +func TestWalkPkgDir(t *testing.T) { + Convey("查找指定目录及其子目录下所有满足golang包的目录路径", t, func() { + wd, err := os.Getwd() + So(err, ShouldBeNil) + paths, err := walkPkgDir(wd) + So(err, ShouldBeNil) + So(paths, ShouldNotBeEmpty) + So(len(paths), ShouldEqual, 1) + So(paths[0], ShouldEqual, wd) + + paths, err = walkPkgDir(strings.TrimRight(wd, "tool")) + So(err, ShouldBeNil) + So(paths, ShouldNotBeEmpty) + So(len(paths), ShouldEqual, 7) + So(paths, ShouldContain, strings.TrimRight(wd, "tool")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "build")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "cmd")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "config")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "tool")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "util")) + So(paths, ShouldContain, filepath.Join(strings.TrimRight(wd, "tool"), "variable")) + + Convey("检查指定路径是否是golang包路径报错", func() { + var ErrIsGoPkg = errors.New("error for test") + monkey.Patch(isGoPkg, func(path string) (yes bool, err error) { + return false, ErrIsGoPkg + }) + defer monkey.Unpatch(isGoPkg) + + paths, err := walkPkgDir(wd) + So(err, ShouldNotBeNil) + So(err, ShouldEqual, ErrIsGoPkg) + So(paths, ShouldBeEmpty) + }) + }) +} + +func TestIsGoPkg(t *testing.T) { + Convey("判断是否是golang包目录", t, func() { + wd, err := os.Getwd() + So(err, ShouldBeNil) + So(wd, ShouldNotBeBlank) + Convey("合法路径", func() { + Convey("路径下包含的全部是go源文件", func() { + yes, err := isGoPkg(wd) + So(err, ShouldBeNil) + So(yes, ShouldBeTrue) + }) + Convey("路径下包含的全部是目录,不包含任何go源文件", func() { + path := filepath.Join(wd, "test") + So(os.MkdirAll(filepath.Join(wd, "test", "subtest0"), 0755), ShouldBeNil) + So(os.MkdirAll(filepath.Join(wd, "test", "subtest1"), 0755), ShouldBeNil) + defer os.RemoveAll(path) + + yes, err := isGoPkg(path) + So(err, ShouldBeNil) + So(yes, ShouldBeFalse) + }) + + Convey("路径下既包含目录,还包含go源文件", func() { + yes, err := isGoPkg(strings.TrimRight(wd, "tool")) + So(err, ShouldBeNil) + So(yes, ShouldBeTrue) + }) + }) + Convey("非法路径", func() { + Convey("路径为空", func() { + yes, err := isGoPkg("") + So(err, ShouldBeNil) + So(yes, ShouldBeFalse) + }) + Convey("路径非目录", func() { + yes, err := isGoPkg(filepath.Join(wd, "golang_test.go")) + So(err, ShouldNotBeNil) + So(yes, ShouldBeFalse) + }) + Convey("路径不存在", func() { + yes, err := isGoPkg(filepath.Join(wd, "not_exist_dir")) + So(err, ShouldNotBeNil) + So(yes, ShouldBeFalse) + }) + }) }) }