Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(integration): Implement AWS GuardDuty (#112)
- Loading branch information
1 parent
238dd7e
commit 9c8c154
Showing
7 changed files
with
224 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""ETL functions for AWS GuardDuty findings. | ||
API reference: https://docs.aws.amazon.com/guardduty/latest/ug/guardduty_finding-types-active.html | ||
""" | ||
|
||
import logging | ||
from collections.abc import Generator | ||
from datetime import datetime | ||
from functools import partial | ||
from itertools import chain | ||
from typing import TYPE_CHECKING | ||
|
||
import boto3 | ||
import diskcache as dc | ||
import mmh3 | ||
import polars as pl | ||
from tqdm.contrib.concurrent import thread_map | ||
|
||
from tracecat.config import TRACECAT__TRIAGE_DIR | ||
from tracecat.contexts import ctx_session_role | ||
from tracecat.logger import standard_logger | ||
|
||
if TYPE_CHECKING: | ||
from mypy_boto3_guardduty.type_defs import GetFindingsResponseTypeDef | ||
|
||
logger = standard_logger("runner.aws_guardduty") | ||
|
||
# Supress botocore info logs | ||
logging.getLogger("botocore").setLevel(logging.CRITICAL) | ||
|
||
AWS_GUARDDUTY__TRIAGE_DIR = TRACECAT__TRIAGE_DIR / "aws_guardduty" | ||
AWS_GUARDDUTY__TRIAGE_DIR.mkdir(parents=True, exist_ok=True) | ||
|
||
GET_FINDINGS_MAX_CHUNK_SIZE = 50 | ||
|
||
|
||
def _get_all_guardduty_findings( | ||
chunk_size: int = GET_FINDINGS_MAX_CHUNK_SIZE, | ||
) -> pl.DataFrame: | ||
"""Get GuardDuty findings for the specified time range. | ||
Args: | ||
region: AWS region | ||
start_time: ISO 8601 formatted start time | ||
end_time: ISO 8601 formatted end time | ||
max_results: Maximum number of findings to return | ||
severity_threshold: Minimum severity threshold to return | ||
Returns: | ||
GuardDuty findings as a Polars DataFrame | ||
""" | ||
client = boto3.client("guardduty") | ||
list_findings_paginator = client.get_paginator("list_findings") | ||
|
||
# For all regions and detectors, list findings | ||
findings: list[GetFindingsResponseTypeDef] = [] | ||
detectors = client.list_detectors()["DetectorIds"] | ||
chunk_size = min(chunk_size, GET_FINDINGS_MAX_CHUNK_SIZE) | ||
|
||
def chunker(finding_ids: list[str]) -> Generator[list[str], None, None]: | ||
for i in range(0, len(finding_ids), chunk_size): | ||
yield finding_ids[i : i + chunk_size] | ||
|
||
def getter(finding_ids: list[str], *, detector_id: str) -> list[str]: | ||
client = boto3.client("guardduty") | ||
findings = client.get_findings(DetectorId=detector_id, FindingIds=finding_ids) | ||
return findings.get("Findings", []) | ||
|
||
for detector_id in detectors: | ||
finding_ids: list[str] = [] | ||
# TODO: Parallelize this? | ||
for page in list_findings_paginator.paginate(DetectorId=detector_id): | ||
finding_ids.extend(page.get("FindingIds", [])) | ||
logger.info(f"Found {len(finding_ids)} findings in detector {detector_id}") | ||
|
||
detector_findings: list[list[str]] = thread_map( | ||
partial(getter, detector_id=detector_id), | ||
chunker(finding_ids=finding_ids), | ||
desc="馃搨 Getting AWS GuardDuty findings", | ||
) | ||
findings.extend(chain.from_iterable(detector_findings)) | ||
|
||
logger.info(f"Retrieved {len(findings)} GuardDuty findings") | ||
df = pl.DataFrame(findings) | ||
return df | ||
|
||
|
||
GUARDDUTY_DEFAULT_STRUCT_COLS = ["Service", "Resource"] | ||
|
||
|
||
def _stringify_struct_columns(df: pl.DataFrame | pl.LazyFrame) -> pl.LazyFrame: | ||
return df.lazy().with_columns( | ||
pl.col(c).struct.json_encode() for c in GUARDDUTY_DEFAULT_STRUCT_COLS | ||
) | ||
|
||
|
||
def load_guardduty_findings( | ||
start: datetime, | ||
end: datetime, | ||
account_id: str, | ||
organization_id: str, | ||
) -> pl.LazyFrame: | ||
"""Load AWS GuardDuty findings for the specified time range. | ||
Caches and reads from disk to avoid repeated (expensive) API calls. | ||
Args: | ||
regions: AWS regions to load findings from | ||
chunk_size: Maximum number of findings to load per request | ||
Returns: | ||
GuardDuty findings as a Polars DataFrame | ||
""" | ||
# Include the session role in the cache key to avoid collisions | ||
# when possibly serving multiple users concurrently | ||
role = ctx_session_role.get() | ||
logger.info(f"Loading GuardDuty findings for role {role}") | ||
|
||
key = mmh3.hash( | ||
f"{role}:{start}{end}{account_id}{organization_id}".encode(), seed=42 | ||
) | ||
|
||
df: pl.DataFrame | ||
dt_col = "CreatedAt" | ||
with dc.Cache(directory=AWS_GUARDDUTY__TRIAGE_DIR) as cache: | ||
if key in cache: | ||
logger.info("Cache hit for GuardDuty findings") | ||
# Structs here are already stringified | ||
df = cache[key] | ||
else: | ||
logger.info("Cache miss for GuardDuty findings") | ||
df = ( | ||
_get_all_guardduty_findings() | ||
.lazy() | ||
.pipe(_stringify_struct_columns) | ||
.collect(streaming=True) | ||
) | ||
# Cache for 10 minutes | ||
cache.set(key=key, value=df, expire=600) | ||
# Apply time range filter | ||
df = df.filter(pl.col(dt_col).is_between(start, end)) | ||
return df.lazy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Native integration to query AWS GuardDuty findings. | ||
Optional secrets: `aws-guardduty` secret with keys `AWS_ACCOUNT_ID` and `AWS_ORGANIZATION_ID`. | ||
Note: this integration DOES NOT support IAM credential based authentication. | ||
Secrets are only used to obscure potentially sensitive data (account ID, organization ID). | ||
""" | ||
|
||
import os | ||
from typing import Any | ||
|
||
import dateutil.parser | ||
|
||
from tracecat.etl.aws_guardduty import load_guardduty_findings | ||
from tracecat.etl.query_builder import pl_sql_query | ||
from tracecat.integrations._registry import registry | ||
|
||
|
||
@registry.register( | ||
description="Query AWS GuardDuty findings", secrets=["aws-guardduty"] | ||
) | ||
def query_guardduty_findings( | ||
start: str, | ||
end: str, | ||
query: str, | ||
account_id: str | None = None, | ||
organization_id: str | None = None, | ||
) -> list[dict[str, Any]]: | ||
account_id = account_id or os.environ["AWS_ACCOUNT_ID"] | ||
organization_id = organization_id or os.environ["AWS_ORGANIZATION_ID"] | ||
start_dt = dateutil.parser.parse(start) | ||
end_dt = dateutil.parser.parse(end) | ||
# Hash the function call args | ||
# to use as a cache key | ||
# We need to use the session role to compute the cache key | ||
findings_lf = load_guardduty_findings( | ||
start=start_dt, | ||
end=end_dt, | ||
account_id=account_id, | ||
organization_id=organization_id, | ||
) | ||
queried_findings = pl_sql_query(lf=findings_lf, query=query, eager=True).to_dicts() | ||
return queried_findings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters