diff --git a/main.go b/main.go index 6494bd5..74f6fbe 100644 --- a/main.go +++ b/main.go @@ -674,7 +674,9 @@ func generateInitfs(name string, path string, kernVer string, devinfo deviceinfo "/dev", "/tmp", "/lib", "/boot", "/sysroot", "/etc", } for _, dir := range requiredDirs { - initfsArchive.Dirs[dir] = false + if err := initfsArchive.AddItem(dir, dir); err != nil { + return err + } } if files, err := getInitfsFiles(devinfo); err != nil { diff --git a/pkgs/archive/archive.go b/pkgs/archive/archive.go index 01fd848..bff2933 100644 --- a/pkgs/archive/archive.go +++ b/pkgs/archive/archive.go @@ -6,20 +6,23 @@ package archive import ( "bytes" "compress/flate" - "github.com/cavaliercoder/go-cpio" - "github.com/klauspost/pgzip" - "gitlab.com/postmarketOS/postmarketos-mkinitfs/pkgs/misc" "fmt" "io" "log" "os" "path/filepath" + "sort" "strings" + "sync" + "syscall" + + "github.com/cavaliercoder/go-cpio" + "github.com/klauspost/pgzip" + "gitlab.com/postmarketOS/postmarketos-mkinitfs/pkgs/misc" ) type Archive struct { - Dirs misc.StringSet - Files misc.StringSet + items archiveItems cpioWriter *cpio.Writer buf *bytes.Buffer } @@ -28,8 +31,6 @@ func New() (*Archive, error) { buf := new(bytes.Buffer) archive := &Archive{ cpioWriter: cpio.NewWriter(buf), - Files: make(misc.StringSet), - Dirs: make(misc.StringSet), buf: buf, } @@ -41,6 +42,60 @@ type archiveItem struct { header *cpio.Header } +type archiveItems struct { + items []archiveItem + sync.RWMutex +} + +// Adds the given item to the archiveItems, only if it doesn't already exist in +// the list. The items are kept sorted in ascending order. +func (a *archiveItems) Add(item archiveItem) { + a.Lock() + defer a.Unlock() + + if len(a.items) < 1 { + // empty list + a.items = append(a.items, item) + return + } + + // find existing item, or index of where new item should go + i := sort.Search(len(a.items), func(i int) bool { + return strings.Compare(item.header.Name, a.items[i].header.Name) <= 0 + }) + + if i >= len(a.items) { + // doesn't exist in list, but would be at the very end + a.items = append(a.items, item) + return + } + + if strings.Compare(a.items[i].header.Name, item.header.Name) == 0 { + // already in list + return + } + + // grow list by 1, shift right at index, and insert new string at index + a.items = append(a.items, archiveItem{}) + copy(a.items[i+1:], a.items[i:]) + a.items[i] = item +} + +// iterate through items and send each one over the returned channel +func (a *archiveItems) IterItems() <-chan archiveItem { + ch := make(chan archiveItem) + go func() { + a.RLock() + defer a.RUnlock() + + for _, item := range a.items { + ch <- item + } + close(ch) + }() + return ch +} + func (archive *Archive) Write(path string, mode os.FileMode) error { if err := archive.writeCpio(); err != nil { return err @@ -78,6 +133,11 @@ func (archive *Archive) AddItem(source string, dest string) error { sourceStat, err := os.Lstat(source) if err != nil { + e, ok := err.(*os.PathError) + if e.Err == syscall.ENOENT && ok { + // doesn't exist in current filesystem, assume it's a new directory + return archive.addDir(dest) + } return fmt.Errorf("AddItem: failed to get stat for %q: %w", source, err) } @@ -93,11 +153,6 @@ func (archive *Archive) addFile(source string, dest string) error { return err } - if archive.Files[source] { - // Already written to cpio - return nil - } - sourceStat, err := os.Lstat(source) if err != nil { log.Print("addFile: failed to stat file: ", source) @@ -114,21 +169,18 @@ func (archive *Archive) addFile(source string, dest string) error { } destFilename := strings.TrimPrefix(dest, "/") - hdr := &cpio.Header{ - Name: destFilename, - Linkname: target, - Mode: 0644 | cpio.ModeSymlink, - Size: int64(len(target)), - // Checksum: 1, - } - if err := archive.cpioWriter.WriteHeader(hdr); err != nil { - return err - } - if _, err = archive.cpioWriter.Write([]byte(target)); err != nil { - return err - } - archive.Files[source] = true + archive.items.Add(archiveItem{ + sourcePath: source, + header: &cpio.Header{ + Name: destFilename, + Linkname: target, + Mode: 0644 | cpio.ModeSymlink, + Size: int64(len(target)), + // Checksum: 1, + }, + }) + if filepath.Dir(target) == "." { target = filepath.Join(filepath.Dir(source), target) } @@ -146,30 +198,17 @@ func (archive *Archive) addFile(source string, dest string) error { return err } - // log.Printf("writing file: %q", file) - - fd, err := os.Open(source) - if err != nil { - return err - } - defer fd.Close() - destFilename := strings.TrimPrefix(dest, "/") - hdr := &cpio.Header{ - Name: destFilename, - Mode: cpio.FileMode(sourceStat.Mode().Perm()), - Size: sourceStat.Size(), - // Checksum: 1, - } - if err := archive.cpioWriter.WriteHeader(hdr); err != nil { - return err - } - if _, err = io.Copy(archive.cpioWriter, fd); err != nil { - return err - } - - archive.Files[source] = true + archive.items.Add(archiveItem{ + sourcePath: source, + header: &cpio.Header{ + Name: destFilename, + Mode: cpio.FileMode(sourceStat.Mode().Perm()), + Size: sourceStat.Size(), + // Checksum: 1, + }, + }) return nil } @@ -207,29 +246,48 @@ func (archive *Archive) writeCompressed(path string, mode os.FileMode) error { } func (archive *Archive) writeCpio() error { - // Write any dirs added explicitly - for dir := range archive.Dirs { - archive.addDir(dir) + // having a transient function for actually adding files to the archive + // allows the deferred fd.close to run after every copy and prevent having + // tons of open file handles until the copying is all done + copyToArchive := func(source string, header *cpio.Header) error { + + if err := archive.cpioWriter.WriteHeader(header); err != nil { + return fmt.Errorf("archive.writeCpio: unable to write header: %w", err) + } + + // don't copy actual dirs into the archive, writing the header is enough + if !header.Mode.IsDir() { + if header.Mode.IsRegular() { + fd, err := os.Open(source) + if err != nil { + return fmt.Errorf("archive.writeCpio: uname to open file %q, %w", source, err) + } + defer fd.Close() + if _, err := io.Copy(archive.cpioWriter, fd); err != nil { + return fmt.Errorf("archive.writeCpio: unable to write out archive: %w", err) + } + } else if header.Linkname != "" { + // the contents of a symlink is just need the link name + if _, err := archive.cpioWriter.Write([]byte(header.Linkname)); err != nil { + return fmt.Errorf("archive.writeCpio: unable to write out symlink: %w", err) + } + } else { + return fmt.Errorf("archive.writeCpio: unknown type for file: %s", source) + } + } + + return nil } - // Write files and any missing parent dirs - for file, imported := range archive.Files { - if imported { - continue - } - if err := archive.addFile(file, file); err != nil { + for i := range archive.items.IterItems() { + if err := copyToArchive(i.sourcePath, i.header); err != nil { return err } } - return nil } func (archive *Archive) addDir(dir string) error { - if archive.Dirs[dir] { - // Already imported - return nil - } if dir == "/" { dir = "." } @@ -237,19 +295,13 @@ func (archive *Archive) addDir(dir string) error { subdirs := strings.Split(strings.TrimPrefix(dir, "/"), "/") for i, subdir := range subdirs { path := filepath.Join(strings.Join(subdirs[:i], "/"), subdir) - if archive.Dirs[path] { - // Subdir already imported - continue - } - err := archive.cpioWriter.WriteHeader(&cpio.Header{ - Name: path, - Mode: cpio.ModeDir | 0755, + archive.items.Add(archiveItem{ + sourcePath: path, + header: &cpio.Header{ + Name: path, + Mode: cpio.ModeDir | 0755, + }, }) - if err != nil { - return err - } - archive.Dirs[path] = true - // log.Print("wrote dir: ", path) } return nil diff --git a/pkgs/archive/archive_test.go b/pkgs/archive/archive_test.go new file mode 100644 index 0000000..f44cd24 --- /dev/null +++ b/pkgs/archive/archive_test.go @@ -0,0 +1,189 @@ +// Copyright 2022 Clayton Craft +// SPDX-License-Identifier: GPL-3.0-or-later + +package archive + +import ( + "reflect" + "testing" + + "github.com/cavaliercoder/go-cpio" +) + +func TestArchiveItemsAdd(t *testing.T) { + subtests := []struct { + name string + inItems []archiveItem + inItem archiveItem + expected []archiveItem + }{ + { + name: "empty list", + inItems: []archiveItem{}, + inItem: archiveItem{ + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + expected: []archiveItem{ + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + }, + }, + { + name: "already exists", + inItems: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + }, + inItem: archiveItem{ + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + expected: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + }, + }, + { + name: "add new", + inItems: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + { + sourcePath: "/foo/bar1", + header: &cpio.Header{Name: "/foo/bar1"}, + }, + }, + inItem: archiveItem{ + sourcePath: "/foo/bar0", + header: &cpio.Header{Name: "/foo/bar0"}, + }, + expected: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + { + sourcePath: "/foo/bar0", + header: &cpio.Header{Name: "/foo/bar0"}, + }, + { + sourcePath: "/foo/bar1", + header: &cpio.Header{Name: "/foo/bar1"}, + }, + }, + }, + { + name: "add new at beginning", + inItems: []archiveItem{ + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + }, + inItem: archiveItem{ + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + expected: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/foo/bar", + header: &cpio.Header{Name: "/foo/bar"}, + }, + }, + }, + { + name: "add new at end", + inItems: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + }, + inItem: archiveItem{ + sourcePath: "/zzz/bazz", + header: &cpio.Header{Name: "/zzz/bazz"}, + }, + expected: []archiveItem{ + { + sourcePath: "/bazz/bar", + header: &cpio.Header{Name: "/bazz/bar"}, + }, + { + sourcePath: "/foo", + header: &cpio.Header{Name: "/foo"}, + }, + { + sourcePath: "/zzz/bazz", + header: &cpio.Header{Name: "/zzz/bazz"}, + }, + }, + }, + } + + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + a := archiveItems{items: st.inItems} + a.Add(st.inItem) + if !reflect.DeepEqual(st.expected, a.items) { + t.Fatal("expected:", st.expected, " got: ", a.items) + } + }) + } +}