diff --git a/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/DependencyManager.cs b/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/DependencyManager.cs index be7eb16253c..d6a19b9cd1a 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/DependencyManager.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/DependencyManager.cs @@ -59,14 +59,14 @@ namespace Semmle.Extraction.CSharp.DependencyFetching this.progressMonitor.FindingFiles(srcDir); packageDirectory = new TemporaryDirectory(ComputeTempDirectory(sourceDir.FullName)); - var allFiles = GetFiles("*.*").ToList(); - var smallFiles = GetSmallFiles(allFiles); - this.fileContent = new FileContent(progressMonitor, GetFileNames(smallFiles)); - this.allSources = GetFileNames(allFiles, ".cs").ToList(); - var allProjects = GetFileNames(allFiles, ".csproj"); + var allFiles = GetAllFiles().ToList(); + var smallFiles = allFiles.SelectSmallFiles(progressMonitor).SelectFileNames(); + this.fileContent = new FileContent(progressMonitor, smallFiles); + this.allSources = allFiles.SelectFileNamesByExtension(".cs").ToList(); + var allProjects = allFiles.SelectFileNamesByExtension(".csproj"); var solutions = options.SolutionFile is not null ? new[] { options.SolutionFile } - : GetFileNames(allFiles, ".sln"); + : allFiles.SelectFileNamesByExtension(".sln"); var dllDirNames = options.DllDirs.Select(Path.GetFullPath).ToList(); @@ -156,7 +156,7 @@ namespace Semmle.Extraction.CSharp.DependencyFetching { progressMonitor.LogInfo($"Generating source files from cshtml and razor files."); - var views = GetFileNames(allFiles, ".cshtml", ".razor").ToArray(); + var views = allFiles.SelectFileNamesByExtension(".cshtml", ".razor").ToArray(); if (views.Length > 0) { @@ -184,31 +184,10 @@ namespace Semmle.Extraction.CSharp.DependencyFetching public DependencyManager(string srcDir) : this(srcDir, DependencyOptions.Default, new ConsoleLogger(Verbosity.Info)) { } - private IEnumerable GetFiles(string pattern, bool recurseSubdirectories = true) => - sourceDir.GetFiles(pattern, new EnumerationOptions - { - RecurseSubdirectories = recurseSubdirectories, - MatchCasing = MatchCasing.CaseInsensitive - }) + private IEnumerable GetAllFiles() => + sourceDir.GetFiles("*.*", new EnumerationOptions { RecurseSubdirectories = true }) .Where(d => d.Extension != ".dll" && !options.ExcludesFile(d.FullName)); - private static IEnumerable GetFileNames(IEnumerable files, params string[] extensions) => - files.Where(fi => !extensions.Any() || extensions.Contains(fi.Extension)).Select(fi => fi.FullName); - - private IEnumerable GetSmallFiles(IEnumerable files) - { - const int oneMb = 1_048_576; - return files.Where(file => - { - if (file.Length > oneMb) - { - progressMonitor.LogDebug($"Skipping {file.FullName} because it is bigger than 1MB."); - return false; - } - return true; - }); - } - /// /// Computes a unique temp directory for the packages associated /// with this source tree. Use a SHA1 of the directory name. @@ -390,16 +369,14 @@ namespace Semmle.Extraction.CSharp.DependencyFetching private void DownloadMissingPackages(List allFiles) { - var nugetConfigs = allFiles - .Where(fi => fi.Name == "nuget.config") - .Select(fi => fi.FullName) - .ToArray(); + var nugetConfigs = allFiles.SelectFileNamesByName("nuget.config").ToArray(); string? nugetConfig = null; if (nugetConfigs.Length > 1) { progressMonitor.MultipleNugetConfig(nugetConfigs); - nugetConfig = GetFiles("nuget.config", recurseSubdirectories: false) - .Select(fi => fi.FullName) + nugetConfig = allFiles + .SelectRootFiles(sourceDir) + .SelectFileNamesByName("nuget.config") .FirstOrDefault(); if (nugetConfig == null) { @@ -412,8 +389,7 @@ namespace Semmle.Extraction.CSharp.DependencyFetching } var alreadyDownloadedPackages = Directory.GetDirectories(packageDirectory.DirInfo.FullName) - .Select(d => Path.GetFileName(d) - .ToLowerInvariant()); + .Select(d => Path.GetFileName(d).ToLowerInvariant()); var notYetDownloadedPackages = fileContent.AllPackages.Except(alreadyDownloadedPackages); foreach (var package in notYetDownloadedPackages) { diff --git a/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/FileInfoExtensions.cs b/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/FileInfoExtensions.cs new file mode 100644 index 00000000000..1d285d03d04 --- /dev/null +++ b/csharp/extractor/Semmle.Extraction.CSharp.DependencyFetching/FileInfoExtensions.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Semmle.Extraction.CSharp.DependencyFetching +{ + public static class FileInfoExtensions + { + private static IEnumerable SelectFilesAux(this IEnumerable files, Predicate p) => + files.Where(f => p(f)).Select(fi => fi.FullName); + + public static IEnumerable SelectRootFiles(this IEnumerable files, DirectoryInfo dir) => + files.Where(file => file.DirectoryName == dir.FullName); + + internal static IEnumerable SelectSmallFiles(this IEnumerable files, ProgressMonitor progressMonitor) + { + const int oneMb = 1_048_576; + return files.Where(file => + { + if (file.Length > oneMb) + { + progressMonitor.LogDebug($"Skipping {file.FullName} because it is bigger than 1MB."); + return false; + } + return true; + }); + } + + public static IEnumerable SelectFileNamesByExtension(this IEnumerable files, params string[] extensions) => + files.SelectFilesAux(fi => extensions.Contains(fi.Extension)); + + public static IEnumerable SelectFileNamesByName(this IEnumerable files, params string[] names) => + files.SelectFilesAux(fi => names.Any(name => string.Compare(name, fi.Name, true) == 0)); + + public static IEnumerable SelectFileNames(this IEnumerable files) => + files.SelectFilesAux(_ => true); + } +}