Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle db isolation for mapped operators and task groups #39259

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/api_internal/endpoints/rpc_api_endpoint.py
Expand Up @@ -26,7 +26,8 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _get_template_context, _update_rtif
from airflow.models.taskinstance import _get_template_context, _record_task_map_for_downstreams, _update_rtif
from airflow.models.xcom_arg import _get_task_map_length
from airflow.sensors.base import _orig_start_date
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session
Expand Down Expand Up @@ -56,8 +57,10 @@ def _initialize_map() -> dict[str, Callable]:
_default_action_log_internal,
_get_template_context,
_get_ti_db_access,
_get_task_map_length,
_update_rtif,
_orig_start_date,
_record_task_map_for_downstreams,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
36 changes: 31 additions & 5 deletions airflow/models/taskinstance.py
Expand Up @@ -497,8 +497,14 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
for key, value in xcom_value.items():
task_instance.xcom_push(key=key, value=value, session=session_or_null)
task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
if TYPE_CHECKING:
assert task_orig.dag
_record_task_map_for_downstreams(
task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null
task_instance=task_instance,
task=task_orig,
dag=task_orig.dag,
value=xcom_value,
session=session_or_null,
)
return result

Expand Down Expand Up @@ -1003,25 +1009,43 @@ def _refresh_from_task(
task_instance_mutation_hook(task_instance)


@internal_api_call
@provide_session
def _record_task_map_for_downstreams(
*, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, value: Any, session: Session
*,
task_instance: TaskInstance | TaskInstancePydantic,
task: Operator,
dag: DAG,
value: Any,
session: Session,
) -> None:
"""
Record the task map for downstream tasks.

:param task_instance: the task instance
:param task: The task object
:param dag: the dag associated with the task
:param value: The value
:param session: SQLAlchemy ORM Session

:meta private:
"""
# if not task._dag:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
# task._dag = dag # required when on RPC server side

dstandish marked this conversation as resolved.
Show resolved Hide resolved
# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag
Comment on lines +1033 to +1038
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan for this… can we do this earlier in the stack, say when the task is created instead?

Copy link
Contributor Author

@dstandish dstandish Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue @uranusjr is that this is early in the stack when it's a RPC call. the only earlier place we could do it is in the decorator. WDYT? we could stick it in a private function though and get it out of the way and reuse in module though....

when not a RPC call, this has no effect and is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @uranusjr this is resolved here (Use sentinel to elide the dag object on reserialization) but i can't make this PR yet because it's depending on too many other PRs to get merged first


if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
Comment on lines +1043 to +1045
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no -- adding the indent will make it so all of this text is "part of" the todo (i.e. all show yellow in IDE) if we don't do this then it looks like separate comment... just a driveby "fix" but i can remove if you like

if isinstance(task, MappedOperator):
return
if value is None:
Expand Down Expand Up @@ -3167,6 +3191,8 @@ def render_templates(
# MappedOperator is useless for template rendering, and we need to be
# able to access the unmapped task instead.
original_task.render_template_fields(context, jinja_env)
if isinstance(self.task, MappedOperator):
self.task = context["ti"].task
Comment on lines +3191 to +3192
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is an interesting one @potiuk. The way mapped operators are "expanded" or "unmapped"... it happens inside of MappedOperator.render_template_fields. It does so by replacing the task attr on the ti in the context dictionary, which in the non-db-isolation case mutates what is here self.task! But in db isolation case, the context dict is created via RPC and so the pydantic TI in the context dict is not the same as the PydanticTI that is running.... it's .... quite complicated. But anyway this here is one way to ensure that the task gets properly unmapped -- we don't here rely on mutating the TI in the context dict.

Comment on lines +3191 to +3192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not work with BaseOperator? The conditional makes this a lot weirder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s right because only when it is mappedoperator is ti.task mutated. Otherwise ti.task is the result of rpc call and long story short it can’t be used

Copy link
Contributor Author

@dstandish dstandish Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that make sense @uranusjr ?

so with normal task, self.task is the task that is created locally, and there is no need to override it from the one in context dict. and if you did that then you'd take a task object that isn't quite complete, essentially because we don't have proper serialization of Task since there's no real Task entity and no TaskPydantic. But generally it's not a problem because most of the time we don't need to serialize a task object.

in the mappedoperator case though, as we saw last night, "unmapping" is achieved by mutating the ti in the context dict, and it relies on the assumption that the TI in the context dict is the same object as the one that is created locally and being run, which isn't true when the context comes from RPC.

if searching for alternatives, we could look at not relying on the context dict for this "unmapping". e.g. we could forword the "original" ti object to the thing doing the unmapping so we don't need to mutate what's in context.

another option would be, upon receiving a fresh context dict over RPC, we could replace the TIs in the context with the local TIPydantic object -- or something to this effect. then perhaps we could keep the context["ti"] mutation approach for unmapping.

we could also look at changing the way we handle context over RPC. currently it's just a "working" approach but not optimal because there's no laziness. we could optimize by making each context object an accessor that is an RPC call (and we should do something like this ). and something like that could help here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, but if isinstance(self.task, MappedOperator) is an awkward condition to check for the case.

upon receiving a fresh context dict over RPC, we could replace the TIs in the context with the local TIPydantic object

This sounds somewhat promising. Instead of just the ti, we could probably try to replace the entire relationship (including e.g. dag) so we can get rid of needing to pass in dag separately into _record_task_map_for_downstreams.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, but if isinstance(self.task, MappedOperator) is an awkward condition to check for the case.

yeah, i see what you're saying. e.g. better would be for the code to "tell us" when an unmap has happened.

like when we call

original_task.render_template_fields(context, jinja_env)

that could like... return a new task when it creates one. that would certainly make it more obvious what is going on too.


return original_task

Expand Down
100 changes: 58 additions & 42 deletions airflow/models/xcom_arg.py
Expand Up @@ -23,9 +23,10 @@

from sqlalchemy import func, or_, select

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.utils.db import exists_query
from airflow.utils.mixins import ResolveMixin
Expand Down Expand Up @@ -222,6 +223,53 @@ def __exit__(self, exc_type, exc_val, exc_tb):
SetupTeardownContext.set_work_task_roots_and_leaves()


@internal_api_call
@provide_session
def _get_task_map_length(
*,
dag_id: str,
task_id: str,
run_id: str,
is_mapped: bool,
session: Session = NEW_SESSION,
) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

if is_mapped:
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == dag_id,
XCom.run_id == run_id,
XCom.task_id == task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)


class PlainXComArg(XComArg):
"""Reference to one single XCom without any additional semantics.

Expand Down Expand Up @@ -364,51 +412,19 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
return super().zip(*others, fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.operator
if isinstance(task, MappedOperator):
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task.task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)
return _get_task_map_length(
dag_id=self.operator.dag_id,
task_id=self.operator.task_id,
is_mapped=isinstance(self.operator, MappedOperator),
run_id=run_id,
session=session,
)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
from airflow.models.taskinstance import TaskInstance

ti = context["ti"]
if not isinstance(ti, TaskInstance):
raise NotImplementedError("Wait for AIP-44 implementation to complete")

if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
self.operator,
Expand Down