This adds several steps to the "write" action, by writing it to a temp location, extracting it, checksumming it, copying to destination, verifying checksum, and using Rename to atomically replace any existing file in the destination.
393 lines
8.8 KiB
Go
393 lines
8.8 KiB
Go
// Copyright 2021 Clayton Craft <clayton@craftyguy.net>
|
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
|
package archive
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"encoding/hex"
|
|
"io/ioutil"
|
|
"path/filepath"
|
|
"compress/flate"
|
|
"crypto/sha256"
|
|
"github.com/cavaliercoder/go-cpio"
|
|
"github.com/klauspost/pgzip"
|
|
"gitlab.com/postmarketOS/mkinitfs/pkgs/misc"
|
|
)
|
|
|
|
type Archive struct {
|
|
Dirs misc.StringSet
|
|
Files misc.StringSet
|
|
cpioWriter *cpio.Writer
|
|
buf *bytes.Buffer
|
|
}
|
|
|
|
func New() (*Archive, error) {
|
|
buf := new(bytes.Buffer)
|
|
archive := &Archive{
|
|
cpioWriter: cpio.NewWriter(buf),
|
|
Files: make(misc.StringSet),
|
|
Dirs: make(misc.StringSet),
|
|
buf: buf,
|
|
}
|
|
|
|
return archive, nil
|
|
}
|
|
|
|
func (archive *Archive) Write(path string, mode os.FileMode) error {
|
|
// Archive verification is done in these steps:
|
|
// 1. write archive to a temp location
|
|
// 2. checksum the temp archive
|
|
// 3. compare size of archive with amount of free space in target dir
|
|
// 4. extract the archive to make sure it's valid / can be extracted
|
|
// 5. copy archive to destination dir
|
|
// 6. checksum target, compare to temp file checksum
|
|
// 7. rename archive at destination to final target name
|
|
|
|
targetDir := filepath.Dir(path)
|
|
|
|
if err := archive.writeCpio(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := archive.cpioWriter.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 1. write archive to a temp location
|
|
tmpOutDir, err := ioutil.TempDir("", filepath.Base(path))
|
|
if err != nil {
|
|
log.Print("Unable to create temporary work dir")
|
|
return err
|
|
}
|
|
|
|
tmpOutFile := filepath.Join(tmpOutDir, filepath.Base(path))
|
|
|
|
if err := archive.writeCompressed(tmpOutFile, mode); err != nil {
|
|
return err
|
|
}
|
|
defer os.Remove(tmpOutFile)
|
|
|
|
// 2. checksum the temp archive
|
|
tmpFileChecksum, err := checksum(tmpOutFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 3. compare size of archive with amount of free space in target dir
|
|
tmpOutFileSize, err := os.Stat(tmpOutFile)
|
|
if err != nil {
|
|
log.Print("Unable to stat tmp output file: ", tmpOutFile)
|
|
}
|
|
actualFreeSpace, err := misc.FreeSpace(targetDir)
|
|
// leave 10% free at target, because we're not monsters
|
|
freeSpace := int64(float64(actualFreeSpace) * 0.9)
|
|
if err != nil {
|
|
log.Print("Unable to verify free space of target directory: ", targetDir)
|
|
return err
|
|
}
|
|
if tmpOutFileSize.Size() >= freeSpace {
|
|
return errors.New(fmt.Sprintf("Not enough free space in target dir (%q) for file. Need: %d bytes, free space: %d bytes",
|
|
targetDir, tmpOutFileSize.Size(), freeSpace))
|
|
}
|
|
|
|
// 4. extract the archive to make sure it's valid / can be extracted
|
|
extractDir, err := ioutil.TempDir(tmpOutDir, "extract-test")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer os.RemoveAll(extractDir)
|
|
if err := extract(tmpOutFile, extractDir); err != nil {
|
|
log.Print("Extraction of archive failed!")
|
|
return err
|
|
}
|
|
|
|
// 5. copy archive to destination dir
|
|
tmpTargetFileFd, err := ioutil.TempFile(targetDir, filepath.Base(path))
|
|
if err != nil {
|
|
log.Print("Unable to create temp file in target dir: ", targetDir)
|
|
return err
|
|
}
|
|
tmpOutFileFd, err := os.Open(tmpOutFile)
|
|
if err != nil {
|
|
log.Print("Unable to open temp file", tmpOutFile)
|
|
return err
|
|
}
|
|
defer tmpOutFileFd.Close()
|
|
if _, err := io.Copy(tmpTargetFileFd, tmpOutFileFd); err != nil {
|
|
return err
|
|
}
|
|
// fsync
|
|
if err := tmpTargetFileFd.Sync(); err != nil {
|
|
log.Print("Unable to call fsync on temp file: ", targetDir)
|
|
return err
|
|
}
|
|
if err := tmpTargetFileFd.Close(); err != nil {
|
|
log.Print("Unable to save temp file to target dir: ", targetDir)
|
|
return err
|
|
}
|
|
|
|
// 6. checksum target, compare to temp file checksum
|
|
targetFileChecksum, err := checksum(tmpTargetFileFd.Name())
|
|
|
|
if tmpFileChecksum != targetFileChecksum {
|
|
return errors.New(fmt.Sprintf("Unable to save archive to path %q, checksum mismatch (expected: %q, got: %q)",
|
|
path, tmpFileChecksum, targetFileChecksum))
|
|
}
|
|
|
|
// 7. rename archive at destination to final target name
|
|
if err := os.Rename(tmpTargetFileFd.Name(), path); err != nil {
|
|
log.Print("Unable to save archive to path: ", path)
|
|
return err
|
|
}
|
|
if err := os.Chmod(path, mode); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func checksum(path string) (string, error) {
|
|
var sum string
|
|
|
|
buf := make([]byte, 64*1024)
|
|
sha256 := sha256.New()
|
|
fd, err := os.Open(path)
|
|
defer fd.Close()
|
|
|
|
if err != nil {
|
|
log.Print("Unable to checksum: ", path)
|
|
return sum, err
|
|
}
|
|
|
|
// Read file in chunks
|
|
for {
|
|
bytes, err := fd.Read(buf)
|
|
if bytes > 0 {
|
|
_, err := sha256.Write(buf[:bytes])
|
|
if err != nil {
|
|
log.Print("Unable to checksum: ", path)
|
|
return sum, err
|
|
}
|
|
}
|
|
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
}
|
|
sum = hex.EncodeToString(sha256.Sum(nil))
|
|
return sum, nil
|
|
}
|
|
|
|
func (archive *Archive) AddFile(file string, dest string) error {
|
|
if err := archive.addDir(filepath.Dir(dest)); err != nil {
|
|
return err
|
|
}
|
|
|
|
if archive.Files[file] {
|
|
// Already written to cpio
|
|
return nil
|
|
}
|
|
|
|
fileStat, err := os.Lstat(file)
|
|
if err != nil {
|
|
log.Print("AddFile: failed to stat file: ", file)
|
|
return err
|
|
}
|
|
|
|
// Symlink: write symlink to archive then set 'file' to link target
|
|
if fileStat.Mode()&os.ModeSymlink != 0 {
|
|
// log.Printf("File %q is a symlink", file)
|
|
target, err := os.Readlink(file)
|
|
if err != nil {
|
|
log.Print("AddFile: failed to get symlink target: ", file)
|
|
return err
|
|
}
|
|
|
|
destFilename := strings.TrimPrefix(dest, "/")
|
|
hdr := &cpio.Header{
|
|
Name: destFilename,
|
|
Linkname: target,
|
|
Mode: 0644 | cpio.ModeSymlink,
|
|
Size: int64(len(target)),
|
|
// Checksum: 1,
|
|
}
|
|
if err := archive.cpioWriter.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
if _, err = archive.cpioWriter.Write([]byte(target)); err != nil {
|
|
return err
|
|
}
|
|
|
|
archive.Files[file] = true
|
|
if filepath.Dir(target) == "." {
|
|
target = filepath.Join(filepath.Dir(file), target)
|
|
}
|
|
// make sure target is an absolute path
|
|
if !filepath.IsAbs(target) {
|
|
target, err = misc.RelativeSymlinkTargetToDir(target, filepath.Dir(file))
|
|
}
|
|
// TODO: add verbose mode, print stuff like this:
|
|
// log.Printf("symlink: %q, target: %q", file, target)
|
|
// write symlink target
|
|
err = archive.AddFile(target, target)
|
|
return err
|
|
}
|
|
|
|
// log.Printf("writing file: %q", file)
|
|
|
|
fd, err := os.Open(file)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer fd.Close()
|
|
|
|
destFilename := strings.TrimPrefix(dest, "/")
|
|
hdr := &cpio.Header{
|
|
Name: destFilename,
|
|
Mode: cpio.FileMode(fileStat.Mode().Perm()),
|
|
Size: fileStat.Size(),
|
|
// Checksum: 1,
|
|
}
|
|
if err := archive.cpioWriter.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = io.Copy(archive.cpioWriter, fd); err != nil {
|
|
return err
|
|
}
|
|
|
|
archive.Files[file] = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func extract(path string, dest string) error {
|
|
tDir, err := ioutil.TempDir("", "archive-extract")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer os.RemoveAll(tDir)
|
|
|
|
srcFd, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer srcFd.Close()
|
|
|
|
// TODO: support more compression types
|
|
gz, err := pgzip.NewReader(srcFd)
|
|
|
|
cpioArchive := cpio.NewReader(gz)
|
|
for {
|
|
hdr, err := cpioArchive.Next()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
destPath := filepath.Join(dest, hdr.Name)
|
|
if hdr.Mode.IsDir() {
|
|
os.MkdirAll(destPath, 0755)
|
|
} else {
|
|
destFd, err := os.Create(destPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer destFd.Close()
|
|
if _, err := io.Copy(destFd, cpioArchive); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (archive *Archive) writeCompressed(path string, mode os.FileMode) error {
|
|
// TODO: support other compression formats, based on deviceinfo
|
|
fd, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
gz, err := pgzip.NewWriterLevel(fd, flate.BestSpeed)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = io.Copy(gz, archive.buf); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := gz.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// call fsync just to be sure
|
|
if err := fd.Sync(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := os.Chmod(path, mode); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (archive *Archive) writeCpio() error {
|
|
// Write any dirs added explicitly
|
|
for dir := range archive.Dirs {
|
|
archive.addDir(dir)
|
|
}
|
|
|
|
// Write files and any missing parent dirs
|
|
for file, imported := range archive.Files {
|
|
if imported {
|
|
continue
|
|
}
|
|
if err := archive.AddFile(file, file); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (archive *Archive) addDir(dir string) error {
|
|
if archive.Dirs[dir] {
|
|
// Already imported
|
|
return nil
|
|
}
|
|
if dir == "/" {
|
|
dir = "."
|
|
}
|
|
|
|
subdirs := strings.Split(strings.TrimPrefix(dir, "/"), "/")
|
|
for i, subdir := range subdirs {
|
|
path := filepath.Join(strings.Join(subdirs[:i], "/"), subdir)
|
|
if archive.Dirs[path] {
|
|
// Subdir already imported
|
|
continue
|
|
}
|
|
err := archive.cpioWriter.WriteHeader(&cpio.Header{
|
|
Name: path,
|
|
Mode: cpio.ModeDir | 0755,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
archive.Dirs[path] = true
|
|
// log.Print("wrote dir: ", path)
|
|
}
|
|
|
|
return nil
|
|
}
|