misc/getfiles: add tests for getFile

This adds some tests for getFile, one of which would have caught
the recent recursion issue and other will hopefully catch future
regressions.

Part-of: https://gitlab.postmarketos.org/postmarketOS/postmarketos-mkinitfs/-/merge_requests/65

[ci:skip-build]: already built successfully in CI
This commit is contained in:
Clayton Craft
2025-08-03 22:40:42 -07:00
parent 7a07a16ecb
commit f6e4773507

View File

@@ -0,0 +1,149 @@
// Copyright 2025 Clayton Craft <clayton@craftyguy.net>
// SPDX-License-Identifier: GPL-3.0-or-later
package misc
import (
"os"
"path/filepath"
"reflect"
"sort"
"testing"
"time"
)
func TestGetFile(t *testing.T) {
subtests := []struct {
name string
setup func(tmpDir string) (inputPath string, expectedFiles []string, err error)
required bool
}{
{
name: "symlink to directory - no infinite recursion",
setup: func(tmpDir string) (string, []string, error) {
// Create target directory with files
targetDir := filepath.Join(tmpDir, "target")
if err := os.MkdirAll(targetDir, 0755); err != nil {
return "", nil, err
}
testFile1 := filepath.Join(targetDir, "file1.txt")
testFile2 := filepath.Join(targetDir, "file2.txt")
if err := os.WriteFile(testFile1, []byte("content1"), 0644); err != nil {
return "", nil, err
}
if err := os.WriteFile(testFile2, []byte("content2"), 0644); err != nil {
return "", nil, err
}
// Create symlink pointing to target directory
symlinkPath := filepath.Join(tmpDir, "symlink")
if err := os.Symlink(targetDir, symlinkPath); err != nil {
return "", nil, err
}
expected := []string{symlinkPath, testFile1, testFile2}
return symlinkPath, expected, nil
},
required: true,
},
{
name: "symlink to file - returns both symlink and target",
setup: func(tmpDir string) (string, []string, error) {
// Create target file
targetFile := filepath.Join(tmpDir, "target.txt")
if err := os.WriteFile(targetFile, []byte("content"), 0644); err != nil {
return "", nil, err
}
// Create symlink pointing to target file
symlinkPath := filepath.Join(tmpDir, "symlink.txt")
if err := os.Symlink(targetFile, symlinkPath); err != nil {
return "", nil, err
}
expected := []string{symlinkPath, targetFile}
return symlinkPath, expected, nil
},
required: true,
},
{
name: "regular file",
setup: func(tmpDir string) (string, []string, error) {
regularFile := filepath.Join(tmpDir, "regular.txt")
if err := os.WriteFile(regularFile, []byte("content"), 0644); err != nil {
return "", nil, err
}
expected := []string{regularFile}
return regularFile, expected, nil
},
required: true,
},
{
name: "regular directory",
setup: func(tmpDir string) (string, []string, error) {
// Create directory with files
dirPath := filepath.Join(tmpDir, "testdir")
if err := os.MkdirAll(dirPath, 0755); err != nil {
return "", nil, err
}
file1 := filepath.Join(dirPath, "file1.txt")
file2 := filepath.Join(dirPath, "subdir", "file2.txt")
if err := os.WriteFile(file1, []byte("content1"), 0644); err != nil {
return "", nil, err
}
if err := os.MkdirAll(filepath.Dir(file2), 0755); err != nil {
return "", nil, err
}
if err := os.WriteFile(file2, []byte("content2"), 0644); err != nil {
return "", nil, err
}
expected := []string{file1, file2}
return dirPath, expected, nil
},
required: true,
},
}
for _, st := range subtests {
t.Run(st.name, func(t *testing.T) {
tmpDir := t.TempDir()
inputPath, expectedFiles, err := st.setup(tmpDir)
if err != nil {
t.Fatalf("setup failed: %v", err)
}
// Add timeout protection for infinite recursion test
done := make(chan struct{})
var files []string
var getFileErr error
go func() {
defer close(done)
files, getFileErr = getFile(inputPath, st.required)
}()
select {
case <-done:
if getFileErr != nil {
t.Fatalf("getFile failed: %v", getFileErr)
}
case <-time.After(5 * time.Second):
t.Fatal("getFile appears to be in infinite recursion (timeout)")
}
// Sort for comparison
sort.Strings(expectedFiles)
sort.Strings(files)
if !reflect.DeepEqual(expectedFiles, files) {
t.Fatalf("expected: %q, got: %q", expectedFiles, files)
}
})
}
}