Skip to content

Commit

Permalink
refactor(engine): Cleanup binds, fstrings, printf formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed May 8, 2024
1 parent 32cc183 commit 659143d
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 112 deletions.
24 changes: 18 additions & 6 deletions tracecat/api/app.py
Expand Up @@ -26,6 +26,7 @@
authenticate_user_or_service,
)
from tracecat.config import TRACECAT__APP_ENV, TRACECAT__RUNNER_URL
from tracecat.contexts import ctx_session_role
from tracecat.db import (
Action,
ActionRun,
Expand Down Expand Up @@ -136,12 +137,19 @@ def create_app(**kwargs) -> FastAPI:
app.add_middleware(RequestLoggingMiddleware)

# TODO: Check TRACECAT__APP_ENV to set methods and headers
logger.bind(env=TRACECAT__APP_ENV, origins=cors_origins_kwargs).warning("App started")
logger.warning("App started", env=TRACECAT__APP_ENV, origins=cors_origins_kwargs)


# Catch-all exception handler to prevent stack traces from leaking
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
logger.error(
"Unexpected error: {!s}",
exc,
role=ctx_session_role.get(),
params=request.query_params,
path=request.url.path,
)
return ORJSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": "An unexpected error occurred. Please try again later."},
Expand All @@ -166,7 +174,7 @@ async def check_runner_health() -> dict[str, str]:
try:
response.raise_for_status()
except Exception as e:
logger.error(f"Error checking runner health: {e}", exc_info=True)
logger.opt(exception=e).error("Error checking runner health", error=e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error checking runner health",
Expand Down Expand Up @@ -506,7 +514,11 @@ async def trigger_workflow_run(
entrypoint_key=params.action_key,
entrypoint_payload=params.payload,
)
logger.debug(f"Triggering workflow: {workflow_id = }, {workflow_params = }")
logger.debug(
"Triggering workflow",
workflow_id=workflow_id,
workflow_params=workflow_params,
)
async with AuthenticatedRunnerClient(role=service_role) as client:
response = await client.post(
f"/workflows/{workflow_id}",
Expand All @@ -515,7 +527,7 @@ async def trigger_workflow_run(
try:
response.raise_for_status()
except Exception as e:
logger.error(f"Error triggering workflow: {e}", exc_info=True)
logger.opt(exception=e).error("Error triggering workflow", error=e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error triggering workflow",
Expand Down Expand Up @@ -947,7 +959,7 @@ def authenticate_webhook(
try:
webhook = result.one()
except NoResultFound as e:
logger.error("Webhook does not exist: %s", e)
logger.opt(exception=e).error("Webhook does not exist", error=e)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found"
) from e
Expand All @@ -960,7 +972,7 @@ def authenticate_webhook(
try:
action = result.one()
except Exception as e:
logger.error("Action does not exist: %s", e)
logger.opt(exception=e).error("Action does not exist", error=e)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found"
) from e
Expand Down
6 changes: 3 additions & 3 deletions tracecat/api/completions.py
Expand Up @@ -150,7 +150,7 @@ def _to_disciminated_union(cons: list[CategoryConstraint]) -> tuple[str, str]:
Returns:
tuple[str, str]: The discriminated union type and the supporting types
"""
logger.info(f"Creating discriminated union for {cons =}")
logger.info("Creating discriminated union", cons=cons)
supporting_tags = {}
for tc in cons:
tag = tc.tag
Expand Down Expand Up @@ -245,8 +245,8 @@ async def stream_case_completions(
output_cls=CaseMissingFieldsResponse,
field_cons=field_cons,
)
logger.info("馃 Starting case completions for %d cases...", len(cases))
logger.bind(system_context=system_context).debug("System context")
logger.info("馃 Starting case completions for {} cases...", len(cases))
logger.debug("System context: {}", system_context=system_context)

async def task(case: Case) -> str:
prompt = f"""Case JSON Object: ```\n{case.model_dump_json()}\n```"""
Expand Down
2 changes: 1 addition & 1 deletion tracecat/auth.py
Expand Up @@ -270,7 +270,7 @@ async def _get_role_from_jwt(token: str | bytes) -> Role:
if user_id is None:
raise HTTP_EXC("No sub claim in JWT")
except ExpiredSignatureError as e:
logger.error(f"ExpiredSignatureError: {e}")
logger.error("Signature expired", error=e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session expired",
Expand Down
4 changes: 2 additions & 2 deletions tracecat/concurrency.py
Expand Up @@ -17,7 +17,7 @@ def _run_serialized_fn(serialized_wrapped_fn: bytes, role: Role, /, *args, **kwa
# NOTE: This is not the raw function - it is still wrapped by the `wrapper` decorator
wrapped_fn: Callable[_P, Any] = cloudpickle.loads(serialized_wrapped_fn)
ctx_session_role.set(role)
logger.bind(role=role).debug("Running serialized function")
logger.debug("Running serialized function", role=role)
kwargs["__role"] = role
res = wrapped_fn(*args, **kwargs)
return res
Expand All @@ -28,6 +28,6 @@ class CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
def submit(self, fn: Callable[_P, Any], /, *args, **kwargs):
# We need to pass the role to the function running in the child process
role = ctx_session_role.get()
logger.bind(role=role).debug("Serializing function")
logger.debug("Serializing function", role=role)
serialized_fn = cloudpickle.dumps(fn)
return super().submit(_run_serialized_fn, serialized_fn, role, *args, **kwargs)
2 changes: 1 addition & 1 deletion tracecat/etl/aws_cloudtrail.py
Expand Up @@ -184,7 +184,7 @@ def load_cloudtrail_logs(
organization_id: str | None = None,
) -> pl.LazyFrame:
logger.info(
"馃啑 Download AWS CloudTrail logs from: account_id=%r across regions=%s",
"馃啑 Download AWS CloudTrail logs from: account_id={!r} across regions={!s}",
account_id,
regions,
)
Expand Down
50 changes: 25 additions & 25 deletions tracecat/integrations/_registry.py
Expand Up @@ -59,7 +59,7 @@ def register(
"""Decorator factory to register a new integration function with additional parameters."""

def decorator_register(func: FunctionType):
logger.info(f"Registering integration {func.__name__}")
logger.info("Registering integration", name=func.__name__)
validate_type_constraints(func)
platform = get_integration_platform(func)
key = get_integration_key(func)
Expand All @@ -74,25 +74,27 @@ def wrapper(*args, **kwargs):
2. Inject all secret keys into the execution environment.
3. Clean up the environment after the function has executed.
"""
_secrets: list[Secret] = []
try:
role = kwargs.pop("__role", None)
# Get secrets from the secrets API
self._logger = logger.bind(pid=os.getpid())
self._logger.bind(key=key).info("Executing in subprocess")

if secrets:
self._logger.bind(secrets=secrets).info("Pull secrets")
_secrets = self._get_secrets(role=role, secret_names=secrets)
self._set_secrets(_secrets)

return func(*args, **kwargs)
except Exception as e:
self._logger.error(f"Error running integration '{key}': {e}")
raise
finally:
self._logger.info(f"Cleaning up after integration '{key}'.")
self._unset_secrets(_secrets)
secret_objs: list[Secret] = []
role: Role = kwargs.pop("__role", None)
with logger.contextualize(user_id=role.user_id, pid=os.getpid()):
try:
# Get secrets from the secrets API
logger.info("Executing in subprocess", key=key)

if secrets:
logger.info("Pull secrets", secrets=secrets)
secret_objs = self._get_secrets(
role=role, secret_names=secrets
)
self._set_secrets(secret_objs)

return func(*args, **kwargs)
except Exception as e:
logger.error("Error running integration {!r}: {!s}", key, e)
raise
finally:
logger.info("Cleaning up after integration {!r}", key)
self._unset_secrets(secret_objs)

if key in self._integrations:
raise ValueError(f"Integration '{key}' is already registered.")
Expand All @@ -114,21 +116,19 @@ def wrapper(*args, **kwargs):
def _get_secrets(self, role: Role, secret_names: list[str]) -> list[Secret]:
"""Retrieve secrets from the secrets API."""

self._logger.opt(lazy=True).debug(
"Getting secrets {secret_names}", secret_names=lambda: secret_names
)
logger.debug("Getting secrets {}", secret_names)
return asyncio.run(batch_get_secrets(role, secret_names))

def _set_secrets(self, secrets: list[Secret]):
"""Set secrets in the environment."""
for secret in secrets:
self._logger.info(f"Setting secret {secret!r}")
logger.info("Setting secret {!r}", secret.name)
for kv in secret.keys:
os.environ[kv.key] = kv.value

def _unset_secrets(self, secrets: list[Secret]):
for secret in secrets:
self._logger.info(f"Deleting secret {secret.name!r}")
logger.info("Deleting secret {!r}", secret.name)
for kv in secret.keys:
del os.environ[kv.key]

Expand Down
4 changes: 2 additions & 2 deletions tracecat/llm.py
Expand Up @@ -66,7 +66,7 @@ def parse_choice(choice: Choice) -> str | dict[str, Any]:
{"role": "user", "content": prompt},
]

logger.info("馃 Calling OpenAI API with model: %s...", model)
logger.info("馃 Calling OpenAI API with {} model...", model)
response: ChatCompletion = await client.chat.completions.create( # type: ignore[call-overload]
model=model,
response_format={"type": response_format},
Expand All @@ -76,7 +76,7 @@ def parse_choice(choice: Choice) -> str | dict[str, Any]:
**kwargs,
)
# TODO: Should track these metrics
logger.bind(usage=response.usage).info("馃 Usage")
logger.info("馃 Usage", usage=response.usage)
if stream:
return response

Expand Down
6 changes: 3 additions & 3 deletions tracecat/messaging/consumer.py
Expand Up @@ -39,10 +39,10 @@ async def prepare_queue(*, channel: Channel, exchange: str, routing_keys: list[s
await queue.bind(ex, routing_key=routing_key)
yield queue
except Exception as e:
logger.error(f"Error in prepare_exchange: {e}", exc_info=True)
logger.opt(exception=e).error("Error in prepare_exchange", exchange=exchange)
finally:
# Cleanup
logger.info(f"Cleaning up exchange {exchange!r}")
logger.info("Cleaning up exchange", exchange=exchange)
if queue:
for routing_key in routing_keys:
await queue.unbind(ex, routing_key=routing_key)
Expand Down Expand Up @@ -101,6 +101,6 @@ async def _subscribe():
async for message in _subscribe():
yield message
except Exception as e:
logger.error(f"Error in event subscription: {e}", exc_info=True)
logger.opt(exception=e).error("Error in event subscription")
finally:
logger.info("Closing event subscription")
4 changes: 1 addition & 3 deletions tracecat/messaging/producer.py
Expand Up @@ -24,9 +24,7 @@ async def event_producer(

for routing_key in routing_keys:
await ex.publish(message, routing_key=routing_key)
logger.bind(routing_key=routing_key, body=message.body).debug(
"Published message"
)
logger.debug("Published message", routing_key=routing_key, body=message.body)


async def publish(
Expand Down
5 changes: 3 additions & 2 deletions tracecat/middleware/request.py
Expand Up @@ -13,13 +13,14 @@ async def dispatch(self, request: Request, call_next):
)

# Log the incoming request with parameters
request.app.logger.bind(
request.app.logger.debug(
"Incoming request",
method=request.method,
scheme=request.url.scheme,
hostname=request.url.hostname,
path=request.url.path,
params=request_params,
body=request_body,
).debug("Request")
)

return await call_next(request)

0 comments on commit 659143d

Please sign in to comment.