"""Source API endpoint.""" from typing import TYPE_CHECKING, AsyncIterator, Optional from ..models import SourceData, SourceResponse if TYPE_CHECKING: from ..client import NVDClient class SourceEndpoint: """Source API endpoint for data source organizations.""" def __init__(self, client: "NVDClient") -> None: self.client = client async def get_source(self, source_identifier: str) -> SourceData: """Get a specific source by identifier. Args: source_identifier: Source identifier Returns: Source data object """ response = await self.client.request( "GET", "/source/2.0", params={"sourceIdentifier": source_identifier}, response_model=SourceResponse, ) if not response.sources: raise ValueError(f"Source {source_identifier} not found") return response.sources[0] async def list_sources( self, source_identifier: Optional[str] = None, last_mod_start_date: Optional[str] = None, last_mod_end_date: Optional[str] = None, results_per_page: int = 1000, start_index: int = 0, ) -> AsyncIterator[SourceData]: """List data sources. Args: source_identifier: Filter by specific source identifier last_mod_start_date: Last modified start date (ISO-8601) last_mod_end_date: Last modified end date (ISO-8601) results_per_page: Results per page (max 1000) start_index: Starting index for pagination Yields: Source data objects """ params = { "sourceIdentifier": source_identifier, "lastModStartDate": last_mod_start_date, "lastModEndDate": last_mod_end_date, "resultsPerPage": results_per_page, "startIndex": start_index, } current_index = start_index while True: params["startIndex"] = current_index response = await self.client.request( "GET", "/source/2.0", params=params, response_model=SourceResponse, ) for source in response.sources: yield source # Check if there are more results if current_index + response.resultsPerPage >= response.totalResults: break current_index += response.resultsPerPage