diff options
| author | Louis Burda <dev@sinitax.com> | 2026-01-30 03:04:01 +0100 |
|---|---|---|
| committer | Louis Burda <dev@sinitax.com> | 2026-01-30 03:04:01 +0100 |
| commit | f6487c615cff023db1574e2c23db78bf02a43709 (patch) | |
| tree | 8a0e793a8ea28b2a5eef5dcd509b6c6a2466ee1c /src/nvd/client.py | |
| download | nvdb-py-main.tar.gz nvdb-py-main.zip | |
Diffstat (limited to 'src/nvd/client.py')
| -rw-r--r-- | src/nvd/client.py | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/src/nvd/client.py b/src/nvd/client.py new file mode 100644 index 0000000..0b8f0d6 --- /dev/null +++ b/src/nvd/client.py @@ -0,0 +1,213 @@ +"""Main NVD API client.""" + +import os +from typing import Any, Dict, Optional, Type, TypeVar + +import httpx +from pydantic import BaseModel + +from .exceptions import ( + AuthenticationError, + NetworkError, + NotFoundError, + RateLimitError, + ResponseError, + ServerError, + ValidationError, +) +from .rate_limiter import RateLimiter + +T = TypeVar("T", bound=BaseModel) + + +class NVDClient: + """Async client for the NVD API 2.0. + + Example: + async with NVDClient(api_key="your-key") as client: + cve = await client.cve.get_cve("CVE-2021-44228") + """ + + BASE_URL = "https://services.nvd.nist.gov/rest/json" + + def __init__( + self, + api_key: Optional[str] = None, + timeout: float = 30.0, + max_retries: int = 3, + ) -> None: + """Initialize the NVD API client. + + Args: + api_key: Optional API key for higher rate limits (50 req/30s vs 5 req/30s) + timeout: Request timeout in seconds + max_retries: Maximum number of retries for failed requests + """ + self.api_key = api_key or os.getenv("NVD_API_KEY") + self.timeout = timeout + self.max_retries = max_retries + + self._client: Optional[httpx.AsyncClient] = None + self._rate_limiter = RateLimiter() + self._rate_limiter.configure(has_api_key=bool(self.api_key)) + + # Import endpoints here to avoid circular imports + from .endpoints.cpe import CPEEndpoint + from .endpoints.cpematch import CPEMatchEndpoint + from .endpoints.cve import CVEEndpoint + from .endpoints.cve_history import CVEHistoryEndpoint + from .endpoints.source import SourceEndpoint + + self.cve = CVEEndpoint(self) + self.cpe = CPEEndpoint(self) + self.cpematch = CPEMatchEndpoint(self) + self.source = SourceEndpoint(self) + self.history = CVEHistoryEndpoint(self) + + async def __aenter__(self) -> "NVDClient": + """Async context manager entry.""" + headers = {"User-Agent": "nvdb-py/0.1.0"} + if self.api_key: + headers["apiKey"] = self.api_key + + self._client = httpx.AsyncClient( + base_url=self.BASE_URL, + headers=headers, + timeout=self.timeout, + http2=True, + ) + return self + + async def __aexit__(self, *args: Any) -> None: + """Async context manager exit.""" + if self._client: + await self._client.aclose() + + async def request( + self, + method: str, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + response_model: Optional[Type[T]] = None, + ) -> T: + """Make an API request with rate limiting and error handling. + + Args: + method: HTTP method + endpoint: API endpoint path + params: Query parameters + response_model: Pydantic model for response validation + + Returns: + Parsed response object + + Raises: + Various NVDError subclasses based on error type + """ + if not self._client: + raise RuntimeError("Client not initialized. Use 'async with' context manager.") + + # Clean up None values from params + if params: + params = {k: v for k, v in params.items() if v is not None} + + retry_count = 0 + last_error: Optional[Exception] = None + + while retry_count < self.max_retries: + try: + # Wait for rate limiter + await self._rate_limiter.acquire() + + # Make request + response = await self._client.request( + method=method, + url=endpoint, + params=params, + ) + + # Handle response + if response.status_code == 200: + try: + data = response.json() + if response_model: + return response_model.model_validate(data) + return data # type: ignore + except Exception as e: + raise ResponseError(f"Failed to parse response: {e}", response) + + elif response.status_code == 400: + raise ValidationError( + f"Invalid request parameters: {response.text}", response + ) + + elif response.status_code == 403: + if "api key" in response.text.lower(): + raise AuthenticationError( + "Invalid or missing API key", response + ) + raise RateLimitError( + "Rate limit exceeded (403)", retry_after=30, response=response + ) + + elif response.status_code == 404: + # Provide helpful error messages for common mistakes + error_msg = f"Resource not found: {endpoint}" + if endpoint == "/cpes/2.0" and params: + if "cpeMatchString" in params: + error_msg = ( + f"Invalid CPE match string: {params['cpeMatchString']}. " + "CPE match strings must use the format 'cpe:2.3:...'. " + "For keyword searches, use keywordSearch parameter instead. " + "Example: cpe:2.3:a:vendor:product:* or use --keyword for text search." + ) + raise NotFoundError(error_msg, response) + + elif response.status_code == 429: + retry_after = int(response.headers.get("Retry-After", 30)) + raise RateLimitError( + "Rate limit exceeded (429)", retry_after=retry_after, response=response + ) + + elif response.status_code >= 500: + raise ServerError( + f"Server error ({response.status_code}): {response.text}", + response, + ) + + else: + raise ResponseError( + f"Unexpected status code {response.status_code}: {response.text}", + response, + ) + + except (RateLimitError, httpx.RequestError) as e: + last_error = e + retry_count += 1 + if retry_count < self.max_retries: + # Exponential backoff + import asyncio + await asyncio.sleep(2 ** retry_count) + continue + + except ( + AuthenticationError, + ValidationError, + NotFoundError, + ServerError, + ResponseError, + ): + raise + + except Exception as e: + raise NetworkError(f"Network error: {e}") + + # If we exhausted retries, raise the last error + if last_error: + raise last_error + raise NetworkError("Max retries exceeded") + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client: + await self._client.aclose() |
