aboutsummaryrefslogtreecommitdiffstats
path: root/src/cvedb/client.py
blob: 0babf2bab9b10b0c985a629f165dbef24c253e39 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import httpx
from cvedb.models import CVE, CVEWithCPEs

BASE_URL = "https://cvedb.shodan.io"


class CVEDBClient:
    def __init__(self, timeout: float = 30.0):
        self._client = httpx.Client(base_url=BASE_URL, timeout=timeout)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self._client.close()

    def get_cve(self, cve_id: str) -> CVEWithCPEs:
        resp = self._client.get(f"/cve/{cve_id}")
        resp.raise_for_status()
        return CVEWithCPEs.model_validate(resp.json())

    def get_cpes(
        self,
        product: str,
        skip: int = 0,
        limit: int | None = 1000,
        count: bool = False,
    ) -> list[str] | int:
        params: dict = {"product": product, "skip": skip, "count": count}
        if limit is not None:
            params["limit"] = limit
        resp = self._client.get("/cpes", params=params)
        resp.raise_for_status()
        data = resp.json()
        if count:
            return data.get("total", 0)
        return data.get("cpes", [])

    def get_cves(
        self,
        cpe23: str | None = None,
        product: str | None = None,
        skip: int = 0,
        limit: int | None = 1000,
        count: bool = False,
        is_kev: bool = False,
        sort_by_epss: bool = False,
        start_date: str | None = None,
        end_date: str | None = None,
    ) -> list[CVE] | int:
        params: dict = {"skip": skip, "count": count}
        if limit is not None:
            params["limit"] = limit
        if cpe23:
            params["cpe23"] = cpe23
        if product:
            params["product"] = product
        if is_kev:
            params["is_kev"] = True
        if sort_by_epss:
            params["sort_by_epss"] = True
        if start_date:
            params["start_date"] = start_date
        if end_date:
            params["end_date"] = end_date
        resp = self._client.get("/cves", params=params)
        resp.raise_for_status()
        data = resp.json()
        if count:
            return data.get("total", 0)
        return [CVE.model_validate(c) for c in data.get("cves", [])]