Skip to content

Commit

Permalink
ECS Executor: Set tasks to RUNNING state once active (#39212)
Browse files Browse the repository at this point in the history
Tasks were previously being put into QUEUED state after they were active
in the ECS executor. This was to store executor state for task adoption
but had the side effect of removing them from the list of running task
instances (which has other knock-on effects). Instead, change tasks into
the RUNNING state, and do not remove them from the list of running
tasks.

* Update change_state usage in debug and celery executor

- DebugExecutor: was overriding the change_state method from the base
executor, but changing no behaviour, so move to using the base executor
implementation
- CeleryExecutor: Plumb through the new param so that the signature
matches the base executor

* Call running_state in try/catch for backcompat
  • Loading branch information
o-nikolas committed May 6, 2024
1 parent 8965f2e commit a74b5f0
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 18 deletions.
25 changes: 19 additions & 6 deletions airflow/executors/base_executor.py
Expand Up @@ -303,19 +303,23 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
"""
Change state of the task.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
:param state: State to set for the task.
:param info: Executor information for the task instance
:param remove_running: Whether or not to remove the TI key from running set
"""
self.log.debug("Changing state: %s", key)
try:
self.running.remove(key)
except KeyError:
self.log.debug("Could not find key: %s", key)
if remove_running:
try:
self.running.remove(key)
except KeyError:
self.log.debug("Could not find key: %s", key)
self.event_buffer[key] = state, info

def fail(self, key: TaskInstanceKey, info=None) -> None:
Expand Down Expand Up @@ -345,6 +349,15 @@ def queued(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, TaskInstanceState.QUEUED, info)

def running_state(self, key: TaskInstanceKey, info=None) -> None:
"""
Set running state for the event.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Return and flush the event buffer.
Expand Down
5 changes: 0 additions & 5 deletions airflow/executors/debug_executor.py
Expand Up @@ -155,8 +155,3 @@ def end(self) -> None:

def terminate(self) -> None:
self._terminated.set()

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
9 changes: 7 additions & 2 deletions airflow/jobs/scheduler_job_runner.py
Expand Up @@ -692,7 +692,12 @@ def _process_executor_events(self, session: Session) -> int:
ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number

self.log.info("Received executor event with state %s for task instance %s", state, ti_key)
if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED):
if state in (
TaskInstanceState.FAILED,
TaskInstanceState.SUCCESS,
TaskInstanceState.QUEUED,
TaskInstanceState.RUNNING,
):
tis_with_right_state.append(ti_key)

# Return if no finished tasks
Expand All @@ -711,7 +716,7 @@ def _process_executor_events(self, session: Session) -> int:
buffer_key = ti.key.with_try_number(try_number)
state, info = event_buffer.pop(buffer_key)

if state == TaskInstanceState.QUEUED:
if state in (TaskInstanceState.QUEUED, TaskInstanceState.RUNNING):
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
continue
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Expand Up @@ -400,7 +400,12 @@ def attempt_task_runs(self):
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
self.queued(task_key, task.task_arn)
try:
self.running_state(task_key, task.task_arn)
except AttributeError:
# running_state is newly added, and only needed to support task adoption (an optional
# executor feature).
pass
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
Expand Down
10 changes: 8 additions & 2 deletions airflow/providers/celery/executors/celery_executor.py
Expand Up @@ -368,8 +368,14 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
super().change_state(key, state, info)
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
try:
super().change_state(key, state, info, remove_running=remove_running)
except AttributeError:
# Earlier versions of the BaseExecutor don't accept the remove_running parameter for this method
super().change_state(key, state, info)
self.tasks.pop(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
Expand Down
53 changes: 52 additions & 1 deletion tests/executors/test_base_executor.py
Expand Up @@ -33,7 +33,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState


def test_supports_sentry():
Expand Down Expand Up @@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, total_tries):
assert a.elapsed > min_seconds_for_test
assert a.total_tries == total_tries
assert a.tries_after_min == 1


def test_state_fail():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.fail(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info)


def test_state_success():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.success(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info)


def test_state_queued():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.queued(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)


def test_state_generic():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.queued(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)


def test_state_running():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.running_state(key, info=info)
# Running state should not remove a command as running
assert executor.running
assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info)
Expand Up @@ -367,7 +367,8 @@ def test_stopped_tasks(self):
class TestAwsEcsExecutor:
"""Tests the AWS ECS Executor."""

def test_execute(self, mock_airflow_key, mock_executor):
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
def test_execute(self, change_state_mock, mock_airflow_key, mock_executor):
"""Test execution from end-to-end."""
airflow_key = mock_airflow_key()

Expand All @@ -393,6 +394,9 @@ def test_execute(self, mock_airflow_key, mock_executor):
# Task is stored in active worker.
assert 1 == len(mock_executor.active_workers)
assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn
change_state_mock.assert_called_once_with(
airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False
)

@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_success_execute_api_exception(self, mock_backoff, mock_executor):
Expand Down

0 comments on commit a74b5f0

Please sign in to comment.