#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later
#
# Enriches the input CycloneDX SBOM with vulnerability information from the NVD
# database.
#
# The NVD database is cloned using a mirror of it and the content is compared
# locally.
#
# Example usage:
# $ make show-info | utils/generate-cyclonedx | support/script/cve-check --nvd-path dl/buildroot-nvd/
from collections import defaultdict
from pathlib import Path
from typing import TypedDict
from datetime import datetime, timezone
import argparse
import sys
import json
import cve as cvecheck


class Options(TypedDict, total=True):
    include_resolved: bool


DESCRIPTION = """
Enriches the input CycloneDX SBOM with vulnerability information from the NVD
database.

The NVD database is cloned using a mirror of it and the content is compared
locally.

Always run this script from the output of 'generate-cyclonedx'. Do not re-run
this script over an already analysed SBOM.
"""


brpath = Path(__file__).parent.parent.parent


def datetime_to_rfc3339(dt_string):
    """Normalize datetime string to RFC 3339 format with Z suffix.

    NVD dates are already in ISO format, just need to add the Z suffix.

    Input:  "1999-01-01T05:00:00.000"
    Output: "1999-01-01T05:00:00.000Z"
    """
    dt = datetime.fromisoformat(dt_string.replace('Z', '+00:00'))

    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    else:
        dt = dt.astimezone(timezone.utc)

    return dt.isoformat().replace('+00:00', 'Z')


def cve_api_get_lang_from_list(values, lang="en") -> (str | None):
    for x in values:
        if x.get("lang") == lang:
            return x.get("value")
    return None


def nvd_cve_weaknesses_to_cdx(weaknesses) -> list[int]:
    """
    See the CycloneDX specification for 'cwes' [1]

    [1] https://cyclonedx.org/docs/1.6/json/#vulnerabilities_items_cwes
    """
    res = []

    for node in weaknesses:
        value = cve_api_get_lang_from_list(node.get("description", []))
        if value is None:
            continue

        cwe = value.replace("CWE-", "")

        if not cwe.isnumeric():
            continue
        res.append(int(cwe))

    return res


def nvd_cve_cvss_to_cdx(metrics):
    """
    See the CycloneDX specification for 'ratings' [1]

    [1] https://cyclonedx.org/docs/1.6/json/#vulnerabilities_items_ratings
    """

    KEY_METHOD_DICT = {
        "cvssMetricV40": "CVSSv4",
        "cvssMetricV31": "CVSSv31",
        "cvssMetricV3": "CVSSv3",
        "cvssMetricV2": "CVSSv2"
    }

    res = []

    for key, values in metrics.items():
        for value in values:
            data = value.get("cvssData", {})
            res.append({
                "method": KEY_METHOD_DICT.get(key, "other"),
                **({
                    "score": data["baseScore"],
                } if "baseScore" in data else {}),
                **({
                    "severity": data["baseSeverity"].lower(),
                } if "baseSeverity" in data else {}),
                **({
                    "vector": data["vectorString"],
                } if "vectorString" in data else {}),
            })

    return res


def nvd_cve_references_to_cdx(references):
    advisories = []

    for ref in references:
        if not {"url", "tags"}.issubset(ref):
            continue

        tags = ref["tags"]
        if not isinstance(tags, list) or len(tags) == 0:
            continue

        advisories.append({
            "title": next((t for t in tags if "Advisory" not in t), tags[0]),
            "url": ref["url"]
        })

    return advisories


def nvd_cve_to_cdx_vulnerability(nvd_cve):
    """
    Turns the CVE object fetched from the NVD API into a CycloneDX
    vulnerability that fits the spec (see [1]).

    [1] https://cyclonedx.org/docs/1.6/json/#vulnerabilities
    """
    vulnerability = {
        "id": nvd_cve["id"],
        "description": cve_api_get_lang_from_list(nvd_cve.get("descriptions", [])) or "",
        "source": {
            "name": "NVD",
            "url": f"https://nvd.nist.gov/vuln/detail/{nvd_cve['id']}"
        },
        **({
            "published": datetime_to_rfc3339(nvd_cve["published"]),
        } if "published" in nvd_cve else {}),
        **({
            "updated": datetime_to_rfc3339(nvd_cve["lastModified"]),
        } if "lastModified" in nvd_cve else {}),
        **({
            "cwes": nvd_cve_weaknesses_to_cdx(nvd_cve["weaknesses"]),
        } if "weaknesses" in nvd_cve else {}),
        **({
            "ratings": nvd_cve_cvss_to_cdx(nvd_cve["metrics"]),
        } if "metrics" in nvd_cve else {}),
        **({
            "advisories": nvd_cve_references_to_cdx(nvd_cve["references"]),
        } if "references" in nvd_cve else {}),
    }

    return vulnerability


def vuln_append_or_update_affects_if_exists(vulnerabilities, vulnerability):
    """
    Updates a matching 'vulnerability' from the 'vulnerabilities' list or
    appends it as a new entry.

    A vulnerability is considered 'matching' if it shares the same 'id' AND
    either:

    1. An identical 'affects' entry.
    2. An identical 'analysis.state'.

    Args:
        vulnerabilities (list): The vulnerabilities array reference retrieved
            from the input CycloneDX SBOM
        vulnerability (dict): Vulnerability to add to the 'vulnerabilities' list.
    """
    new_analysis = vulnerability.get("analysis", {}).get("state")
    new_ref = next((a.get("ref") for a in vulnerability.get("affects", [])), None)

    # All vulnerabilities with same ID
    matching_vulns = [v for v in vulnerabilities if v.get("id") == vulnerability.get("id")]

    for curr_vuln in matching_vulns:
        curr_vuln_analysis = curr_vuln.get("analysis", {}).get("state")
        curr_vuln_refs = [a.get("ref") for a in curr_vuln.get("affects", [])]

        is_same_ref = new_ref in curr_vuln_refs
        is_same_analysis = curr_vuln_analysis == new_analysis

        if not (is_same_ref or is_same_analysis):
            continue

        if is_same_ref:
            # If same vulnerability id and same affect ref, keep the previous
            # analysis. This is the case where a vulnerability was ignored from
            # the generated SBOM.
            del vulnerability["analysis"]
            del vulnerability["affects"]
        else:
            # The same analysis, add a new affect
            # reference.
            if new_ref is not None:
                curr_vuln.setdefault("affects", []).append({"ref": new_ref})
                del vulnerability["affects"]

        curr_vuln.update(vulnerability)
        return

    # No same ID w/ same analysis or same ref.
    vulnerabilities.append(vulnerability)


def check_package_cve_affects(cve: cvecheck.CVE, cpe_product_pkgs, sbom, opt: Options):
    vulnerabilities = sbom.setdefault("vulnerabilities", [])

    for product in cve.affected_products:
        for comp in cpe_product_pkgs.get(product, []):
            cve_status = cve.affects(comp["name"], comp["version"], comp["cpe"])

            if cve_status == cve.CVE_UNKNOWN:
                continue

            if cve_status == cve.CVE_DOESNT_AFFECT and not opt["include_resolved"]:
                continue

            vulnerability = nvd_cve_to_cdx_vulnerability(cve.nvd_cve)

            vulnerability["analysis"] = {
                "state": "exploitable" if cve_status == cve.CVE_AFFECTS else "resolved"
            }

            vulnerability["affects"] = [{
                "ref": comp["bom-ref"]
            }]

            vuln_append_or_update_affects_if_exists(vulnerabilities, vulnerability)


def check_package_cves(nvd_path: Path, sbom, opt: Options):
    """
    Iterate over every entry of the NVD API mirror. Each vulnerability is
    compared to the set of components passed as argument in the 'sbom'.
    The vulnerabilities set of that 'sbom' argument is enriched with analysis
    of vulnerabilities that match that set of components.

    Args:
        nvd_path (Path): Path of the mirror of the NVD API.
        sbom (dict): Input SBOM containing a set of vulnerabilities that will be enriched.
        opt (Options): Options for the analysis.
    """
    cpe_product_pkgs = defaultdict(list)

    for comp in sbom.get("components", []):
        if comp.get("cpe") and comp.get("version"):
            cpe_product = cvecheck.CPE(comp["cpe"]).product
            cpe_product_pkgs[cpe_product].append(comp)

    for cve in cvecheck.CVE.read_nvd_dir(nvd_path):
        check_package_cve_affects(cve, cpe_product_pkgs, sbom, opt)


def enrich_vulnerabilities(nvd_path: Path, sbom):
    """
    Iterate over the vulnerabilities present in the 'sbom' passed as arguments
    and enrich the vulnerability with content from the NVD API mirror.

    Args:
        nvd_path (Path): Path of the mirror of the NVD API.
        sbom (dict): Input SBOM containing a set of vulnerabilities that will be enriched.
    """
    vulnerabilities = sbom.setdefault("vulnerabilities", [])

    for vuln in vulnerabilities:
        vuln_id = vuln.get("id")
        if vuln_id is None or not vuln_id.upper().startswith("CVE-"):
            continue

        cve = cvecheck.CVE.read_nvd_entry(nvd_path, vuln_id)

        if cve is None:
            print(f"Warning: '{vuln_id}' doesn't exist in NVD database.", file=sys.stderr)
            continue

        vuln.update(nvd_cve_to_cdx_vulnerability(cve.nvd_cve))


def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument("-i", "--in-file", nargs="?", type=argparse.FileType("r"),
                        default=(None if sys.stdin.isatty() else sys.stdin))
    parser.add_argument("-o", "--out-file", nargs="?", type=argparse.FileType("w"),
                        default=sys.stdout)
    parser.add_argument('--nvd-path', dest='nvd_path',
                        default=brpath / 'dl' / 'buildroot-nvd',
                        help='Path to the local NVD database',
                        type=lambda p: Path(p).expanduser().resolve())
    parser.add_argument("--enrich-only", default=False, action='store_true',
                        help="Only update metadata for the vulnerabilities currently present " +
                        "in the input CycloneDX SBOM. Don't do an analysis.")
    parser.add_argument("--include-resolved", default=False, action='store_true',
                        help="Add vulnerabilities already 'resolved' that don't affect a " +
                        "component to the output CycloneDX vulnerabilities analysis.")
    parser.add_argument("--no-nvd-update", default=False, action='store_true',
                        help="Doesn't update the NVD database.")

    args = parser.parse_args()

    if args.in_file is None or args.nvd_path is None:
        parser.print_help()
        sys.exit(1)

    sbom = json.load(args.in_file)

    opt = Options(
        include_resolved=args.include_resolved,
    )

    args.nvd_path.mkdir(parents=True, exist_ok=True)
    if not args.no_nvd_update:
        cvecheck.CVE.download_nvd(args.nvd_path)

    if args.enrich_only:
        enrich_vulnerabilities(args.nvd_path, sbom)
    else:
        check_package_cves(args.nvd_path, sbom, opt)

    args.out_file.write(json.dumps(sbom, indent=2))
    args.out_file.write('\n')


if __name__ == "__main__":
    main()
