#!/var/ossec/framework/python/bin/python3
#
# =============================================================================
# custom-epss.py — Wazuh Integration: EPSS Score Enrichment via FIRST API
# =============================================================================
# Placement : /var/ossec/integrations/custom-epss.py
# Permissions: chmod 750 /var/ossec/integrations/custom-epss.py
#              chown root:wazuh /var/ossec/integrations/custom-epss.py
#
# ossec.conf snippet:
#   <integration>
#     <name>custom-epss</name>
#     <group>vulnerability-detector</group>
#     <alert_format>json</alert_format>
#   </integration>
#
# EPSS API reference: https://www.first.org/epss/api
# Cache file       : /var/ossec/tmp/epss_cache.json  (TTL = 24 h)
# =============================================================================

import sys
import os
import json
import logging
import time
import ssl
import certifi
import urllib.request
import urllib.error
from datetime import datetime, timezone

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
EPSS_API_URL   = "https://api.first.org/data/v1/epss"
CACHE_FILE     = "/var/ossec/tmp/epss_cache.json"
CACHE_TTL      = 86400          # 24 hours in seconds
LOG_FILE       = "/var/ossec/logs/custom-epss.log"
SOCKET_PATH    = "/var/ossec/queue/sockets/queue"
WAZUH_QUEUE_ID = 1              # AR queue — used to inject synthetic events
REQUEST_TIMEOUT = 10            # seconds

# ---------------------------------------------------------------------------
# Logging setup
# ---------------------------------------------------------------------------
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)

logging.basicConfig(
    filename=LOG_FILE,
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%S%z",
)
logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Cache helpers
# ---------------------------------------------------------------------------

def _load_cache() -> dict:
    """Return the on-disk cache dict, or an empty dict on any error."""
    if not os.path.isfile(CACHE_FILE):
        return {}
    try:
        with open(CACHE_FILE, "r") as fh:
            return json.load(fh)
    except (json.JSONDecodeError, OSError) as exc:
        logger.warning("Cache read error (%s) — starting fresh.", exc)
        return {}


def _save_cache(cache: dict) -> None:
    """Persist the cache dict to disk, creating parent dirs if needed."""
    os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
    try:
        with open(CACHE_FILE, "w") as fh:
            json.dump(cache, fh)
    except OSError as exc:
        logger.error("Cache write error: %s", exc)


def get_cached_epss(cve_id: str) -> dict | None:
    """
    Return cached EPSS data for *cve_id* if it exists and is still fresh.
    Returns None when the entry is absent or stale.
    """
    cache = _load_cache()
    entry = cache.get(cve_id)
    if not entry:
        return None
    age = time.time() - entry.get("cached_at", 0)
    if age > CACHE_TTL:
        logger.debug("Cache miss (stale, age=%.0fs) for %s.", age, cve_id)
        return None
    logger.info("Cache hit for %s (age=%.0fs).", cve_id, age)
    return entry["data"]


def store_epss_in_cache(cve_id: str, data: dict) -> None:
    """Write *data* for *cve_id* into the cache with a fresh timestamp."""
    cache = _load_cache()
    cache[cve_id] = {"cached_at": time.time(), "data": data}
    _save_cache(cache)
    logger.debug("Cached EPSS data for %s.", cve_id)


# ---------------------------------------------------------------------------
# FIRST EPSS API
# ---------------------------------------------------------------------------

def fetch_epss(cve_id: str) -> dict | None:
    """
    Query the FIRST EPSS API for *cve_id*.

    Returns a dict like:
        {
            "cve":        "CVE-2021-44228",
            "epss":       "0.97565",
            "percentile": "0.99976",
            "date":       "2025-04-09"
        }
    or None on failure.
    """
    url = f"{EPSS_API_URL}?cve={cve_id}"
    logger.info("Fetching EPSS for %s from %s", cve_id, url)
    try:
        ctx = ssl.create_default_context()
        ctx.load_verify_locations("/etc/ssl/certs/ca-certificates.crt")
    except:
        logger.error("Cannot load system certificates")
        return None
    try:
        req = urllib.request.Request(
            url,
            headers={"Accept": "application/json", "User-Agent": "wazuh-epss-integration/1.0"},
        )
        with urllib.request.urlopen(req, timeout=REQUEST_TIMEOUT, context=ctx) as resp:
            raw = json.loads(resp.read().decode())
    except urllib.error.HTTPError as exc:
        logger.error("HTTP %s fetching EPSS for %s: %s", exc.code, cve_id, exc.reason)
        return None
    except urllib.error.URLError as exc:
        logger.error("URL error fetching EPSS for %s: %s", cve_id, exc.reason)
        return None
    except (json.JSONDecodeError, OSError) as exc:
        logger.error("Unexpected error for %s: %s", cve_id, exc)
        return None

    scores = raw.get("data", [])
    if not scores:
        logger.warning("No EPSS data returned for %s.", cve_id)
        return None

    entry = scores[0]
    return {
        "cve":        entry.get("cve", cve_id),
        "epss":       entry.get("epss", "0"),
        "percentile": entry.get("percentile", "0"),
        "date":       entry.get("date", ""),
    }


def get_epss(cve_id: str) -> dict | None:
    """Cache-aware wrapper: try cache first, fall back to live API."""
    cached = get_cached_epss(cve_id)
    if cached:
        return cached
    data = fetch_epss(cve_id)
    if data:
        store_epss_in_cache(cve_id, data)
    return data


# ---------------------------------------------------------------------------
# EPSS risk label helper
# ---------------------------------------------------------------------------

def epss_risk_label(score: float) -> str:
    """
    Derive a human-readable risk label from a raw EPSS probability score.

    Thresholds are opinionated but align with common threat-intel practice:
      ≥ 0.90 → Critical exploitation probability
      ≥ 0.50 → High exploitation probability
      ≥ 0.10 → Medium exploitation probability
      < 0.10 → Low exploitation probability
    """
    if score >= 0.90:
        return "critical"
    if score >= 0.50:
        return "high"
    if score >= 0.10:
        return "medium"
    return "low"


# ---------------------------------------------------------------------------
# Wazuh socket writer
# ---------------------------------------------------------------------------

def send_to_wazuh(event: dict) -> None:
    """
    Inject *event* as a JSON syslog message into the Wazuh manager queue.

    The message format is:  1:epss_integration:<json_payload>
    Rule ID 100 200 is reserved for custom rules — adjust to your ruleset.
    """
    import socket

    msg = json.dumps(event)
    # Wazuh queue format: <queue_id>:<location>:<message>
    payload = f"1:epss_integration:{msg}".encode("utf-8")

    try:
        with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock:
            sock.sendto(payload, SOCKET_PATH)
        logger.info("Event sent to Wazuh queue for %s.", event.get("epss", {}).get("cve"))
    except OSError as exc:
        logger.error("Failed to write to Wazuh queue: %s", exc)


# ---------------------------------------------------------------------------
# Alert parsing
# ---------------------------------------------------------------------------

def extract_cve(alert: dict) -> str | None:
    """
    Extract a CVE identifier from the incoming Wazuh alert.

    Wazuh vulnerability-detector alerts typically expose the CVE in:
      alert["data"]["vulnerability"]["cve"]
    A fallback path through the rule description is also attempted.
    """
    try:
        cve = alert["data"]["vulnerability"]["cve"]
        if cve and cve.upper().startswith("CVE-"):
            return cve.strip()
    except (KeyError, TypeError):
        pass

    # Fallback: scan the rule description for a CVE pattern
    import re
    description = alert.get("rule", {}).get("description", "")
    match = re.search(r"CVE-\d{4}-\d{4,7}", description, re.IGNORECASE)
    if match:
        return match.group(0).upper()

    return None


# ---------------------------------------------------------------------------
# Main integration logic
# ---------------------------------------------------------------------------

def build_enriched_event(alert: dict, epss_data: dict) -> dict:
    """
    Compose the enriched Wazuh event that will be re-injected into the queue.

    The event preserves original alert metadata and appends an `epss` block.
    """
    epss_score = float(epss_data.get("epss", 0))
    percentile  = float(epss_data.get("percentile", 0))

    return {
        "integration": "epss",
        "timestamp":   datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        # ---- original alert context ----------------------------------------
        "original_alert": {
            "id":          alert.get("id"),
            "rule":        alert.get("rule", {}),
            "agent":       alert.get("agent", {}),
            "data":        alert.get("data", {}),
        },
        # ---- EPSS enrichment block -----------------------------------------
        "epss": {
            "cve":              epss_data.get("cve"),
            "score":            epss_score,
            "score_pct":        round(epss_score * 100, 4),
            "percentile":       percentile,
            "percentile_pct":   round(percentile * 100, 2),
            "risk_label":       epss_risk_label(epss_score),
            "model_date":       epss_data.get("date"),
            "source":           "FIRST EPSS API",
            "cached":           False,   # overwritten below if from cache
        },
    }


def main() -> None:
    # ------------------------------------------------------------------
    # 1. Read Wazuh integration arguments
    #    argv[1] = path to the JSON alert file
    #    argv[2] = API key (unused here, but kept for interface compat)
    #    argv[3] = hook URL (unused)
    # ------------------------------------------------------------------
    if len(sys.argv) < 2:
        logger.error("Usage: custom-epss.py <alert_file> [api_key] [hook_url]")
        sys.exit(1)

    alert_file = sys.argv[1]

    try:
        with open(alert_file, "r") as fh:
            alert = json.load(fh)
    except (OSError, json.JSONDecodeError) as exc:
        logger.error("Cannot read alert file %s: %s", alert_file, exc)
        sys.exit(1)

    logger.info("Processing alert id=%s rule=%s",
                alert.get("id"), alert.get("rule", {}).get("id"))

    # ------------------------------------------------------------------
    # 2. Extract CVE identifier
    # ------------------------------------------------------------------
    cve_id = extract_cve(alert)
    if not cve_id:
        logger.info("No CVE found in alert — skipping EPSS enrichment.")
        sys.exit(0)

    logger.info("CVE identified: %s", cve_id)

    # ------------------------------------------------------------------
    # 3. Fetch EPSS (cache-aware)
    # ------------------------------------------------------------------
    epss_data = get_epss(cve_id)
    if not epss_data:
        logger.warning("EPSS data unavailable for %s — no event generated.", cve_id)
        sys.exit(0)

    # Flag whether this came from cache (informational)
    from_cache = get_cached_epss(cve_id) is not None

    # ------------------------------------------------------------------
    # 4. Build enriched event
    # ------------------------------------------------------------------
    event = build_enriched_event(alert, epss_data)
    event["epss"]["cached"] = from_cache

    logger.info(
        "EPSS for %s: score=%.5f (%s%%) percentile=%.2f%% label=%s cached=%s",
        cve_id,
        event["epss"]["score"],
        event["epss"]["score_pct"],
        event["epss"]["percentile_pct"],
        event["epss"]["risk_label"],
        from_cache,
    )

    # ------------------------------------------------------------------
    # 5. Send enriched event back to Wazuh
    # ------------------------------------------------------------------
    send_to_wazuh(event)
    logger.info("EPSS enrichment complete for %s.", cve_id)


if __name__ == "__main__":
    main()
