diff --git a/pkgs/archive/archive.go b/pkgs/archive/archive.go index 126f33a..30bc818 100644 --- a/pkgs/archive/archive.go +++ b/pkgs/archive/archive.go @@ -4,6 +4,8 @@ package archive import ( "bytes" + "errors" + "fmt" "io" "log" "os" @@ -38,6 +40,16 @@ func New() (*Archive, error) { } func (archive *Archive) Write(path string, mode os.FileMode) error { + // Archive verification is done in these steps: + // 1. write archive to a temp location + // 2. checksum the temp archive + // 3. compare size of archive with amount of free space in target dir + // 4. extract the archive to make sure it's valid / can be extracted + // 5. copy archive to destination dir + // 6. checksum target, compare to temp file checksum + // 7. rename archive at destination to final target name + + targetDir := filepath.Dir(path) if err := archive.writeCpio(); err != nil { return err @@ -47,7 +59,93 @@ func (archive *Archive) Write(path string, mode os.FileMode) error { return err } - if err := archive.writeCompressed(path, mode); err != nil { + // 1. write archive to a temp location + tmpOutDir, err := ioutil.TempDir("", filepath.Base(path)) + if err != nil { + log.Print("Unable to create temporary work dir") + return err + } + + tmpOutFile := filepath.Join(tmpOutDir, filepath.Base(path)) + + if err := archive.writeCompressed(tmpOutFile, mode); err != nil { + return err + } + defer os.Remove(tmpOutFile) + + // 2. checksum the temp archive + tmpFileChecksum, err := checksum(tmpOutFile) + if err != nil { + return err + } + + // 3. compare size of archive with amount of free space in target dir + tmpOutFileSize, err := os.Stat(tmpOutFile) + if err != nil { + log.Print("Unable to stat tmp output file: ", tmpOutFile) + } + actualFreeSpace, err := misc.FreeSpace(targetDir) + // leave 10% free at target, because we're not monsters + freeSpace := int64(float64(actualFreeSpace) * 0.9) + if err != nil { + log.Print("Unable to verify free space of target directory: ", targetDir) + return err + } + if tmpOutFileSize.Size() >= freeSpace { + return errors.New(fmt.Sprintf("Not enough free space in target dir (%q) for file. Need: %d bytes, free space: %d bytes", + targetDir, tmpOutFileSize.Size(), freeSpace)) + } + + // 4. extract the archive to make sure it's valid / can be extracted + extractDir, err := ioutil.TempDir(tmpOutDir, "extract-test") + if err != nil { + return err + } + defer os.RemoveAll(extractDir) + if err := extract(tmpOutFile, extractDir); err != nil { + log.Print("Extraction of archive failed!") + return err + } + + // 5. copy archive to destination dir + tmpTargetFileFd, err := ioutil.TempFile(targetDir, filepath.Base(path)) + if err != nil { + log.Print("Unable to create temp file in target dir: ", targetDir) + return err + } + tmpOutFileFd, err := os.Open(tmpOutFile) + if err != nil { + log.Print("Unable to open temp file", tmpOutFile) + return err + } + defer tmpOutFileFd.Close() + if _, err := io.Copy(tmpTargetFileFd, tmpOutFileFd); err != nil { + return err + } + // fsync + if err := tmpTargetFileFd.Sync(); err != nil { + log.Print("Unable to call fsync on temp file: ", targetDir) + return err + } + if err := tmpTargetFileFd.Close(); err != nil { + log.Print("Unable to save temp file to target dir: ", targetDir) + return err + } + + // 6. checksum target, compare to temp file checksum + targetFileChecksum, err := checksum(tmpTargetFileFd.Name()) + + if tmpFileChecksum != targetFileChecksum { + return errors.New(fmt.Sprintf("Unable to save archive to path %q, checksum mismatch (expected: %q, got: %q)", + path, tmpFileChecksum, targetFileChecksum)) + } + + // 7. rename archive at destination to final target name + if err := os.Rename(tmpTargetFileFd.Name(), path); err != nil { + log.Print("Unable to save archive to path: ", path) + return err + } + if err := os.Chmod(path, mode); err != nil { return err }