aboutsummaryrefslogtreecommitdiffstats
path: root/src/nvd/client.py
diff options
context:
space:
mode:
authorLouis Burda <dev@sinitax.com>2026-01-30 03:04:01 +0100
committerLouis Burda <dev@sinitax.com>2026-01-30 03:04:01 +0100
commitf6487c615cff023db1574e2c23db78bf02a43709 (patch)
tree8a0e793a8ea28b2a5eef5dcd509b6c6a2466ee1c /src/nvd/client.py
downloadnvdb-py-main.tar.gz
nvdb-py-main.zip
Add initial versionHEADmain
Diffstat (limited to 'src/nvd/client.py')
-rw-r--r--src/nvd/client.py213
1 files changed, 213 insertions, 0 deletions
diff --git a/src/nvd/client.py b/src/nvd/client.py
new file mode 100644
index 0000000..0b8f0d6
--- /dev/null
+++ b/src/nvd/client.py
@@ -0,0 +1,213 @@
+"""Main NVD API client."""
+
+import os
+from typing import Any, Dict, Optional, Type, TypeVar
+
+import httpx
+from pydantic import BaseModel
+
+from .exceptions import (
+ AuthenticationError,
+ NetworkError,
+ NotFoundError,
+ RateLimitError,
+ ResponseError,
+ ServerError,
+ ValidationError,
+)
+from .rate_limiter import RateLimiter
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class NVDClient:
+ """Async client for the NVD API 2.0.
+
+ Example:
+ async with NVDClient(api_key="your-key") as client:
+ cve = await client.cve.get_cve("CVE-2021-44228")
+ """
+
+ BASE_URL = "https://services.nvd.nist.gov/rest/json"
+
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ timeout: float = 30.0,
+ max_retries: int = 3,
+ ) -> None:
+ """Initialize the NVD API client.
+
+ Args:
+ api_key: Optional API key for higher rate limits (50 req/30s vs 5 req/30s)
+ timeout: Request timeout in seconds
+ max_retries: Maximum number of retries for failed requests
+ """
+ self.api_key = api_key or os.getenv("NVD_API_KEY")
+ self.timeout = timeout
+ self.max_retries = max_retries
+
+ self._client: Optional[httpx.AsyncClient] = None
+ self._rate_limiter = RateLimiter()
+ self._rate_limiter.configure(has_api_key=bool(self.api_key))
+
+ # Import endpoints here to avoid circular imports
+ from .endpoints.cpe import CPEEndpoint
+ from .endpoints.cpematch import CPEMatchEndpoint
+ from .endpoints.cve import CVEEndpoint
+ from .endpoints.cve_history import CVEHistoryEndpoint
+ from .endpoints.source import SourceEndpoint
+
+ self.cve = CVEEndpoint(self)
+ self.cpe = CPEEndpoint(self)
+ self.cpematch = CPEMatchEndpoint(self)
+ self.source = SourceEndpoint(self)
+ self.history = CVEHistoryEndpoint(self)
+
+ async def __aenter__(self) -> "NVDClient":
+ """Async context manager entry."""
+ headers = {"User-Agent": "nvdb-py/0.1.0"}
+ if self.api_key:
+ headers["apiKey"] = self.api_key
+
+ self._client = httpx.AsyncClient(
+ base_url=self.BASE_URL,
+ headers=headers,
+ timeout=self.timeout,
+ http2=True,
+ )
+ return self
+
+ async def __aexit__(self, *args: Any) -> None:
+ """Async context manager exit."""
+ if self._client:
+ await self._client.aclose()
+
+ async def request(
+ self,
+ method: str,
+ endpoint: str,
+ params: Optional[Dict[str, Any]] = None,
+ response_model: Optional[Type[T]] = None,
+ ) -> T:
+ """Make an API request with rate limiting and error handling.
+
+ Args:
+ method: HTTP method
+ endpoint: API endpoint path
+ params: Query parameters
+ response_model: Pydantic model for response validation
+
+ Returns:
+ Parsed response object
+
+ Raises:
+ Various NVDError subclasses based on error type
+ """
+ if not self._client:
+ raise RuntimeError("Client not initialized. Use 'async with' context manager.")
+
+ # Clean up None values from params
+ if params:
+ params = {k: v for k, v in params.items() if v is not None}
+
+ retry_count = 0
+ last_error: Optional[Exception] = None
+
+ while retry_count < self.max_retries:
+ try:
+ # Wait for rate limiter
+ await self._rate_limiter.acquire()
+
+ # Make request
+ response = await self._client.request(
+ method=method,
+ url=endpoint,
+ params=params,
+ )
+
+ # Handle response
+ if response.status_code == 200:
+ try:
+ data = response.json()
+ if response_model:
+ return response_model.model_validate(data)
+ return data # type: ignore
+ except Exception as e:
+ raise ResponseError(f"Failed to parse response: {e}", response)
+
+ elif response.status_code == 400:
+ raise ValidationError(
+ f"Invalid request parameters: {response.text}", response
+ )
+
+ elif response.status_code == 403:
+ if "api key" in response.text.lower():
+ raise AuthenticationError(
+ "Invalid or missing API key", response
+ )
+ raise RateLimitError(
+ "Rate limit exceeded (403)", retry_after=30, response=response
+ )
+
+ elif response.status_code == 404:
+ # Provide helpful error messages for common mistakes
+ error_msg = f"Resource not found: {endpoint}"
+ if endpoint == "/cpes/2.0" and params:
+ if "cpeMatchString" in params:
+ error_msg = (
+ f"Invalid CPE match string: {params['cpeMatchString']}. "
+ "CPE match strings must use the format 'cpe:2.3:...'. "
+ "For keyword searches, use keywordSearch parameter instead. "
+ "Example: cpe:2.3:a:vendor:product:* or use --keyword for text search."
+ )
+ raise NotFoundError(error_msg, response)
+
+ elif response.status_code == 429:
+ retry_after = int(response.headers.get("Retry-After", 30))
+ raise RateLimitError(
+ "Rate limit exceeded (429)", retry_after=retry_after, response=response
+ )
+
+ elif response.status_code >= 500:
+ raise ServerError(
+ f"Server error ({response.status_code}): {response.text}",
+ response,
+ )
+
+ else:
+ raise ResponseError(
+ f"Unexpected status code {response.status_code}: {response.text}",
+ response,
+ )
+
+ except (RateLimitError, httpx.RequestError) as e:
+ last_error = e
+ retry_count += 1
+ if retry_count < self.max_retries:
+ # Exponential backoff
+ import asyncio
+ await asyncio.sleep(2 ** retry_count)
+ continue
+
+ except (
+ AuthenticationError,
+ ValidationError,
+ NotFoundError,
+ ServerError,
+ ResponseError,
+ ):
+ raise
+
+ except Exception as e:
+ raise NetworkError(f"Network error: {e}")
+
+ # If we exhausted retries, raise the last error
+ if last_error:
+ raise last_error
+ raise NetworkError("Max retries exceeded")
+
+ async def close(self) -> None:
+ """Close the HTTP client."""
+ if self._client:
+ await self._client.aclose()