diff options
Diffstat (limited to 'cmd/ponzu/options.go')
-rw-r--r-- | cmd/ponzu/options.go | 221 |
1 files changed, 206 insertions, 15 deletions
diff --git a/cmd/ponzu/options.go b/cmd/ponzu/options.go index 4d102c5..d71d362 100644 --- a/cmd/ponzu/options.go +++ b/cmd/ponzu/options.go @@ -4,10 +4,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "os" "os/exec" "path/filepath" "strings" + "time" ) func newProjectInDir(path string) error { @@ -19,18 +21,11 @@ func newProjectInDir(path string) error { if _, err := os.Stat(path); !os.IsNotExist(err) { fmt.Println("Path exists, overwrite contents? (y/N):") - var answer string - _, err := fmt.Scanf("%s\n", &answer) + answer, err := getAnswer() if err != nil { - if err.Error() == "unexpected newline" { - answer = "" - } else { - return err - } + return err } - answer = strings.ToLower(answer) - switch answer { case "n", "no", "\r\n", "\n", "": fmt.Println("") @@ -41,7 +36,7 @@ func newProjectInDir(path string) error { return fmt.Errorf("Failed to overwrite %s. \n%s", path, err) } - return createProjInDir(path) + return createProjectInDir(path) default: fmt.Println("Input not recognized. No files overwritten. Answer as 'y' or 'n' only.") @@ -50,16 +45,35 @@ func newProjectInDir(path string) error { return nil } - return createProjInDir(path) + return createProjectInDir(path) } var ponzuRepo = []string{"github.com", "ponzu-cms", "ponzu"} -func createProjInDir(path string) error { +func getAnswer() (string, error) { + var answer string + _, err := fmt.Scanf("%s\n", &answer) + if err != nil { + if err.Error() == "unexpected newline" { + answer = "" + } else { + return "", err + } + } + + answer = strings.ToLower(answer) + + return answer, nil +} + +func createProjectInDir(path string) error { gopath := os.Getenv("GOPATH") repo := ponzuRepo local := filepath.Join(gopath, "src", filepath.Join(repo...)) network := "https://" + strings.Join(repo, "/") + ".git" + if !strings.HasPrefix(path, gopath) { + path = filepath.Join(gopath, path) + } // create the directory or overwrite it err := os.MkdirAll(path, os.ModeDir|os.ModePerm) @@ -106,7 +120,7 @@ func createProjInDir(path string) error { } err = localClone.Wait() if err != nil { - fmt.Println("Couldn't clone from", local, ". Trying network...") + fmt.Println("Couldn't clone from", local, "- trying network...") // try to git clone the repository over the network networkClone := exec.Command("git", "clone", network, path) @@ -168,7 +182,7 @@ func vendorCorePackages(path string) error { return nil } -func copyFile(src, dst string) error { +func copyFileNoRoot(src, dst string) error { noRoot := strings.Split(src, string(filepath.Separator))[1:] path := filepath.Join(noRoot...) dstFile, err := os.Create(filepath.Join(dst, path)) @@ -226,7 +240,7 @@ func copyFilesWarnConflicts(srcDir, dstDir string, conflicts []string) error { return nil } - err = copyFile(path, dstDir) + err = copyFileNoRoot(path, dstDir) if err != nil { return err } @@ -313,3 +327,180 @@ func buildPonzuServer(args []string) error { return nil } + +func copyAll(src, dst string) error { + err := filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + sep := string(filepath.Separator) + + // base == the ponzu project dir + string(filepath.Separator) + parts := strings.Split(src, sep) + base := strings.Join(parts[:len(parts)-1], sep) + base += sep + + target := filepath.Join(dst, path[len(base):]) + + // if its a directory, make dir in dst + if info.IsDir() { + err := os.MkdirAll(target, os.ModeDir|os.ModePerm) + if err != nil { + return err + } + } else { + // if its a file, move file to dir of dst + err = os.Rename(path, target) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + return nil +} + +func upgradePonzuProjectDir(path string) error { + core := []string{ + ".gitattributes", + "LICENSE", + "ponzu-banner.png", + "README.md", + "cmd", + "deployment", + "management", + "system", + } + + stamp := fmt.Sprintf("ponzu-%d.bak", time.Now().Unix()) + temp := filepath.Join(os.TempDir(), stamp) + err := os.Mkdir(temp, os.ModeDir|os.ModePerm) + if err != nil { + return err + } + + // track non-Ponzu core items (added by user) + var user []os.FileInfo + list, err := ioutil.ReadDir(path) + if err != nil { + return err + } + + for _, item := range list { + // check if in core + var isCore bool + for _, name := range core { + if item.Name() == name { + isCore = true + break + } + } + + if !isCore { + user = append(user, item) + } + } + + // move non-Ponzu files to temp location + fmt.Println("Preserving files to be restored after upgrade...") + for _, item := range user { + src := filepath.Join(path, item.Name()) + if item.IsDir() { + err := os.Mkdir(filepath.Join(temp, item.Name()), os.ModeDir|os.ModePerm) + if err != nil { + return err + } + } + + err := copyAll(src, temp) + if err != nil { + return err + } + + fmt.Println(" [-]", item.Name()) + + } + + // remove all files in path + for _, item := range list { + err := os.RemoveAll(filepath.Join(path, item.Name())) + if err != nil { + return fmt.Errorf("Failed to remove old Ponzu files.\n%s", err) + } + } + + err = os.Chdir(temp) + if err != nil { + fmt.Println("Coudln't change directory to:", temp) + } + + // re-create the project dir + err = os.MkdirAll(path, os.ModeDir|os.ModePerm) + if err != nil { + return err + } + + err = os.Chdir(path) + if err != nil { + return err + } + + err = createProjectInDir(path) + if err != nil { + fmt.Println("") + fmt.Println("Upgrade failed...") + fmt.Println("Your code is backed up at the following location:") + fmt.Println(temp) + fmt.Println("") + fmt.Println("Manually create a new Ponzu project here and copy those files within it to fully restore.") + fmt.Println("") + return err + } + + // move non-Ponzu files from temp location backed + restore, err := ioutil.ReadDir(temp) + if err != nil { + return err + } + + fmt.Println("Restoring files preserved before upgrade...") + for _, r := range restore { + p := filepath.Join(temp, r.Name()) + err = copyAll(p, path) + if err != nil { + fmt.Println("Couldn't merge your previous project files with upgraded one.") + fmt.Println("Manually copy your files from the following directory:") + fmt.Println(temp) + return err + } + + fmt.Println(" [+]", r.Name()) + } + + // refresh dir to show files after restoring + err = os.Chdir(temp) + if err != nil { + return err + } + err = os.Chdir(path) + if err != nil { + return err + } + + // clean-up + backups := []string{filepath.Join(path, stamp), temp} + for _, bak := range backups { + err := os.RemoveAll(bak) + if err != nil { + return err + } + } + + return nil +} |