diff options
Diffstat (limited to 'src/nvd')
| -rw-r--r-- | src/nvd/__init__.py | 39 | ||||
| -rw-r--r-- | src/nvd/__version__.py | 1 | ||||
| -rw-r--r-- | src/nvd/cli/__init__.py | 1 | ||||
| -rw-r--r-- | src/nvd/cli/commands/__init__.py | 1 | ||||
| -rw-r--r-- | src/nvd/cli/commands/config.py | 95 | ||||
| -rw-r--r-- | src/nvd/cli/commands/cpe.py | 170 | ||||
| -rw-r--r-- | src/nvd/cli/commands/cve.py | 150 | ||||
| -rw-r--r-- | src/nvd/cli/formatters.py | 137 | ||||
| -rw-r--r-- | src/nvd/cli/main.py | 72 | ||||
| -rw-r--r-- | src/nvd/client.py | 213 | ||||
| -rw-r--r-- | src/nvd/endpoints/__init__.py | 1 | ||||
| -rw-r--r-- | src/nvd/endpoints/cpe.py | 93 | ||||
| -rw-r--r-- | src/nvd/endpoints/cpematch.py | 101 | ||||
| -rw-r--r-- | src/nvd/endpoints/cve.py | 205 | ||||
| -rw-r--r-- | src/nvd/endpoints/cve_history.py | 79 | ||||
| -rw-r--r-- | src/nvd/endpoints/source.py | 81 | ||||
| -rw-r--r-- | src/nvd/exceptions.py | 61 | ||||
| -rw-r--r-- | src/nvd/models.py | 315 | ||||
| -rw-r--r-- | src/nvd/py.typed | 0 | ||||
| -rw-r--r-- | src/nvd/rate_limiter.py | 67 |
20 files changed, 1882 insertions, 0 deletions
diff --git a/src/nvd/__init__.py b/src/nvd/__init__.py new file mode 100644 index 0000000..4821325 --- /dev/null +++ b/src/nvd/__init__.py @@ -0,0 +1,39 @@ +"""NVD API - Python library and CLI for the US National Vulnerability Database API 2.0.""" + +from .__version__ import __version__ +from .client import NVDClient +from .exceptions import ( + AuthenticationError, + NetworkError, + NotFoundError, + NVDError, + RateLimitError, + ResponseError, + ServerError, + ValidationError, +) +from .models import ( + CPEData, + CPEMatchString, + CVEChange, + CVEData, + SourceData, +) + +__all__ = [ + "__version__", + "NVDClient", + "NVDError", + "RateLimitError", + "AuthenticationError", + "ValidationError", + "NotFoundError", + "ServerError", + "NetworkError", + "ResponseError", + "CVEData", + "CPEData", + "CPEMatchString", + "CVEChange", + "SourceData", +] diff --git a/src/nvd/__version__.py b/src/nvd/__version__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/src/nvd/__version__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/src/nvd/cli/__init__.py b/src/nvd/cli/__init__.py new file mode 100644 index 0000000..73851cd --- /dev/null +++ b/src/nvd/cli/__init__.py @@ -0,0 +1 @@ +"""CLI package for NVD API.""" diff --git a/src/nvd/cli/commands/__init__.py b/src/nvd/cli/commands/__init__.py new file mode 100644 index 0000000..3f4c467 --- /dev/null +++ b/src/nvd/cli/commands/__init__.py @@ -0,0 +1 @@ +"""CLI commands package.""" diff --git a/src/nvd/cli/commands/config.py b/src/nvd/cli/commands/config.py new file mode 100644 index 0000000..46da597 --- /dev/null +++ b/src/nvd/cli/commands/config.py @@ -0,0 +1,95 @@ +"""Configuration CLI commands.""" + +import os +import sys +from pathlib import Path +from typing import Optional + +import typer +import yaml +from rich.console import Console +from rich.table import Table + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + +app = typer.Typer( + no_args_is_help=True, + help="Manage nvdb configuration (API keys, settings)", + context_settings=CONTEXT_SETTINGS, + epilog="Examples:\n" + " nvdb config show\n" + " nvdb config set-api-key YOUR_API_KEY\n" + " nvdb config clear", +) + +# Console for stderr (status messages, errors) +console = Console(stderr=True) + +CONFIG_DIR = Path.home() / ".config" / "nvd" +CONFIG_FILE = CONFIG_DIR / "config.yaml" + + +def load_config() -> dict: + """Load configuration from file.""" + if not CONFIG_FILE.exists(): + return {} + with open(CONFIG_FILE) as f: + return yaml.safe_load(f) or {} + + +def save_config(config: dict) -> None: + """Save configuration to file.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + with open(CONFIG_FILE, "w") as f: + yaml.dump(config, f) + + +@app.command("set-api-key") +def set_api_key( + api_key: str = typer.Argument(..., help="Your NVD API key"), +) -> None: + """Set the NVD API key in config file.""" + config = load_config() + config["api_key"] = api_key + save_config(config) + console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") + console.print("[yellow]Tip: You can also set the NVD_API_KEY environment variable[/yellow]") + + +@app.command("show") +def show_config() -> None: + """Show current configuration.""" + config = load_config() + env_api_key = os.getenv("NVD_API_KEY") + + table = Table(title="NVD API Configuration", show_header=True, header_style="bold magenta") + table.add_column("Setting", style="cyan") + table.add_column("Value", style="green") + table.add_column("Source", style="yellow") + + if config.get("api_key"): + masked_key = config["api_key"][:8] + "..." if len(config["api_key"]) > 8 else "***" + table.add_row("API Key", masked_key, f"Config file ({CONFIG_FILE})") + elif env_api_key: + masked_key = env_api_key[:8] + "..." if len(env_api_key) > 8 else "***" + table.add_row("API Key", masked_key, "Environment variable") + else: + table.add_row("API Key", "Not set", "N/A") + + console.print(table) + + if not config.get("api_key") and not env_api_key: + console.print("\n[yellow]No API key configured. Using unauthenticated access (5 req/30s)[/yellow]") + console.print("[blue]Set an API key for higher rate limits (50 req/30s):[/blue]") + console.print(" nvdb config set-api-key YOUR_KEY") + console.print(" or export NVD_API_KEY=YOUR_KEY") + + +@app.command("clear") +def clear_config() -> None: + """Clear saved configuration.""" + if CONFIG_FILE.exists(): + CONFIG_FILE.unlink() + console.print("[green]Configuration cleared[/green]") + else: + console.print("[yellow]No configuration file found[/yellow]") diff --git a/src/nvd/cli/commands/cpe.py b/src/nvd/cli/commands/cpe.py new file mode 100644 index 0000000..16a6cbc --- /dev/null +++ b/src/nvd/cli/commands/cpe.py @@ -0,0 +1,170 @@ +"""CPE CLI commands.""" + +import asyncio +import sys +from typing import Optional + +import typer +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from ...client import NVDClient +from ..formatters import ( + format_cpe_table, + format_json, + format_json_lines, + format_match_criteria_table, + format_yaml, +) + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + +app = typer.Typer( + no_args_is_help=True, + help="Query CPE (product) information from the NVD", + context_settings=CONTEXT_SETTINGS, + epilog="Examples:\n" + " # Keyword search (recommended for most searches)\n" + " nvdb cpe search --keyword 'apache'\n" + " nvdb cpe search --keyword 'windows 10'\n\n" + " # CPE match string (requires cpe:2.3 format)\n" + " nvdb cpe search --match-string 'cpe:2.3:a:microsoft:*'\n" + " nvdb cpe search --match-string 'cpe:2.3:o:linux:*'\n\n" + " # Get match criteria for a CVE\n" + " nvdb cpe matches --cve CVE-2021-44228", +) + +# Console for stderr (progress, errors, warnings) +console = Console(stderr=True) + + +@app.command("get") +def get_cpe( + cpe_id: str = typer.Argument(..., help="CPE Name UUID"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, yaml"), +) -> None: + """Get details for a specific CPE by UUID.""" + + async def _get() -> None: + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + progress.add_task(f"Fetching CPE {cpe_id}...", total=None) + cpe = await client.cpe.get_cpe(cpe_id) + + if output_format == "json": + format_json(cpe) + elif output_format == "yaml": + format_yaml(cpe) + else: + format_cpe_table([cpe]) + + asyncio.run(_get()) + + +@app.command("search") +def search_cpes( + keyword: Optional[str] = typer.Option(None, "--keyword", "-k", help="Keyword to search in CPE titles (e.g., 'windows' or 'apache')"), + match_string: Optional[str] = typer.Option(None, "--match-string", "-m", help="CPE match string in cpe:2.3 format (e.g., 'cpe:2.3:a:vendor:product:*')"), + limit: int = typer.Option(20, "--limit", "-l", help="Maximum number of results"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, yaml"), +) -> None: + """Search for CPEs. + + Use --keyword for simple text search in product names. + Use --match-string for CPE formatted strings (cpe:2.3:...). + """ + + async def _search() -> None: + from ...exceptions import NotFoundError + + results = [] + try: + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Searching CPEs...", total=None) + + async for cpe in client.cpe.search_cpes( + keyword_search=keyword, + cpe_match_string=match_string, + ): + results.append(cpe) + if len(results) >= limit: + break + progress.update(task, description=f"Searching CPEs... ({len(results)} found)") + + if not results: + console.print("[yellow]No CPEs found matching criteria[/yellow]") + return + + if output_format == "json": + format_json_lines(results) + elif output_format == "yaml": + for cpe in results: + format_yaml(cpe) + console.print("---") + else: + format_cpe_table(results) + except NotFoundError as e: + console.print(f"[red]Error:[/red] {e.message}") + if match_string: + console.print("\n[yellow]Tip:[/yellow] Use --keyword for text search, or use CPE format for --match-string") + console.print("Example: [blue]nvdb cpe search --keyword 'soft-serve'[/blue]") + console.print("Or: [blue]nvdb cpe search --match-string 'cpe:2.3:a:*:soft*'[/blue]") + raise typer.Exit(1) + + asyncio.run(_search()) + + +@app.command("matches") +def get_matches( + cve_id: Optional[str] = typer.Option(None, "--cve", help="CVE ID to get match criteria for"), + match_criteria_id: Optional[str] = typer.Option(None, "--id", help="Specific match criteria UUID"), + limit: int = typer.Option(20, "--limit", "-l", help="Maximum number of results"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, yaml"), +) -> None: + """Get CPE match criteria.""" + + async def _matches() -> None: + results = [] + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Fetching match criteria...", total=None) + + async for match in client.cpematch.search_match_criteria( + cve_id=cve_id, + match_criteria_id=match_criteria_id, + ): + results.append(match) + if len(results) >= limit: + break + progress.update(task, description=f"Fetching match criteria... ({len(results)} found)") + + if not results: + console.print("[yellow]No match criteria found[/yellow]") + return + + if output_format == "json": + format_json_lines(results) + elif output_format == "yaml": + for match in results: + format_yaml(match) + console.print("---") + else: + format_match_criteria_table(results) + + asyncio.run(_matches()) diff --git a/src/nvd/cli/commands/cve.py b/src/nvd/cli/commands/cve.py new file mode 100644 index 0000000..3f28ac7 --- /dev/null +++ b/src/nvd/cli/commands/cve.py @@ -0,0 +1,150 @@ +"""CVE CLI commands.""" + +import asyncio +import sys +from typing import Optional + +import typer +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from ...client import NVDClient +from ..formatters import format_cve_table, format_json, format_json_lines, format_yaml + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + +app = typer.Typer( + no_args_is_help=True, + help="Query CVE (vulnerability) information from the NVD", + context_settings=CONTEXT_SETTINGS, + epilog="Examples:\n" + " nvdb cve get CVE-2021-44228\n" + " nvdb cve search --keyword 'sql injection' --severity HIGH\n" + " nvdb cve search --has-kev --limit 20\n" + " nvdb cve history CVE-2021-44228", +) + +# Console for stderr (progress, errors, warnings) +console = Console(stderr=True) + + +@app.command("get") +def get_cve( + cve_id: str = typer.Argument(..., help="CVE ID (e.g., CVE-2021-44228)"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, yaml"), +) -> None: + """Get details for a specific CVE.""" + + async def _get() -> None: + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + progress.add_task(f"Fetching {cve_id}...", total=None) + cve = await client.cve.get_cve(cve_id) + + if output_format == "json": + format_json(cve) + elif output_format == "yaml": + format_yaml(cve) + else: + format_cve_table([cve]) + + asyncio.run(_get()) + + +@app.command("search") +def search_cves( + keyword: Optional[str] = typer.Option(None, "--keyword", "-k", help="Keyword to search in descriptions"), + cpe: Optional[str] = typer.Option(None, "--cpe", help="CPE name to filter by"), + severity: Optional[str] = typer.Option(None, "--severity", "-s", help="CVSS v3 severity: LOW, MEDIUM, HIGH, CRITICAL"), + cvss_v2_severity: Optional[str] = typer.Option(None, "--cvss-v2-severity", help="CVSS v2 severity"), + cvss_v3_severity: Optional[str] = typer.Option(None, "--cvss-v3-severity", help="CVSS v3 severity"), + cwe_id: Optional[str] = typer.Option(None, "--cwe", help="CWE ID (e.g., CWE-79)"), + has_kev: bool = typer.Option(False, "--has-kev", help="Only CVEs in CISA KEV catalog"), + pub_start: Optional[str] = typer.Option(None, "--pub-start", help="Publication start date (ISO-8601)"), + pub_end: Optional[str] = typer.Option(None, "--pub-end", help="Publication end date (ISO-8601)"), + mod_start: Optional[str] = typer.Option(None, "--mod-start", help="Last modified start date (ISO-8601)"), + mod_end: Optional[str] = typer.Option(None, "--mod-end", help="Last modified end date (ISO-8601)"), + limit: int = typer.Option(20, "--limit", "-l", help="Maximum number of results"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, yaml"), +) -> None: + """Search for CVEs with various filters.""" + + async def _search() -> None: + results = [] + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Searching CVEs...", total=None) + + async for cve in client.cve.search_cves( + keyword_search=keyword, + cpe_name=cpe, + cvss_v2_severity=cvss_v2_severity, + cvss_v3_severity=cvss_v3_severity or severity, + cwe_id=cwe_id, + has_kev=has_kev if has_kev else None, + pub_start_date=pub_start, + pub_end_date=pub_end, + last_mod_start_date=mod_start, + last_mod_end_date=mod_end, + ): + results.append(cve) + if len(results) >= limit: + break + progress.update(task, description=f"Searching CVEs... ({len(results)} found)") + + if not results: + console.print("[yellow]No CVEs found matching criteria[/yellow]") + return + + if output_format == "json": + format_json_lines(results) + elif output_format == "yaml": + for cve in results: + format_yaml(cve) + console.print("---") + else: + format_cve_table(results) + + asyncio.run(_search()) + + +@app.command("history") +def get_history( + cve_id: str = typer.Argument(..., help="CVE ID to get history for"), + api_key: Optional[str] = typer.Option(None, "--api-key", envvar="NVD_API_KEY", help="NVD API key"), + output_format: str = typer.Option("json", "--format", "-f", help="Output format: json, yaml"), +) -> None: + """Get change history for a CVE.""" + + async def _history() -> None: + async with NVDClient(api_key=api_key) as client: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + progress.add_task(f"Fetching history for {cve_id}...", total=None) + history = await client.history.get_cve_history(cve_id) + + if not history: + console.print(f"[yellow]No history found for {cve_id}[/yellow]") + return + + if output_format == "yaml": + for change in history: + format_yaml(change) + console.print("---") + else: + format_json_lines(history) + + asyncio.run(_history()) diff --git a/src/nvd/cli/formatters.py b/src/nvd/cli/formatters.py new file mode 100644 index 0000000..fa95e9e --- /dev/null +++ b/src/nvd/cli/formatters.py @@ -0,0 +1,137 @@ +"""Output formatters for CLI.""" + +import json +import sys +from typing import Any, List + +import yaml +from rich.console import Console +from rich.table import Table + +from ..models import CPEData, CPEMatchString, CVEData, SourceData + +# Console for data output - explicitly use stdout +console = Console(file=sys.stdout, stderr=False) + + +def format_json(data: Any) -> None: + """Format output as JSON.""" + if hasattr(data, "model_dump"): + output = data.model_dump(mode="json", exclude_none=True) + else: + output = data + print(json.dumps(output, indent=2, default=str, ensure_ascii=False)) + + +def format_json_lines(data: List[Any]) -> None: + """Format multiple items as JSON lines (one per line).""" + for item in data: + if hasattr(item, "model_dump"): + output = item.model_dump(mode="json", exclude_none=True) + else: + output = item + print(json.dumps(output, default=str)) + + +def format_yaml(data: Any) -> None: + """Format output as YAML.""" + if hasattr(data, "model_dump"): + output = data.model_dump(mode="json", exclude_none=True) + else: + output = data + print(yaml.dump(output, default_flow_style=False, sort_keys=False, allow_unicode=True), end="") + + +def format_cve_table(cves: List[CVEData]) -> None: + """Format CVEs as a table.""" + table = Table(title="CVE Results", show_header=True, header_style="bold magenta") + table.add_column("CVE ID", style="cyan", no_wrap=True) + table.add_column("Published", style="green") + table.add_column("CVSS v3", justify="right", style="yellow") + table.add_column("Status", style="blue") + table.add_column("Description", style="white", max_width=60) + + for cve in cves: + score = cve.cvss_v3_score or cve.cvss_v2_score + score_str = f"{score:.1f}" if score else "N/A" + + description = cve.description[:100] + "..." if len(cve.description) > 100 else cve.description + + table.add_row( + cve.id, + cve.published.strftime("%Y-%m-%d"), + score_str, + cve.vulnStatus, + description, + ) + + console.print(table) + + +def format_cpe_table(cpes: List[CPEData]) -> None: + """Format CPEs as a table.""" + table = Table(title="CPE Results", show_header=True, header_style="bold magenta") + table.add_column("CPE Name", style="cyan", no_wrap=True, max_width=50) + table.add_column("Title", style="green", max_width=40) + table.add_column("Deprecated", style="yellow") + table.add_column("Last Modified", style="blue") + + for cpe in cpes: + table.add_row( + cpe.cpeName, + cpe.title, + "Yes" if cpe.deprecated else "No", + cpe.lastModified.strftime("%Y-%m-%d"), + ) + + console.print(table) + + +def format_match_criteria_table(matches: List[CPEMatchString]) -> None: + """Format CPE match criteria as a table.""" + table = Table(title="CPE Match Criteria", show_header=True, header_style="bold magenta") + table.add_column("Criteria", style="cyan", max_width=50) + table.add_column("Status", style="green") + table.add_column("Version Range", style="yellow", max_width=30) + + for match in matches: + version_range = "" + if match.versionStartIncluding: + version_range += f">={match.versionStartIncluding} " + if match.versionStartExcluding: + version_range += f">{match.versionStartExcluding} " + if match.versionEndIncluding: + version_range += f"<={match.versionEndIncluding} " + if match.versionEndExcluding: + version_range += f"<{match.versionEndExcluding} " + + table.add_row( + match.criteria, + match.status, + version_range.strip() or "All versions", + ) + + console.print(table) + + +def format_source_table(sources: List[SourceData]) -> None: + """Format sources as a table.""" + table = Table(title="Data Sources", show_header=True, header_style="bold magenta") + table.add_column("Name", style="cyan") + table.add_column("Contact Email", style="green") + table.add_column("Identifiers", style="yellow", max_width=40) + table.add_column("Created", style="blue") + + for source in sources: + identifiers = ", ".join(source.sourceIdentifiers[:3]) + if len(source.sourceIdentifiers) > 3: + identifiers += f" (+{len(source.sourceIdentifiers) - 3} more)" + + table.add_row( + source.name, + source.contactEmail, + identifiers, + source.created.strftime("%Y-%m-%d"), + ) + + console.print(table) diff --git a/src/nvd/cli/main.py b/src/nvd/cli/main.py new file mode 100644 index 0000000..fc90fe5 --- /dev/null +++ b/src/nvd/cli/main.py @@ -0,0 +1,72 @@ +"""Main CLI entry point.""" + +import sys + +import typer +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +from .commands import config, cpe, cve + +# Enable -h as shorthand for --help +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + +app = typer.Typer( + name="nvdb", + help="NVD API CLI - Query the US National Vulnerability Database", + add_completion=False, + rich_markup_mode="rich", + context_settings=CONTEXT_SETTINGS, +) + +# Console for stderr (help messages, examples) +console = Console(stderr=True) + + +def show_examples() -> None: + """Show usage examples.""" + examples = Text() + examples.append("Quick Examples:\n\n", style="bold cyan") + examples.append(" # Get a specific CVE\n", style="dim") + examples.append(" nvdb cve get CVE-2021-44228\n\n", style="green") + examples.append(" # Search for critical vulnerabilities\n", style="dim") + examples.append(" nvdb cve search --severity CRITICAL --limit 10\n\n", style="green") + examples.append(" # Search CVEs in CISA KEV catalog\n", style="dim") + examples.append(" nvdb cve search --has-kev --limit 20\n\n", style="green") + examples.append(" # Search CPEs\n", style="dim") + examples.append(" nvdb cpe search --keyword 'windows 10'\n\n", style="green") + examples.append(" # Configure API key\n", style="dim") + examples.append(" nvdb config set-api-key YOUR_API_KEY\n\n", style="green") + examples.append("Get help for specific commands:\n", style="bold yellow") + examples.append(" nvdb cve --help\n", style="blue") + examples.append(" nvdb cpe --help\n", style="blue") + examples.append(" nvdb config --help\n", style="blue") + + console.print(Panel(examples, title="[bold]nvdb - NVD API CLI[/bold]", border_style="blue")) + + +@app.callback(invoke_without_command=True) +def main_callback(ctx: typer.Context) -> None: + """Main callback to show examples when no command is provided.""" + if ctx.invoked_subcommand is None and len(sys.argv) == 1: + show_examples() + raise typer.Exit() + + +# Register command groups +app.add_typer(cve.app, name="cve", help="CVE (vulnerability) commands") +app.add_typer(cpe.app, name="cpe", help="CPE (product) commands") +app.add_typer(config.app, name="config", help="Configuration commands") + + +@app.command() +def version() -> None: + """Show version information.""" + from .. import __version__ + + console.print(f"nvdb-py version: {__version__}") + + +if __name__ == "__main__": + app() diff --git a/src/nvd/client.py b/src/nvd/client.py new file mode 100644 index 0000000..0b8f0d6 --- /dev/null +++ b/src/nvd/client.py @@ -0,0 +1,213 @@ +"""Main NVD API client.""" + +import os +from typing import Any, Dict, Optional, Type, TypeVar + +import httpx +from pydantic import BaseModel + +from .exceptions import ( + AuthenticationError, + NetworkError, + NotFoundError, + RateLimitError, + ResponseError, + ServerError, + ValidationError, +) +from .rate_limiter import RateLimiter + +T = TypeVar("T", bound=BaseModel) + + +class NVDClient: + """Async client for the NVD API 2.0. + + Example: + async with NVDClient(api_key="your-key") as client: + cve = await client.cve.get_cve("CVE-2021-44228") + """ + + BASE_URL = "https://services.nvd.nist.gov/rest/json" + + def __init__( + self, + api_key: Optional[str] = None, + timeout: float = 30.0, + max_retries: int = 3, + ) -> None: + """Initialize the NVD API client. + + Args: + api_key: Optional API key for higher rate limits (50 req/30s vs 5 req/30s) + timeout: Request timeout in seconds + max_retries: Maximum number of retries for failed requests + """ + self.api_key = api_key or os.getenv("NVD_API_KEY") + self.timeout = timeout + self.max_retries = max_retries + + self._client: Optional[httpx.AsyncClient] = None + self._rate_limiter = RateLimiter() + self._rate_limiter.configure(has_api_key=bool(self.api_key)) + + # Import endpoints here to avoid circular imports + from .endpoints.cpe import CPEEndpoint + from .endpoints.cpematch import CPEMatchEndpoint + from .endpoints.cve import CVEEndpoint + from .endpoints.cve_history import CVEHistoryEndpoint + from .endpoints.source import SourceEndpoint + + self.cve = CVEEndpoint(self) + self.cpe = CPEEndpoint(self) + self.cpematch = CPEMatchEndpoint(self) + self.source = SourceEndpoint(self) + self.history = CVEHistoryEndpoint(self) + + async def __aenter__(self) -> "NVDClient": + """Async context manager entry.""" + headers = {"User-Agent": "nvdb-py/0.1.0"} + if self.api_key: + headers["apiKey"] = self.api_key + + self._client = httpx.AsyncClient( + base_url=self.BASE_URL, + headers=headers, + timeout=self.timeout, + http2=True, + ) + return self + + async def __aexit__(self, *args: Any) -> None: + """Async context manager exit.""" + if self._client: + await self._client.aclose() + + async def request( + self, + method: str, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + response_model: Optional[Type[T]] = None, + ) -> T: + """Make an API request with rate limiting and error handling. + + Args: + method: HTTP method + endpoint: API endpoint path + params: Query parameters + response_model: Pydantic model for response validation + + Returns: + Parsed response object + + Raises: + Various NVDError subclasses based on error type + """ + if not self._client: + raise RuntimeError("Client not initialized. Use 'async with' context manager.") + + # Clean up None values from params + if params: + params = {k: v for k, v in params.items() if v is not None} + + retry_count = 0 + last_error: Optional[Exception] = None + + while retry_count < self.max_retries: + try: + # Wait for rate limiter + await self._rate_limiter.acquire() + + # Make request + response = await self._client.request( + method=method, + url=endpoint, + params=params, + ) + + # Handle response + if response.status_code == 200: + try: + data = response.json() + if response_model: + return response_model.model_validate(data) + return data # type: ignore + except Exception as e: + raise ResponseError(f"Failed to parse response: {e}", response) + + elif response.status_code == 400: + raise ValidationError( + f"Invalid request parameters: {response.text}", response + ) + + elif response.status_code == 403: + if "api key" in response.text.lower(): + raise AuthenticationError( + "Invalid or missing API key", response + ) + raise RateLimitError( + "Rate limit exceeded (403)", retry_after=30, response=response + ) + + elif response.status_code == 404: + # Provide helpful error messages for common mistakes + error_msg = f"Resource not found: {endpoint}" + if endpoint == "/cpes/2.0" and params: + if "cpeMatchString" in params: + error_msg = ( + f"Invalid CPE match string: {params['cpeMatchString']}. " + "CPE match strings must use the format 'cpe:2.3:...'. " + "For keyword searches, use keywordSearch parameter instead. " + "Example: cpe:2.3:a:vendor:product:* or use --keyword for text search." + ) + raise NotFoundError(error_msg, response) + + elif response.status_code == 429: + retry_after = int(response.headers.get("Retry-After", 30)) + raise RateLimitError( + "Rate limit exceeded (429)", retry_after=retry_after, response=response + ) + + elif response.status_code >= 500: + raise ServerError( + f"Server error ({response.status_code}): {response.text}", + response, + ) + + else: + raise ResponseError( + f"Unexpected status code {response.status_code}: {response.text}", + response, + ) + + except (RateLimitError, httpx.RequestError) as e: + last_error = e + retry_count += 1 + if retry_count < self.max_retries: + # Exponential backoff + import asyncio + await asyncio.sleep(2 ** retry_count) + continue + + except ( + AuthenticationError, + ValidationError, + NotFoundError, + ServerError, + ResponseError, + ): + raise + + except Exception as e: + raise NetworkError(f"Network error: {e}") + + # If we exhausted retries, raise the last error + if last_error: + raise last_error + raise NetworkError("Max retries exceeded") + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client: + await self._client.aclose() diff --git a/src/nvd/endpoints/__init__.py b/src/nvd/endpoints/__init__.py new file mode 100644 index 0000000..551a8fe --- /dev/null +++ b/src/nvd/endpoints/__init__.py @@ -0,0 +1 @@ +"""NVD API endpoints.""" diff --git a/src/nvd/endpoints/cpe.py b/src/nvd/endpoints/cpe.py new file mode 100644 index 0000000..9edd23a --- /dev/null +++ b/src/nvd/endpoints/cpe.py @@ -0,0 +1,93 @@ +"""CPE API endpoint.""" + +from typing import TYPE_CHECKING, AsyncIterator, Optional + +from ..models import CPEData, CPEResponse + +if TYPE_CHECKING: + from ..client import NVDClient + + +class CPEEndpoint: + """CPE (Common Platform Enumeration) API endpoint.""" + + def __init__(self, client: "NVDClient") -> None: + self.client = client + + async def get_cpe(self, cpe_name_id: str) -> CPEData: + """Get a specific CPE by UUID. + + Args: + cpe_name_id: CPE Name UUID + + Returns: + CPE data object + """ + response = await self.client.request( + "GET", + "/cpes/2.0", + params={"cpeNameId": cpe_name_id}, + response_model=CPEResponse, + ) + if not response.products: + raise ValueError(f"CPE {cpe_name_id} not found") + return response.products[0].cpe + + async def search_cpes( + self, + cpe_name_id: Optional[str] = None, + cpe_match_string: Optional[str] = None, + keyword_search: Optional[str] = None, + keyword_exact_match: Optional[bool] = None, + last_mod_start_date: Optional[str] = None, + last_mod_end_date: Optional[str] = None, + match_criteria_id: Optional[str] = None, + results_per_page: int = 10000, + start_index: int = 0, + ) -> AsyncIterator[CPEData]: + """Search for CPEs. + + Args: + cpe_name_id: CPE Name UUID + cpe_match_string: CPE match string pattern + keyword_search: Keyword to search in titles and references + keyword_exact_match: Require exact keyword match + last_mod_start_date: Last modified start date (ISO-8601) + last_mod_end_date: Last modified end date (ISO-8601) + match_criteria_id: Match criteria UUID + results_per_page: Results per page (max 10000) + start_index: Starting index for pagination + + Yields: + CPE data objects + """ + params = { + "cpeNameId": cpe_name_id, + "cpeMatchString": cpe_match_string, + "keywordSearch": keyword_search, + "keywordExactMatch": keyword_exact_match, + "lastModStartDate": last_mod_start_date, + "lastModEndDate": last_mod_end_date, + "matchCriteriaId": match_criteria_id, + "resultsPerPage": results_per_page, + "startIndex": start_index, + } + + current_index = start_index + while True: + params["startIndex"] = current_index + response = await self.client.request( + "GET", + "/cpes/2.0", + params=params, + response_model=CPEResponse, + ) + + for item in response.products: + yield item.cpe + + # Check if there are more results + if current_index + response.resultsPerPage >= response.totalResults: + break + + current_index += response.resultsPerPage diff --git a/src/nvd/endpoints/cpematch.py b/src/nvd/endpoints/cpematch.py new file mode 100644 index 0000000..0575859 --- /dev/null +++ b/src/nvd/endpoints/cpematch.py @@ -0,0 +1,101 @@ +"""CPE Match Criteria API endpoint.""" + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from ..models import CPEMatchResponse, CPEMatchString + +if TYPE_CHECKING: + from ..client import NVDClient + + +class CPEMatchEndpoint: + """CPE Match Criteria API endpoint.""" + + def __init__(self, client: "NVDClient") -> None: + self.client = client + + async def get_match_criteria(self, match_criteria_id: str) -> CPEMatchString: + """Get specific match criteria by UUID. + + Args: + match_criteria_id: Match criteria UUID + + Returns: + Match criteria object + """ + response = await self.client.request( + "GET", + "/cpematch/2.0", + params={"matchCriteriaId": match_criteria_id}, + response_model=CPEMatchResponse, + ) + if not response.matchStrings: + raise ValueError(f"Match criteria {match_criteria_id} not found") + return response.matchStrings[0] + + async def get_cve_match_criteria(self, cve_id: str) -> List[CPEMatchString]: + """Get all match criteria for a specific CVE. + + Args: + cve_id: CVE identifier + + Returns: + List of match criteria objects + """ + results: List[CPEMatchString] = [] + async for match in self.search_match_criteria(cve_id=cve_id): + results.append(match) + return results + + async def search_match_criteria( + self, + cve_id: Optional[str] = None, + match_criteria_id: Optional[str] = None, + match_string_search: Optional[str] = None, + last_mod_start_date: Optional[str] = None, + last_mod_end_date: Optional[str] = None, + results_per_page: int = 500, + start_index: int = 0, + ) -> AsyncIterator[CPEMatchString]: + """Search for CPE match criteria. + + Args: + cve_id: CVE identifier to get match strings for + match_criteria_id: Specific match criteria UUID + match_string_search: Match string pattern to search + last_mod_start_date: Last modified start date (ISO-8601) + last_mod_end_date: Last modified end date (ISO-8601) + results_per_page: Results per page (max 500) + start_index: Starting index for pagination + + Yields: + Match criteria objects + """ + params = { + "cveId": cve_id, + "matchCriteriaId": match_criteria_id, + "matchStringSearch": match_string_search, + "lastModStartDate": last_mod_start_date, + "lastModEndDate": last_mod_end_date, + "resultsPerPage": results_per_page, + "startIndex": start_index, + } + + current_index = start_index + while True: + params["startIndex"] = current_index + response = await self.client.request( + "GET", + "/cpematch/2.0", + params=params, + response_model=CPEMatchResponse, + ) + + for match_string in response.matchStrings: + yield match_string + + # Check if there are more results + if current_index + response.resultsPerPage >= response.totalResults: + break + + current_index += response.resultsPerPage diff --git a/src/nvd/endpoints/cve.py b/src/nvd/endpoints/cve.py new file mode 100644 index 0000000..5029f59 --- /dev/null +++ b/src/nvd/endpoints/cve.py @@ -0,0 +1,205 @@ +"""CVE API endpoint.""" + +from datetime import datetime +from typing import TYPE_CHECKING, AsyncIterator, Optional + +from ..models import CVEData, CVEResponse + +if TYPE_CHECKING: + from ..client import NVDClient + + +class CVEEndpoint: + """CVE API endpoint with full parameter support.""" + + def __init__(self, client: "NVDClient") -> None: + self.client = client + + async def get_cve(self, cve_id: str) -> CVEData: + """Get a specific CVE by ID. + + Args: + cve_id: CVE identifier (e.g., "CVE-2021-44228") + + Returns: + CVE data object + """ + response = await self.client.request( + "GET", + "/cves/2.0", + params={"cveId": cve_id}, + response_model=CVEResponse, + ) + if not response.vulnerabilities: + raise ValueError(f"CVE {cve_id} not found") + return response.vulnerabilities[0].cve + + async def search_cves( + self, + # CVE identification + cve_id: Optional[str] = None, + # CPE filtering + cpe_name: Optional[str] = None, + virtual_match_string: Optional[str] = None, + # Date ranges (ISO-8601 format, max 120 days) + pub_start_date: Optional[str] = None, + pub_end_date: Optional[str] = None, + last_mod_start_date: Optional[str] = None, + last_mod_end_date: Optional[str] = None, + kev_start_date: Optional[str] = None, + kev_end_date: Optional[str] = None, + # CVSS v2 filtering + cvss_v2_severity: Optional[str] = None, # LOW, MEDIUM, HIGH + cvss_v2_metrics: Optional[str] = None, + # CVSS v3 filtering + cvss_v3_severity: Optional[str] = None, # LOW, MEDIUM, HIGH, CRITICAL + cvss_v3_metrics: Optional[str] = None, + # CVSS v4 filtering + cvss_v4_severity: Optional[str] = None, # LOW, MEDIUM, HIGH, CRITICAL + cvss_v4_metrics: Optional[str] = None, + # CWE filtering + cwe_id: Optional[str] = None, # e.g., "CWE-79" + # Boolean filters + has_cert_alerts: Optional[bool] = None, + has_cert_notes: Optional[bool] = None, + has_kev: Optional[bool] = None, + has_oval: Optional[bool] = None, + is_vulnerable: Optional[bool] = None, + no_rejected: Optional[bool] = None, + # Keyword search + keyword_search: Optional[str] = None, + keyword_exact_match: Optional[bool] = None, + # Source + source_identifier: Optional[str] = None, + # Version filtering (requires cpe_name) + version_start: Optional[str] = None, + version_start_type: Optional[str] = None, # "including" or "excluding" + version_end: Optional[str] = None, + version_end_type: Optional[str] = None, # "including" or "excluding" + # Pagination + results_per_page: int = 2000, + start_index: int = 0, + ) -> AsyncIterator[CVEData]: + """Search for CVEs with extensive filtering options. + + Args: + cve_id: Specific CVE identifier + cpe_name: CPE 2.3 name + virtual_match_string: Virtual CPE match string + pub_start_date: Publication start date (ISO-8601) + pub_end_date: Publication end date (ISO-8601) + last_mod_start_date: Last modified start date (ISO-8601) + last_mod_end_date: Last modified end date (ISO-8601) + kev_start_date: KEV catalog start date (ISO-8601) + kev_end_date: KEV catalog end date (ISO-8601) + cvss_v2_severity: CVSS v2 severity (LOW, MEDIUM, HIGH) + cvss_v2_metrics: CVSS v2 vector string + cvss_v3_severity: CVSS v3 severity (LOW, MEDIUM, HIGH, CRITICAL) + cvss_v3_metrics: CVSS v3 vector string + cvss_v4_severity: CVSS v4 severity (LOW, MEDIUM, HIGH, CRITICAL) + cvss_v4_metrics: CVSS v4 vector string + cwe_id: CWE identifier (e.g., "CWE-79") + has_cert_alerts: Filter for CERT alerts + has_cert_notes: Filter for CERT notes + has_kev: Filter for CISA KEV catalog entries + has_oval: Filter for OVAL records + is_vulnerable: Filter for vulnerable CPE configurations + no_rejected: Exclude rejected CVEs + keyword_search: Keyword to search in descriptions + keyword_exact_match: Require exact keyword match + source_identifier: Data source identifier + version_start: Start version for CPE filtering + version_start_type: "including" or "excluding" + version_end: End version for CPE filtering + version_end_type: "including" or "excluding" + results_per_page: Results per page (max 2000) + start_index: Starting index for pagination + + Yields: + CVE data objects + """ + params = { + "cveId": cve_id, + "cpeName": cpe_name, + "virtualMatchString": virtual_match_string, + "pubStartDate": pub_start_date, + "pubEndDate": pub_end_date, + "lastModStartDate": last_mod_start_date, + "lastModEndDate": last_mod_end_date, + "kevStartDate": kev_start_date, + "kevEndDate": kev_end_date, + "cvssV2Severity": cvss_v2_severity, + "cvssV2Metrics": cvss_v2_metrics, + "cvssV3Severity": cvss_v3_severity, + "cvssV3Metrics": cvss_v3_metrics, + "cvssV4Severity": cvss_v4_severity, + "cvssV4Metrics": cvss_v4_metrics, + "cweId": cwe_id, + "hasCertAlerts": has_cert_alerts, + "hasCertNotes": has_cert_notes, + "hasKev": has_kev, + "hasOval": has_oval, + "isVulnerable": is_vulnerable, + "noRejected": no_rejected, + "keywordSearch": keyword_search, + "keywordExactMatch": keyword_exact_match, + "sourceIdentifier": source_identifier, + "versionStart": version_start, + "versionStartType": version_start_type, + "versionEnd": version_end, + "versionEndType": version_end_type, + "resultsPerPage": results_per_page, + "startIndex": start_index, + } + + current_index = start_index + while True: + params["startIndex"] = current_index + response = await self.client.request( + "GET", + "/cves/2.0", + params=params, + response_model=CVEResponse, + ) + + for item in response.vulnerabilities: + yield item.cve + + # Check if there are more results + if current_index + response.resultsPerPage >= response.totalResults: + break + + current_index += response.resultsPerPage + + async def get_cves_by_cpe( + self, cpe_name: str, **kwargs: object + ) -> AsyncIterator[CVEData]: + """Get CVEs for a specific CPE. + + Args: + cpe_name: CPE 2.3 name + **kwargs: Additional search parameters + + Yields: + CVE data objects + """ + async for cve in self.search_cves(cpe_name=cpe_name, **kwargs): + yield cve + + async def get_cves_by_keyword( + self, keyword: str, exact_match: bool = False, **kwargs: object + ) -> AsyncIterator[CVEData]: + """Search CVEs by keyword. + + Args: + keyword: Keyword to search + exact_match: Require exact match + **kwargs: Additional search parameters + + Yields: + CVE data objects + """ + async for cve in self.search_cves( + keyword_search=keyword, keyword_exact_match=exact_match, **kwargs + ): + yield cve diff --git a/src/nvd/endpoints/cve_history.py b/src/nvd/endpoints/cve_history.py new file mode 100644 index 0000000..c20670d --- /dev/null +++ b/src/nvd/endpoints/cve_history.py @@ -0,0 +1,79 @@ +"""CVE Change History API endpoint.""" + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from ..models import CVEChange, CVEChangeResponse + +if TYPE_CHECKING: + from ..client import NVDClient + + +class CVEHistoryEndpoint: + """CVE Change History API endpoint.""" + + def __init__(self, client: "NVDClient") -> None: + self.client = client + + async def get_cve_history(self, cve_id: str) -> List[CVEChange]: + """Get complete change history for a CVE. + + Args: + cve_id: CVE identifier + + Returns: + List of change events + """ + changes: List[CVEChange] = [] + async for change in self.search_changes(cve_id=cve_id): + changes.append(change) + return changes + + async def search_changes( + self, + cve_id: Optional[str] = None, + change_start_date: Optional[str] = None, + change_end_date: Optional[str] = None, + event_name: Optional[str] = None, + results_per_page: int = 5000, + start_index: int = 0, + ) -> AsyncIterator[CVEChange]: + """Search CVE change history. + + Args: + cve_id: CVE identifier to get history for + change_start_date: Change start date (ISO-8601) + change_end_date: Change end date (ISO-8601) + event_name: Event type (e.g., "Initial Analysis", "CVE Modified") + results_per_page: Results per page (max 5000) + start_index: Starting index for pagination + + Yields: + Change event objects + """ + params = { + "cveId": cve_id, + "changeStartDate": change_start_date, + "changeEndDate": change_end_date, + "eventName": event_name, + "resultsPerPage": results_per_page, + "startIndex": start_index, + } + + current_index = start_index + while True: + params["startIndex"] = current_index + response = await self.client.request( + "GET", + "/cvehistory/2.0", + params=params, + response_model=CVEChangeResponse, + ) + + for change in response.cveChanges: + yield change + + # Check if there are more results + if current_index + response.resultsPerPage >= response.totalResults: + break + + current_index += response.resultsPerPage diff --git a/src/nvd/endpoints/source.py b/src/nvd/endpoints/source.py new file mode 100644 index 0000000..6235c96 --- /dev/null +++ b/src/nvd/endpoints/source.py @@ -0,0 +1,81 @@ +"""Source API endpoint.""" + +from typing import TYPE_CHECKING, AsyncIterator, Optional + +from ..models import SourceData, SourceResponse + +if TYPE_CHECKING: + from ..client import NVDClient + + +class SourceEndpoint: + """Source API endpoint for data source organizations.""" + + def __init__(self, client: "NVDClient") -> None: + self.client = client + + async def get_source(self, source_identifier: str) -> SourceData: + """Get a specific source by identifier. + + Args: + source_identifier: Source identifier + + Returns: + Source data object + """ + response = await self.client.request( + "GET", + "/source/2.0", + params={"sourceIdentifier": source_identifier}, + response_model=SourceResponse, + ) + if not response.sources: + raise ValueError(f"Source {source_identifier} not found") + return response.sources[0] + + async def list_sources( + self, + source_identifier: Optional[str] = None, + last_mod_start_date: Optional[str] = None, + last_mod_end_date: Optional[str] = None, + results_per_page: int = 1000, + start_index: int = 0, + ) -> AsyncIterator[SourceData]: + """List data sources. + + Args: + source_identifier: Filter by specific source identifier + last_mod_start_date: Last modified start date (ISO-8601) + last_mod_end_date: Last modified end date (ISO-8601) + results_per_page: Results per page (max 1000) + start_index: Starting index for pagination + + Yields: + Source data objects + """ + params = { + "sourceIdentifier": source_identifier, + "lastModStartDate": last_mod_start_date, + "lastModEndDate": last_mod_end_date, + "resultsPerPage": results_per_page, + "startIndex": start_index, + } + + current_index = start_index + while True: + params["startIndex"] = current_index + response = await self.client.request( + "GET", + "/source/2.0", + params=params, + response_model=SourceResponse, + ) + + for source in response.sources: + yield source + + # Check if there are more results + if current_index + response.resultsPerPage >= response.totalResults: + break + + current_index += response.resultsPerPage diff --git a/src/nvd/exceptions.py b/src/nvd/exceptions.py new file mode 100644 index 0000000..a4c774b --- /dev/null +++ b/src/nvd/exceptions.py @@ -0,0 +1,61 @@ +"""Custom exceptions for the NVD API client.""" + +from typing import Any, Optional + + +class NVDError(Exception): + """Base exception for all NVD API errors.""" + + def __init__(self, message: str, response: Optional[Any] = None) -> None: + super().__init__(message) + self.message = message + self.response = response + + +class RateLimitError(NVDError): + """Raised when rate limit is exceeded.""" + + def __init__( + self, + message: str = "Rate limit exceeded", + retry_after: Optional[int] = None, + response: Optional[Any] = None, + ) -> None: + super().__init__(message, response) + self.retry_after = retry_after + + +class AuthenticationError(NVDError): + """Raised when authentication fails (invalid API key).""" + + pass + + +class ValidationError(NVDError): + """Raised when request parameters are invalid.""" + + pass + + +class NotFoundError(NVDError): + """Raised when a resource is not found (404).""" + + pass + + +class ServerError(NVDError): + """Raised when the NVD API returns a server error (5xx).""" + + pass + + +class NetworkError(NVDError): + """Raised when a network error occurs.""" + + pass + + +class ResponseError(NVDError): + """Raised when response parsing fails.""" + + pass diff --git a/src/nvd/models.py b/src/nvd/models.py new file mode 100644 index 0000000..77fb7c4 --- /dev/null +++ b/src/nvd/models.py @@ -0,0 +1,315 @@ +"""Pydantic models for NVD API responses.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +# Common models +class LangString(BaseModel): + """Multilingual string.""" + + lang: str + value: str + + +# CVSS Models +class CVSSMetricV2(BaseModel): + """CVSS v2.0 metrics.""" + + source: str + type: str + cvssData: Dict[str, Any] + baseSeverity: Optional[str] = None + exploitabilityScore: Optional[float] = None + impactScore: Optional[float] = None + acInsufInfo: Optional[bool] = None + obtainAllPrivilege: Optional[bool] = None + obtainUserPrivilege: Optional[bool] = None + obtainOtherPrivilege: Optional[bool] = None + userInteractionRequired: Optional[bool] = None + + +class CVSSMetricV3(BaseModel): + """CVSS v3.x metrics.""" + + source: str + type: str + cvssData: Dict[str, Any] + exploitabilityScore: Optional[float] = None + impactScore: Optional[float] = None + + +class CVSSMetricV4(BaseModel): + """CVSS v4.0 metrics.""" + + source: str + type: str + cvssData: Dict[str, Any] + + +class Metrics(BaseModel): + """CVSS metrics container.""" + + cvssMetricV2: Optional[List[CVSSMetricV2]] = None + cvssMetricV30: Optional[List[CVSSMetricV3]] = None + cvssMetricV31: Optional[List[CVSSMetricV3]] = None + cvssMetricV40: Optional[List[CVSSMetricV4]] = None + + +# CPE Models +class CPEMatch(BaseModel): + """CPE match criteria.""" + + vulnerable: bool + criteria: str + matchCriteriaId: str + versionStartIncluding: Optional[str] = None + versionStartExcluding: Optional[str] = None + versionEndIncluding: Optional[str] = None + versionEndExcluding: Optional[str] = None + + +class Node(BaseModel): + """Configuration node.""" + + operator: str + negate: bool + cpeMatch: List[CPEMatch] + + +class Configuration(BaseModel): + """CVE configuration.""" + + nodes: List[Node] + + +# Weakness Models +class WeaknessDescription(BaseModel): + """CWE description.""" + + lang: str + value: str + + +class Weakness(BaseModel): + """CWE weakness.""" + + source: str + type: str + description: List[WeaknessDescription] + + +# Reference Models +class Reference(BaseModel): + """External reference.""" + + url: str + source: Optional[str] = None + tags: Optional[List[str]] = None + + +# CVE Models +class CVEData(BaseModel): + """Core CVE data.""" + + id: str = Field(alias="cveId") + sourceIdentifier: str + published: datetime + lastModified: datetime + vulnStatus: str + cveTags: Optional[List[Dict[str, str]]] = None + descriptions: List[LangString] + metrics: Optional[Metrics] = None + weaknesses: Optional[List[Weakness]] = None + configurations: Optional[List[Configuration]] = None + references: List[Reference] + vendorComments: Optional[List[Dict[str, Any]]] = None + + class Config: + populate_by_name = True + + @property + def description(self) -> str: + """Get English description.""" + for desc in self.descriptions: + if desc.lang == "en": + return desc.value + return self.descriptions[0].value if self.descriptions else "" + + @property + def cvss_v3_score(self) -> Optional[float]: + """Get primary CVSS v3.x base score.""" + if self.metrics: + if self.metrics.cvssMetricV31: + return self.metrics.cvssMetricV31[0].cvssData.get("baseScore") + if self.metrics.cvssMetricV30: + return self.metrics.cvssMetricV30[0].cvssData.get("baseScore") + return None + + @property + def cvss_v2_score(self) -> Optional[float]: + """Get CVSS v2.0 base score.""" + if self.metrics and self.metrics.cvssMetricV2: + return self.metrics.cvssMetricV2[0].cvssData.get("baseScore") + return None + + +class CVEItem(BaseModel): + """CVE item wrapper.""" + + cve: CVEData + + +class CVEResponse(BaseModel): + """CVE API response.""" + + resultsPerPage: int + startIndex: int + totalResults: int + format: str + version: str + timestamp: datetime + vulnerabilities: List[CVEItem] + + +# CVE Change History Models +class ChangeDetail(BaseModel): + """Individual change detail.""" + + action: Optional[str] = None + type: Optional[str] = None + oldValue: Optional[str] = None + newValue: Optional[str] = None + + +class CVEChange(BaseModel): + """CVE change event.""" + + cveId: str + eventName: str + cveChangeId: str + sourceIdentifier: str + created: datetime + details: Optional[List[ChangeDetail]] = None + + +class CVEChangeResponse(BaseModel): + """CVE Change History API response.""" + + resultsPerPage: int + startIndex: int + totalResults: int + format: str + version: str + timestamp: datetime + cveChanges: List[CVEChange] + + +# CPE Models +class CPETitle(BaseModel): + """CPE title in different languages.""" + + title: str + lang: str + + +class CPERef(BaseModel): + """CPE reference.""" + + ref: str + type: Optional[str] = None + + +class CPEData(BaseModel): + """CPE product data.""" + + cpeName: str + cpeNameId: str + created: datetime + lastModified: datetime + deprecated: bool = False + titles: Optional[List[CPETitle]] = None + refs: Optional[List[CPERef]] = None + deprecatedBy: Optional[List[Dict[str, str]]] = None + + @property + def title(self) -> str: + """Get English title.""" + if self.titles: + for t in self.titles: + if t.lang == "en": + return t.title + return self.titles[0].title + return "" + + +class CPEItem(BaseModel): + """CPE item wrapper from API response.""" + + cpe: CPEData + + +class CPEResponse(BaseModel): + """CPE API response.""" + + resultsPerPage: int + startIndex: int + totalResults: int + format: str + version: str + timestamp: datetime + products: List[CPEItem] + + +# CPE Match Criteria Models +class CPEMatchString(BaseModel): + """CPE match string data.""" + + matchCriteriaId: str + criteria: str + lastModified: datetime + cpeLastModified: Optional[datetime] = None + created: datetime + status: str + matches: Optional[List[Dict[str, str]]] = None + versionStartIncluding: Optional[str] = None + versionStartExcluding: Optional[str] = None + versionEndIncluding: Optional[str] = None + versionEndExcluding: Optional[str] = None + + +class CPEMatchResponse(BaseModel): + """CPE Match Criteria API response.""" + + resultsPerPage: int + startIndex: int + totalResults: int + format: str + version: str + timestamp: datetime + matchStrings: List[CPEMatchString] + + +# Source Models +class SourceData(BaseModel): + """Data source organization.""" + + name: str + contactEmail: str + sourceIdentifiers: List[str] + lastModified: datetime + created: datetime + + +class SourceResponse(BaseModel): + """Source API response.""" + + resultsPerPage: int + startIndex: int + totalResults: int + format: str + version: str + timestamp: datetime + sources: List[SourceData] diff --git a/src/nvd/py.typed b/src/nvd/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/nvd/py.typed diff --git a/src/nvd/rate_limiter.py b/src/nvd/rate_limiter.py new file mode 100644 index 0000000..be0fddf --- /dev/null +++ b/src/nvd/rate_limiter.py @@ -0,0 +1,67 @@ +"""Async rate limiter for NVD API requests.""" + +import asyncio +import time +from collections import deque +from typing import Deque + + +class RateLimiter: + """Sliding window rate limiter for async operations. + + NVD API rate limits: + - Without API key: 5 requests per 30 seconds + - With API key: 50 requests per 30 seconds + """ + + def __init__(self, max_requests: int = 5, window_seconds: int = 30) -> None: + """Initialize rate limiter. + + Args: + max_requests: Maximum number of requests allowed in the time window + window_seconds: Time window in seconds + """ + self.max_requests = max_requests + self.window_seconds = window_seconds + self._requests: Deque[float] = deque() + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + """Acquire permission to make a request, waiting if necessary.""" + async with self._lock: + now = time.time() + + # Remove requests outside the current window + while self._requests and self._requests[0] <= now - self.window_seconds: + self._requests.popleft() + + # If at limit, wait until oldest request expires + if len(self._requests) >= self.max_requests: + sleep_time = self._requests[0] + self.window_seconds - now + if sleep_time > 0: + await asyncio.sleep(sleep_time) + # Recursively try again after waiting + return await self.acquire() + + # Record this request + self._requests.append(now) + + def configure(self, has_api_key: bool) -> None: + """Configure rate limiter based on whether an API key is present. + + Args: + has_api_key: Whether an API key is being used + """ + if has_api_key: + self.max_requests = 50 + else: + self.max_requests = 5 + + async def __aenter__(self) -> "RateLimiter": + """Async context manager entry.""" + await self.acquire() + return self + + async def __aexit__(self, *args: object) -> None: + """Async context manager exit.""" + pass |
