From f6e477350737a95c9a05c9899941542a946347bf Mon Sep 17 00:00:00 2001 From: Clayton Craft Date: Sun, 3 Aug 2025 22:40:42 -0700 Subject: [PATCH] misc/getfiles: add tests for getFile This adds some tests for getFile, one of which would have caught the recent recursion issue and other will hopefully catch future regressions. Part-of: https://gitlab.postmarketos.org/postmarketOS/postmarketos-mkinitfs/-/merge_requests/65 [ci:skip-build]: already built successfully in CI --- internal/misc/getfiles_test.go | 149 +++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 internal/misc/getfiles_test.go diff --git a/internal/misc/getfiles_test.go b/internal/misc/getfiles_test.go new file mode 100644 index 0000000..860c430 --- /dev/null +++ b/internal/misc/getfiles_test.go @@ -0,0 +1,149 @@ +// Copyright 2025 Clayton Craft +// SPDX-License-Identifier: GPL-3.0-or-later + +package misc + +import ( + "os" + "path/filepath" + "reflect" + "sort" + "testing" + "time" +) + +func TestGetFile(t *testing.T) { + subtests := []struct { + name string + setup func(tmpDir string) (inputPath string, expectedFiles []string, err error) + required bool + }{ + { + name: "symlink to directory - no infinite recursion", + setup: func(tmpDir string) (string, []string, error) { + // Create target directory with files + targetDir := filepath.Join(tmpDir, "target") + if err := os.MkdirAll(targetDir, 0755); err != nil { + return "", nil, err + } + + testFile1 := filepath.Join(targetDir, "file1.txt") + testFile2 := filepath.Join(targetDir, "file2.txt") + if err := os.WriteFile(testFile1, []byte("content1"), 0644); err != nil { + return "", nil, err + } + if err := os.WriteFile(testFile2, []byte("content2"), 0644); err != nil { + return "", nil, err + } + + // Create symlink pointing to target directory + symlinkPath := filepath.Join(tmpDir, "symlink") + if err := os.Symlink(targetDir, symlinkPath); err != nil { + return "", nil, err + } + + expected := []string{symlinkPath, testFile1, testFile2} + return symlinkPath, expected, nil + }, + required: true, + }, + { + name: "symlink to file - returns both symlink and target", + setup: func(tmpDir string) (string, []string, error) { + // Create target file + targetFile := filepath.Join(tmpDir, "target.txt") + if err := os.WriteFile(targetFile, []byte("content"), 0644); err != nil { + return "", nil, err + } + + // Create symlink pointing to target file + symlinkPath := filepath.Join(tmpDir, "symlink.txt") + if err := os.Symlink(targetFile, symlinkPath); err != nil { + return "", nil, err + } + + expected := []string{symlinkPath, targetFile} + return symlinkPath, expected, nil + }, + required: true, + }, + { + name: "regular file", + setup: func(tmpDir string) (string, []string, error) { + regularFile := filepath.Join(tmpDir, "regular.txt") + if err := os.WriteFile(regularFile, []byte("content"), 0644); err != nil { + return "", nil, err + } + + expected := []string{regularFile} + return regularFile, expected, nil + }, + required: true, + }, + { + name: "regular directory", + setup: func(tmpDir string) (string, []string, error) { + // Create directory with files + dirPath := filepath.Join(tmpDir, "testdir") + if err := os.MkdirAll(dirPath, 0755); err != nil { + return "", nil, err + } + + file1 := filepath.Join(dirPath, "file1.txt") + file2 := filepath.Join(dirPath, "subdir", "file2.txt") + + if err := os.WriteFile(file1, []byte("content1"), 0644); err != nil { + return "", nil, err + } + if err := os.MkdirAll(filepath.Dir(file2), 0755); err != nil { + return "", nil, err + } + if err := os.WriteFile(file2, []byte("content2"), 0644); err != nil { + return "", nil, err + } + + expected := []string{file1, file2} + return dirPath, expected, nil + }, + required: true, + }, + } + + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + tmpDir := t.TempDir() + + inputPath, expectedFiles, err := st.setup(tmpDir) + if err != nil { + t.Fatalf("setup failed: %v", err) + } + + // Add timeout protection for infinite recursion test + done := make(chan struct{}) + var files []string + var getFileErr error + + go func() { + defer close(done) + files, getFileErr = getFile(inputPath, st.required) + }() + + select { + case <-done: + if getFileErr != nil { + t.Fatalf("getFile failed: %v", getFileErr) + } + case <-time.After(5 * time.Second): + t.Fatal("getFile appears to be in infinite recursion (timeout)") + } + + // Sort for comparison + sort.Strings(expectedFiles) + sort.Strings(files) + + if !reflect.DeepEqual(expectedFiles, files) { + t.Fatalf("expected: %q, got: %q", expectedFiles, files) + } + }) + } +}