diff --git a/main.go b/main.go index 484ba50..e36101b 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "os" "os/exec" "path/filepath" + "regexp" "strings" "time" @@ -621,22 +622,21 @@ func getModulesInDir(files misc.StringSet, modPath string) error { // anywhere func getModule(files misc.StringSet, modName string, modDir string) error { - deps, err := getModuleDeps(modName, modDir) - if err != nil { - return err + modDep := filepath.Join(modDir, "modules.dep") + if !exists(modDep) { + log.Fatal("Kernel module.dep not found: ", modDir) } - if len(deps) == 0 { - // retry and swap - and _ in module name - if strings.Contains(modName, "-") { - modName = strings.ReplaceAll(modName, "-", "_") - } else { - modName = strings.ReplaceAll(modName, "_", "-") - } - deps, err = getModuleDeps(modName, modDir) - if err != nil { - return err - } + fd, err := os.Open(modDep) + if err != nil { + log.Print("Unable to open modules.dep: ", modDep) + return err + } + defer fd.Close() + + deps, err := getModuleDeps(modName, fd) + if err != nil { + return err } for _, dep := range deps { @@ -651,29 +651,35 @@ func getModule(files misc.StringSet, modName string, modDir string) error { return err } -func getModuleDeps(modName string, modDir string) ([]string, error) { +// Get the canonicalized name for the module as represented in the given modules.dep io.reader +func getModuleDeps(modName string, modulesDep io.Reader) ([]string, error) { var deps []string - modDep := filepath.Join(modDir, "modules.dep") - if !exists(modDep) { - log.Fatal("Kernel module.dep not found: ", modDir) + // split the module name on - and/or _, build a regex for matching + splitRe := regexp.MustCompile("[-_]+") + var modNameReStr string + for _, s := range splitRe.Split(modName, -1) { + if modNameReStr != "" { + modNameReStr += "[-_]+" + s + } else { + modNameReStr = s + } } + re := regexp.MustCompile(modNameReStr) - fd, err := os.Open(modDep) - if err != nil { - log.Print("Unable to open modules.dep: ", modDep) - return deps, err - } - - defer fd.Close() - s := bufio.NewScanner(fd) + s := bufio.NewScanner(modulesDep) for s.Scan() { fields := strings.Fields(s.Text()) - fields[0] = strings.TrimSuffix(fields[0], ":") - if modName != filepath.Base(stripExts(fields[0])) { + if len(fields) == 0 { continue } - deps = append(deps, fields...) + fields[0] = strings.TrimSuffix(fields[0], ":") + + found := re.FindAll([]byte(filepath.Base(stripExts(fields[0]))), -1) + if len(found) > 0 { + deps = append(deps, fields...) + break + } } if err := s.Err(); err != nil { log.Print("Unable to get module + dependencies: ", modName)