pkg/archive/Write: implement writing of archive with verification

This adds several steps to the "write" action, by writing it to a temp
location, extracting it, checksumming it, copying to destination,
verifying checksum, and using Rename to atomically replace any existing
file in the destination.
This commit is contained in:
Clayton Craft
2021-08-10 13:23:56 -07:00
parent 1716445e9d
commit f0bf13c9f2

View File

@@ -4,6 +4,8 @@ package archive
import ( import (
"bytes" "bytes"
"errors"
"fmt"
"io" "io"
"log" "log"
"os" "os"
@@ -38,6 +40,16 @@ func New() (*Archive, error) {
} }
func (archive *Archive) Write(path string, mode os.FileMode) error { func (archive *Archive) Write(path string, mode os.FileMode) error {
// Archive verification is done in these steps:
// 1. write archive to a temp location
// 2. checksum the temp archive
// 3. compare size of archive with amount of free space in target dir
// 4. extract the archive to make sure it's valid / can be extracted
// 5. copy archive to destination dir
// 6. checksum target, compare to temp file checksum
// 7. rename archive at destination to final target name
targetDir := filepath.Dir(path)
if err := archive.writeCpio(); err != nil { if err := archive.writeCpio(); err != nil {
return err return err
@@ -47,7 +59,93 @@ func (archive *Archive) Write(path string, mode os.FileMode) error {
return err return err
} }
if err := archive.writeCompressed(path, mode); err != nil { // 1. write archive to a temp location
tmpOutDir, err := ioutil.TempDir("", filepath.Base(path))
if err != nil {
log.Print("Unable to create temporary work dir")
return err
}
tmpOutFile := filepath.Join(tmpOutDir, filepath.Base(path))
if err := archive.writeCompressed(tmpOutFile, mode); err != nil {
return err
}
defer os.Remove(tmpOutFile)
// 2. checksum the temp archive
tmpFileChecksum, err := checksum(tmpOutFile)
if err != nil {
return err
}
// 3. compare size of archive with amount of free space in target dir
tmpOutFileSize, err := os.Stat(tmpOutFile)
if err != nil {
log.Print("Unable to stat tmp output file: ", tmpOutFile)
}
actualFreeSpace, err := misc.FreeSpace(targetDir)
// leave 10% free at target, because we're not monsters
freeSpace := int64(float64(actualFreeSpace) * 0.9)
if err != nil {
log.Print("Unable to verify free space of target directory: ", targetDir)
return err
}
if tmpOutFileSize.Size() >= freeSpace {
return errors.New(fmt.Sprintf("Not enough free space in target dir (%q) for file. Need: %d bytes, free space: %d bytes",
targetDir, tmpOutFileSize.Size(), freeSpace))
}
// 4. extract the archive to make sure it's valid / can be extracted
extractDir, err := ioutil.TempDir(tmpOutDir, "extract-test")
if err != nil {
return err
}
defer os.RemoveAll(extractDir)
if err := extract(tmpOutFile, extractDir); err != nil {
log.Print("Extraction of archive failed!")
return err
}
// 5. copy archive to destination dir
tmpTargetFileFd, err := ioutil.TempFile(targetDir, filepath.Base(path))
if err != nil {
log.Print("Unable to create temp file in target dir: ", targetDir)
return err
}
tmpOutFileFd, err := os.Open(tmpOutFile)
if err != nil {
log.Print("Unable to open temp file", tmpOutFile)
return err
}
defer tmpOutFileFd.Close()
if _, err := io.Copy(tmpTargetFileFd, tmpOutFileFd); err != nil {
return err
}
// fsync
if err := tmpTargetFileFd.Sync(); err != nil {
log.Print("Unable to call fsync on temp file: ", targetDir)
return err
}
if err := tmpTargetFileFd.Close(); err != nil {
log.Print("Unable to save temp file to target dir: ", targetDir)
return err
}
// 6. checksum target, compare to temp file checksum
targetFileChecksum, err := checksum(tmpTargetFileFd.Name())
if tmpFileChecksum != targetFileChecksum {
return errors.New(fmt.Sprintf("Unable to save archive to path %q, checksum mismatch (expected: %q, got: %q)",
path, tmpFileChecksum, targetFileChecksum))
}
// 7. rename archive at destination to final target name
if err := os.Rename(tmpTargetFileFd.Name(), path); err != nil {
log.Print("Unable to save archive to path: ", path)
return err
}
if err := os.Chmod(path, mode); err != nil {
return err return err
} }