diff --git a/cmd/guacone/cmd/datadog_malware.go b/cmd/guacone/cmd/datadog_malware.go new file mode 100644 index 0000000000..4097937d32 --- /dev/null +++ b/cmd/guacone/cmd/datadog_malware.go @@ -0,0 +1,169 @@ +// +// Copyright 2024 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "github.com/Khan/genqlient/graphql" + "github.com/guacsec/guac/pkg/assembler/clients/generated" + "github.com/guacsec/guac/pkg/certifier" + "github.com/guacsec/guac/pkg/certifier/certify" + "github.com/guacsec/guac/pkg/certifier/components/root_package" + "github.com/guacsec/guac/pkg/certifier/datadog_malware" + "github.com/guacsec/guac/pkg/cli" + "github.com/guacsec/guac/pkg/handler/processor" + "github.com/guacsec/guac/pkg/ingestor" + "github.com/guacsec/guac/pkg/logging" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type datadogMalwareOptions struct { + graphqlEndpoint string + headerFile string + poll bool + interval time.Duration + addedLatency *time.Duration + batchSize int + lastScan *int +} + +var datadogMalwareCmd = &cobra.Command{ + Use: "datadog-malware [flags]", + Short: "Runs the Datadog malicious package certifier", + Run: func(cmd *cobra.Command, args []string) { + opts, err := validateDatadogMalwareFlags( + viper.GetString("gql-addr"), + viper.GetString("header-file"), + viper.GetString("interval"), + viper.GetString("certifier-latency"), + viper.GetInt("certifier-batch-size"), + viper.GetInt("last-scan"), + viper.GetBool("poll"), + ) + if err != nil { + fmt.Printf("unable to validate flags: %v\n", err) + _ = cmd.Help() + os.Exit(1) + } + + ctx := logging.WithLogger(context.Background()) + logger := logging.FromContext(ctx) + transport := cli.HTTPHeaderTransport(ctx, opts.headerFile, http.DefaultTransport) + + assemblerFunc := ingestor.GetAssembler(ctx, logger, opts.graphqlEndpoint, transport) + ddCertifier, err := datadog_malware.NewDatadogMalwareCertifier(ctx, assemblerFunc) + if err != nil { + logger.Fatalf("unable to create datadog certifier: %v", err) + } + + if err := certify.RegisterCertifier(func() certifier.Certifier { return ddCertifier }, certifier.CertifierDatadogMalware); err != nil { + logger.Fatalf("unable to register datadog certifier: %v", err) + } + + httpClient := http.Client{Transport: transport} + gqlclient := graphql.NewClient(opts.graphqlEndpoint, &httpClient) + + packageQuery := root_package.NewPackageQuery(gqlclient, generated.QueryTypeVulnerability, opts.batchSize, 1000, opts.addedLatency, opts.lastScan) + + totalNum := 0 + emit := func(d *processor.Document) error { + // data dog certifier does not need to emit anything but just for the sake of completeness + if _, err := ingestor.Ingest(ctx, d, opts.graphqlEndpoint, transport, nil, false, false, false, false); err != nil { + return fmt.Errorf("unable to ingest document: %v", err) + } + totalNum += 1 + return nil + } + + errHandler := func(err error) bool { + if err == nil { + logger.Info("certifier ended gracefully") + return true + } + logger.Errorf("certifier ended with error: %v", err) + return true + } + + if err := certify.Certify(ctx, packageQuery, emit, errHandler, opts.poll, opts.interval); err != nil { + logger.Fatal(err) + } + + logger.Infof("completed certifying %d packages", totalNum) + }, +} + +func validateDatadogMalwareFlags( + graphqlEndpoint, + headerFile, + interval, + certifierLatencyStr string, + batchSize int, lastScan int, + poll bool, +) (datadogMalwareOptions, error) { + var opts datadogMalwareOptions + opts.graphqlEndpoint = graphqlEndpoint + opts.headerFile = headerFile + + i, err := time.ParseDuration(interval) + if err != nil { + return opts, err + } + opts.interval = i + if certifierLatencyStr != "" { + addedLatency, err := time.ParseDuration(certifierLatencyStr) + if err != nil { + return opts, fmt.Errorf("failed to parse duration with error: %w", err) + } + opts.addedLatency = &addedLatency + } else { + opts.addedLatency = nil + } + + opts.batchSize = batchSize + if lastScan != 0 { + opts.lastScan = &lastScan + } + opts.poll = poll + return opts, nil +} + +func init() { + set, err := cli.BuildFlags([]string{ + "interval", + "header-file", + "certifier-latency", + "certifier-batch-size", + "last-scan", + "poll", + }) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err) + os.Exit(1) + } + datadogMalwareCmd.PersistentFlags().AddFlagSet(set) + if err := viper.BindPFlags(datadogMalwareCmd.PersistentFlags()); err != nil { + fmt.Fprintf(os.Stderr, "failed to bind flags: %v", err) + os.Exit(1) + } + + certifierCmd.AddCommand(datadogMalwareCmd) +} diff --git a/pkg/certifier/certifier.go b/pkg/certifier/certifier.go index 35e47ea99e..d187094054 100644 --- a/pkg/certifier/certifier.go +++ b/pkg/certifier/certifier.go @@ -51,4 +51,5 @@ const ( CertifierClearlyDefined CertifierType = "CD" CertifierScorecard CertifierType = "scorecard" CertifierEOL CertifierType = "EOL" + CertifierDatadogMalware CertifierType = "DATADOG_MALWARE" ) diff --git a/pkg/certifier/datadog_malware/datadog_malware.go b/pkg/certifier/datadog_malware/datadog_malware.go new file mode 100644 index 0000000000..cad9fc70e8 --- /dev/null +++ b/pkg/certifier/datadog_malware/datadog_malware.go @@ -0,0 +1,303 @@ +// +// Copyright 2024 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datadog_malware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/guacsec/guac/pkg/assembler" + "github.com/guacsec/guac/pkg/assembler/clients/generated" + ingestor "github.com/guacsec/guac/pkg/assembler/clients/helpers" + "github.com/guacsec/guac/pkg/assembler/helpers" + "github.com/guacsec/guac/pkg/certifier" + "github.com/guacsec/guac/pkg/certifier/components/root_package" + "github.com/guacsec/guac/pkg/clients" + "github.com/guacsec/guac/pkg/handler/processor" + "github.com/guacsec/guac/pkg/logging" + "github.com/guacsec/guac/pkg/version" + jsoniter "github.com/json-iterator/go" + "golang.org/x/time/rate" +) + +var ( + json = jsoniter.ConfigCompatibleWithStandardLibrary + rateLimit = 10000 + rateLimitInterval = 30 * time.Second +) + +const ( + NPM_MANIFEST_URL string = "https://raw.githubusercontent.com/DataDog/malicious-software-packages-dataset/main/samples/npm/manifest.json" + PYPI_MANIFEST_URL string = "https://raw.githubusercontent.com/DataDog/malicious-software-packages-dataset/main/samples/pypi/manifest.json" + DatadogMalwareCertifier string = "datadog_malware_certifier" + + // manifestRefreshInterval is the interval at which we periodically re-fetch + // the Datadog malicious packages manifests. + manifestRefreshInterval = 4 * time.Hour +) + +var ErrDatadogMalwareComponentTypeMismatch error = errors.New("rootComponent type is not []*root_package.PackageNode") + +type MaliciousPackages map[string][]string + +type assemblerFuncType func([]assembler.IngestPredicates) (*ingestor.AssemblerIngestedIDs, error) + +type datadogMalwareCertifier struct { + httpClient *http.Client + assemblerFunc assemblerFuncType + + dataMu sync.RWMutex + npmData MaliciousPackages + pypiData MaliciousPackages +} + +// CertifierOption defines functional options for the certifier +type CertifierOption func(*datadogMalwareCertifier) + +// WithHTTPClient allows overriding the default HTTP client +func WithHTTPClient(client *http.Client) CertifierOption { + return func(d *datadogMalwareCertifier) { + d.httpClient = client + } +} + +// NewDatadogMalwareCertifier initializes the Datadog Malicious Software Packages certifier. +// +// The Datadog malicious software packages dataset is a public repository that catalogs +// known malicious artifacts for platforms such as npm and PyPI. By using this data, +// we can certify suspect packages in our GUAC graph as known bad (i.e. CertifyBad). +// https://github.com/DataDog/malicious-software-packages-dataset +// +// - added predicates: each discovered malicious package is ingested with a `CertifyBad` +// predicate, describing the justification, origin, and collector used. +// +// - recommended intervals: this certifier is typically run on a recurring basis or +// with a background refresh at least every few hours. the default here is every 4 hours. +// +// If the provided context is canceled, the refresh loop will exit. +func NewDatadogMalwareCertifier(ctx context.Context, assemblerFunc assemblerFuncType, opts ...CertifierOption) (certifier.Certifier, error) { + limiter := rate.NewLimiter(rate.Every(rateLimitInterval), rateLimit) + transport := clients.NewRateLimitedTransport(version.UATransport, limiter) + defaultClient := &http.Client{Transport: transport} + + d := &datadogMalwareCertifier{ + httpClient: defaultClient, + assemblerFunc: assemblerFunc, + } + + // apply user-provided options + for _, opt := range opts { + opt(d) + } + + // Fetch immediately on creation to have initial data + if err := d.fetchManifests(); err != nil { + return nil, fmt.Errorf("failed to fetch Datadog Malicious Software Packages manifests: %w", err) + } + + // Start a background loop to refresh data every so often + go d.startManifestRefreshLoop(ctx) + + return d, nil +} + +// startManifestRefreshLoop uses a ticker to periodically fetch the manifests +func (d *datadogMalwareCertifier) startManifestRefreshLoop(ctx context.Context) { + ticker := time.NewTicker(manifestRefreshInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := d.fetchManifests(); err != nil { + logging.FromContext(ctx).Errorf("failed to refresh Datadog Malicious Software Packages manifests: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +// fetchManifests updates the local npmData and pypiData from remote sources +func (d *datadogMalwareCertifier) fetchManifests() error { + d.dataMu.Lock() + defer d.dataMu.Unlock() + + npmResp, err := d.httpClient.Get(NPM_MANIFEST_URL) + if err != nil { + return fmt.Errorf("failed to fetch NPM manifest: %w", err) + } + defer npmResp.Body.Close() + + npmData := make(MaliciousPackages) + if err := json.NewDecoder(npmResp.Body).Decode(&npmData); err != nil { + return fmt.Errorf("failed to decode NPM manifest: %w", err) + } + + pypiResp, err := d.httpClient.Get(PYPI_MANIFEST_URL) + if err != nil { + return fmt.Errorf("failed to fetch PyPI manifest: %w", err) + } + defer pypiResp.Body.Close() + + pypiData := make(MaliciousPackages) + if err := json.NewDecoder(pypiResp.Body).Decode(&pypiData); err != nil { + return fmt.Errorf("failed to decode PyPI manifest: %w", err) + } + + d.npmData = npmData + d.pypiData = pypiData + return nil +} + +// CertifyComponent checks the provided packages against the Datadog malicious datasets +func (d *datadogMalwareCertifier) CertifyComponent(ctx context.Context, rootComponent interface{}, docChannel chan<- *processor.Document) error { + logger := logging.FromContext(ctx) + packageNodes, ok := rootComponent.([]*root_package.PackageNode) + if !ok { + return ErrDatadogMalwareComponentTypeMismatch + } + + predicates := &assembler.IngestPredicates{} + currentTime := time.Now().UTC() + + d.dataMu.RLock() + npmData := d.npmData + pypiData := d.pypiData + d.dataMu.RUnlock() + + for _, node := range packageNodes { + purl := node.Purl + + pkgInput, err := helpers.PurlToPkg(purl) + if err != nil { + logger.Debugf("failed to parse purl '%s' into package: %v", purl, err) + continue + } + + var maliciousVersions []string + switch pkgInput.Type { + case "npm": + // Conditionally add "@" only if the original namespace began with "@" or "%40" + fullName := pkgInput.Name + if pkgInput.Namespace != nil && *pkgInput.Namespace != "" { + ns := *pkgInput.Namespace + trimmedNS := strings.TrimPrefix(ns, "@") + trimmedNS = strings.TrimPrefix(trimmedNS, "%40") + + // If we did actually remove a prefix, then re-add the "@" + if strings.HasPrefix(ns, "@") || strings.HasPrefix(ns, "%40") { + fullName = "@" + trimmedNS + "/" + pkgInput.Name + } else { + // Otherwise, just combine them without '@' + fullName = trimmedNS + "/" + pkgInput.Name + } + } + v, found := npmData[fullName] + if !found { + continue + } + maliciousVersions = v + + case "pypi": + v, found := pypiData[pkgInput.Name] + if !found { + continue + } + maliciousVersions = v + + default: + logger.Debugf("Skipping package %s, not npm or pypi", purl) + continue + } + + if len(maliciousVersions) == 0 { + // interpret empty versions array as "all versions malicious" + justification := "All versions of this package are malicious." + certifyBad := &assembler.CertifyBadIngest{ + Pkg: pkgInput, + PkgMatchFlag: generated.MatchFlags{Pkg: generated.PkgMatchTypeAllVersions}, + CertifyBad: &generated.CertifyBadInputSpec{ + Justification: justification, + Origin: "Datadog Malicious Software Packages Dataset", + Collector: DatadogMalwareCertifier, + KnownSince: currentTime, + }, + } + predicates.CertifyBad = append(predicates.CertifyBad, *certifyBad) + } else { + // we have specific malicious versions listed + if pkgInput.Version == nil || *pkgInput.Version == "" { + // no specific version requested, treat all known malicious versions as applying to all versions + justification := "All versions of this package are malicious according to Datadog's dataset." + certifyBad := &assembler.CertifyBadIngest{ + Pkg: pkgInput, + PkgMatchFlag: generated.MatchFlags{Pkg: generated.PkgMatchTypeAllVersions}, + CertifyBad: &generated.CertifyBadInputSpec{ + Justification: justification, + Origin: "Datadog Malicious Software Packages Dataset", + Collector: DatadogMalwareCertifier, + KnownSince: currentTime, + }, + } + predicates.CertifyBad = append(predicates.CertifyBad, *certifyBad) + } else { + // If a specific version is given, check if it's malicious + versionToCheck := *pkgInput.Version + if !containsVersion(maliciousVersions, versionToCheck) { + logger.Debugf("Package %s version %s not found in malicious dataset", purl, versionToCheck) + continue + } + + justification := fmt.Sprintf("Package version %s found in Datadog's malicious software packages dataset.", versionToCheck) + certifyBad := &assembler.CertifyBadIngest{ + Pkg: pkgInput, + PkgMatchFlag: generated.MatchFlags{Pkg: generated.PkgMatchTypeSpecificVersion}, + CertifyBad: &generated.CertifyBadInputSpec{ + Justification: justification, + Origin: "Datadog Malicious Software Packages Dataset", + Collector: DatadogMalwareCertifier, + KnownSince: currentTime, + }, + } + predicates.CertifyBad = append(predicates.CertifyBad, *certifyBad) + } + } + } + + if len(predicates.CertifyBad) > 0 { + if _, err := d.assemblerFunc([]assembler.IngestPredicates{*predicates}); err != nil { + return fmt.Errorf("unable to assemble graphs: %w", err) + } + } + + return nil +} + +// containsVersion checks if a given version string is in the malicious versions list +func containsVersion(maliciousVersions []string, versionToCheck string) bool { + for _, v := range maliciousVersions { + if v == versionToCheck { + return true + } + } + return false +} diff --git a/pkg/certifier/datadog_malware/datadog_malware_test.go b/pkg/certifier/datadog_malware/datadog_malware_test.go new file mode 100644 index 0000000000..fc561172c2 --- /dev/null +++ b/pkg/certifier/datadog_malware/datadog_malware_test.go @@ -0,0 +1,320 @@ +// +// Copyright 2024 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datadog_malware + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "path" + "strings" + "testing" + + "github.com/guacsec/guac/pkg/assembler" + "github.com/guacsec/guac/pkg/assembler/clients/generated" + ingestor "github.com/guacsec/guac/pkg/assembler/clients/helpers" + "github.com/guacsec/guac/pkg/assembler/helpers" + "github.com/guacsec/guac/pkg/certifier/components/root_package" + "github.com/guacsec/guac/pkg/logging" + "github.com/stretchr/testify/assert" +) + +func TestDatadogMalwareCertifier_CertifyComponent(t *testing.T) { + // set up test server that handles both npm and pypi manifests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var response map[string][]string + switch path.Base(r.URL.Path) { + case "manifest.json": + if strings.Contains(r.URL.Path, "npm") { + response = map[string][]string{ + "@malicious/package": {"1.0.0"}, + "evil-package": {"2.0.0"}, + "@malicious/emptyversions": {}, + "@malicious/multipleversions": {"1.0.0", "2.0.0"}, + "@malicious/singleversion": {"1.0.0"}, + } + } else if strings.Contains(r.URL.Path, "pypi") { + response = map[string][]string{ + "malicious-pypi": {"0.1.0"}, + "emptyversions-pypi": {}, + "manyversions-pypi": {"0.1.0", "0.2.0"}, + "singleversion-pypi": {"3.3.3"}, + } + } + default: + http.Error(w, "not found", http.StatusNotFound) + return + } + + err := json.NewEncoder(w).Encode(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + })) + defer server.Close() + + // create a custom client that redirects all requests to our test server + testClient := &http.Client{ + Transport: &mockTransport{ + server: server, + npmManifestURL: NPM_MANIFEST_URL, + pypiManifestURL: PYPI_MANIFEST_URL, + }, + } + + ctx := logging.WithLogger(context.Background()) + + type testCase struct { + name string + rootComponent interface{} + expectedPreds int + wantPackages []string + wantPkgMatch generated.PkgMatchType + wantErr bool + errMessage error + assemblerError error + } + + tests := []testCase{ + { + name: "certify malicious npm and pypi packages", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/package@1.0.0"}, + {Purl: "pkg:npm/evil-package@2.0.0"}, + {Purl: "pkg:pypi/malicious-pypi@0.1.0"}, + {Purl: "pkg:maven/safe/package@1.0.0"}, + }, + expectedPreds: 3, + wantPackages: []string{ + "pkg:npm/%40malicious/package@1.0.0", + "pkg:npm/evil-package@2.0.0", + "pkg:pypi/malicious-pypi@0.1.0", + }, + wantPkgMatch: generated.PkgMatchTypeSpecificVersion, + wantErr: false, + }, + { + name: "bad component type", + rootComponent: map[string]string{}, + expectedPreds: 0, + wantErr: true, + errMessage: ErrDatadogMalwareComponentTypeMismatch, + }, + { + name: "no malicious packages", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/safe-package@1.0.0"}, + {Purl: "pkg:pypi/good-package@1.0.0"}, + }, + expectedPreds: 0, + wantErr: false, + }, + { + name: "assembler error", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/multipleversions@1.0.0"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:npm/%40malicious/multipleversions@1.0.0"}, + wantPkgMatch: generated.PkgMatchTypeSpecificVersion, + wantErr: true, + assemblerError: errors.New("assembler error"), + }, + + // empty versions list, no package version specified + { + name: "empty versions, no package version (npm)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/emptyversions"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:npm/%40malicious/emptyversions"}, + wantPkgMatch: generated.PkgMatchTypeAllVersions, + wantErr: false, + }, + { + name: "empty versions, no package version (pypi)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:pypi/emptyversions-pypi"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:pypi/emptyversions-pypi"}, + wantPkgMatch: generated.PkgMatchTypeAllVersions, + wantErr: false, + }, + + // empty versions list, with package version specified + { + name: "empty versions, package version specified (npm)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/emptyversions@9.9.9"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:npm/%40malicious/emptyversions@9.9.9"}, + wantPkgMatch: generated.PkgMatchTypeAllVersions, + wantErr: false, + }, + + // non-empty versions, no package version specified -> all malicious + { + name: "non-empty versions, no version specified (npm)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/multipleversions"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:npm/%40malicious/multipleversions"}, + wantPkgMatch: generated.PkgMatchTypeAllVersions, + wantErr: false, + }, + { + name: "non-empty versions, no version specified (pypi)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:pypi/manyversions-pypi"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:pypi/manyversions-pypi"}, + wantPkgMatch: generated.PkgMatchTypeAllVersions, + wantErr: false, + }, + + // non-empty versions, with a specified malicious version + { + name: "non-empty versions, specified malicious version (npm)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/multipleversions@1.0.0"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:npm/%40malicious/multipleversions@1.0.0"}, + wantPkgMatch: generated.PkgMatchTypeSpecificVersion, + wantErr: false, + }, + { + name: "non-empty versions, specified malicious version (pypi)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:pypi/manyversions-pypi@0.1.0"}, + }, + expectedPreds: 1, + wantPackages: []string{"pkg:pypi/manyversions-pypi@0.1.0"}, + wantPkgMatch: generated.PkgMatchTypeSpecificVersion, + wantErr: false, + }, + + // non-empty versions, with a specified non-malicious version + { + name: "non-empty versions, specified non-malicious version (npm)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:npm/%40malicious/multipleversions@3.0.0"}, + }, + expectedPreds: 0, + wantErr: false, + }, + { + name: "non-empty versions, specified non-malicious version (pypi)", + rootComponent: []*root_package.PackageNode{ + {Purl: "pkg:pypi/manyversions-pypi@1.1.1"}, + }, + expectedPreds: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedPreds []assembler.IngestPredicates + mockAssembler := func(preds []assembler.IngestPredicates) (*ingestor.AssemblerIngestedIDs, error) { + capturedPreds = preds + if tt.assemblerError != nil { + return nil, tt.assemblerError + } + return &ingestor.AssemblerIngestedIDs{}, nil + } + + certifier, err := NewDatadogMalwareCertifier(ctx, mockAssembler, WithHTTPClient(testClient)) + if err != nil { + t.Fatalf("Failed to create certifier: %v", err) + } + + err = certifier.CertifyComponent(ctx, tt.rootComponent, nil) + + if (err != nil) != tt.wantErr { + t.Errorf("CertifyComponent() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil { + if tt.errMessage != nil && !errors.Is(err, tt.errMessage) { + t.Errorf("CertifyComponent() error = %v, want error = %v", err, tt.errMessage) + } + return + } + + if tt.expectedPreds > 0 { + assert.Len(t, capturedPreds, 1, "Should have one IngestPredicates") + assert.Len(t, capturedPreds[0].CertifyBad, tt.expectedPreds, "Should have expected number of CertifyBad predicates") + + foundPackages := make(map[string]bool) + for _, certifyBad := range capturedPreds[0].CertifyBad { + purl := helpers.PkgInputSpecToPurl(certifyBad.Pkg) + foundPackages[purl] = true + + // verify predicate content + assert.Equal(t, DatadogMalwareCertifier, certifyBad.CertifyBad.Collector) + assert.Equal(t, "Datadog Malicious Software Packages Dataset", certifyBad.CertifyBad.Origin) + assert.NotEmpty(t, certifyBad.CertifyBad.Justification) + assert.Equal(t, tt.wantPkgMatch, certifyBad.PkgMatchFlag.Pkg) + assert.NotNil(t, certifyBad.CertifyBad.KnownSince) + } + + for _, pkg := range tt.wantPackages { + assert.True(t, foundPackages[pkg], fmt.Sprintf("Package %s was not found in certifications", pkg)) + } + } else { + assert.Len(t, capturedPreds, 0, "Should have no IngestPredicates") + } + }) + } +} + +// mockTransport redirects requests to the test server while preserving paths +type mockTransport struct { + server *httptest.Server + npmManifestURL string + pypiManifestURL string +} + +func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + newURL := t.server.URL + switch req.URL.String() { + case t.npmManifestURL: + newURL += "/npm/manifest.json" + case t.pypiManifestURL: + newURL += "/pypi/manifest.json" + default: + return nil, fmt.Errorf("unexpected URL: %s", req.URL.String()) + } + + newReq := req.Clone(req.Context()) + var err error + newReq.URL, err = req.URL.Parse(newURL) + if err != nil { + return nil, err + } + + return t.server.Client().Do(newReq) +}