"""Main NVD API client.""" import os from typing import Any, Dict, Optional, Type, TypeVar import httpx from pydantic import BaseModel from .exceptions import ( AuthenticationError, NetworkError, NotFoundError, RateLimitError, ResponseError, ServerError, ValidationError, ) from .rate_limiter import RateLimiter T = TypeVar("T", bound=BaseModel) class NVDClient: """Async client for the NVD API 2.0. Example: async with NVDClient(api_key="your-key") as client: cve = await client.cve.get_cve("CVE-2021-44228") """ BASE_URL = "https://services.nvd.nist.gov/rest/json" def __init__( self, api_key: Optional[str] = None, timeout: float = 30.0, max_retries: int = 3, ) -> None: """Initialize the NVD API client. Args: api_key: Optional API key for higher rate limits (50 req/30s vs 5 req/30s) timeout: Request timeout in seconds max_retries: Maximum number of retries for failed requests """ self.api_key = api_key or os.getenv("NVD_API_KEY") self.timeout = timeout self.max_retries = max_retries self._client: Optional[httpx.AsyncClient] = None self._rate_limiter = RateLimiter() self._rate_limiter.configure(has_api_key=bool(self.api_key)) # Import endpoints here to avoid circular imports from .endpoints.cpe import CPEEndpoint from .endpoints.cpematch import CPEMatchEndpoint from .endpoints.cve import CVEEndpoint from .endpoints.cve_history import CVEHistoryEndpoint from .endpoints.source import SourceEndpoint self.cve = CVEEndpoint(self) self.cpe = CPEEndpoint(self) self.cpematch = CPEMatchEndpoint(self) self.source = SourceEndpoint(self) self.history = CVEHistoryEndpoint(self) async def __aenter__(self) -> "NVDClient": """Async context manager entry.""" headers = {"User-Agent": "nvdb-py/0.1.0"} if self.api_key: headers["apiKey"] = self.api_key self._client = httpx.AsyncClient( base_url=self.BASE_URL, headers=headers, timeout=self.timeout, http2=True, ) return self async def __aexit__(self, *args: Any) -> None: """Async context manager exit.""" if self._client: await self._client.aclose() async def request( self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None, response_model: Optional[Type[T]] = None, ) -> T: """Make an API request with rate limiting and error handling. Args: method: HTTP method endpoint: API endpoint path params: Query parameters response_model: Pydantic model for response validation Returns: Parsed response object Raises: Various NVDError subclasses based on error type """ if not self._client: raise RuntimeError("Client not initialized. Use 'async with' context manager.") # Clean up None values from params if params: params = {k: v for k, v in params.items() if v is not None} retry_count = 0 last_error: Optional[Exception] = None while retry_count < self.max_retries: try: # Wait for rate limiter await self._rate_limiter.acquire() # Make request response = await self._client.request( method=method, url=endpoint, params=params, ) # Handle response if response.status_code == 200: try: data = response.json() if response_model: return response_model.model_validate(data) return data # type: ignore except Exception as e: raise ResponseError(f"Failed to parse response: {e}", response) elif response.status_code == 400: raise ValidationError( f"Invalid request parameters: {response.text}", response ) elif response.status_code == 403: if "api key" in response.text.lower(): raise AuthenticationError( "Invalid or missing API key", response ) raise RateLimitError( "Rate limit exceeded (403)", retry_after=30, response=response ) elif response.status_code == 404: # Provide helpful error messages for common mistakes error_msg = f"Resource not found: {endpoint}" if endpoint == "/cpes/2.0" and params: if "cpeMatchString" in params: error_msg = ( f"Invalid CPE match string: {params['cpeMatchString']}. " "CPE match strings must use the format 'cpe:2.3:...'. " "For keyword searches, use keywordSearch parameter instead. " "Example: cpe:2.3:a:vendor:product:* or use --keyword for text search." ) raise NotFoundError(error_msg, response) elif response.status_code == 429: retry_after = int(response.headers.get("Retry-After", 30)) raise RateLimitError( "Rate limit exceeded (429)", retry_after=retry_after, response=response ) elif response.status_code >= 500: raise ServerError( f"Server error ({response.status_code}): {response.text}", response, ) else: raise ResponseError( f"Unexpected status code {response.status_code}: {response.text}", response, ) except (RateLimitError, httpx.RequestError) as e: last_error = e retry_count += 1 if retry_count < self.max_retries: # Exponential backoff import asyncio await asyncio.sleep(2 ** retry_count) continue except ( AuthenticationError, ValidationError, NotFoundError, ServerError, ResponseError, ): raise except Exception as e: raise NetworkError(f"Network error: {e}") # If we exhausted retries, raise the last error if last_error: raise last_error raise NetworkError("Max retries exceeded") async def close(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose()