Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECS Executor: Set tasks to RUNNING state once active #39212

Merged
merged 6 commits into from May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 19 additions & 6 deletions airflow/executors/base_executor.py
Expand Up @@ -303,19 +303,23 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
"""
Change state of the task.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
:param state: State to set for the task.
:param info: Executor information for the task instance
:param remove_running: Whether or not to remove the TI key from running set
"""
self.log.debug("Changing state: %s", key)
try:
self.running.remove(key)
except KeyError:
self.log.debug("Could not find key: %s", key)
if remove_running:
try:
self.running.remove(key)
except KeyError:
self.log.debug("Could not find key: %s", key)
self.event_buffer[key] = state, info

def fail(self, key: TaskInstanceKey, info=None) -> None:
Expand Down Expand Up @@ -345,6 +349,15 @@ def queued(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, TaskInstanceState.QUEUED, info)

def running_state(self, key: TaskInstanceKey, info=None) -> None:
"""
Set running state for the event.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Return and flush the event buffer.
Expand Down
5 changes: 0 additions & 5 deletions airflow/executors/debug_executor.py
Expand Up @@ -155,8 +155,3 @@ def end(self) -> None:

def terminate(self) -> None:
self._terminated.set()

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
Expand Up @@ -400,7 +400,12 @@ def attempt_task_runs(self):
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
self.queued(task_key, task.task_arn)
try:
self.running_state(task_key, task.task_arn)
except AttributeError:
# running_state is newly added, and only needed to support task adoption (an optional
# executor feature).
pass
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/celery/executors/celery_executor.py
Expand Up @@ -368,8 +368,10 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
super().change_state(key, state, info)
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
super().change_state(key, state, info, remove_running=remove_running)
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved
self.tasks.pop(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
Expand Down
53 changes: 52 additions & 1 deletion tests/executors/test_base_executor.py
Expand Up @@ -33,7 +33,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState


def test_supports_sentry():
Expand Down Expand Up @@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, total_tries):
assert a.elapsed > min_seconds_for_test
assert a.total_tries == total_tries
assert a.tries_after_min == 1


def test_state_fail():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.fail(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info)


def test_state_success():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.success(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info)


def test_state_queued():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.queued(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)


def test_state_generic():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.queued(key, info=info)
assert not executor.running
assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)


def test_state_running():
executor = BaseExecutor()
key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
executor.running.add(key)
info = "info"
executor.running_state(key, info=info)
# Running state should not remove a command as running
assert executor.running
assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info)
Expand Up @@ -367,7 +367,8 @@ def test_stopped_tasks(self):
class TestAwsEcsExecutor:
"""Tests the AWS ECS Executor."""

def test_execute(self, mock_airflow_key, mock_executor):
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
def test_execute(self, change_state_mock, mock_airflow_key, mock_executor):
"""Test execution from end-to-end."""
airflow_key = mock_airflow_key()

Expand All @@ -393,6 +394,9 @@ def test_execute(self, mock_airflow_key, mock_executor):
# Task is stored in active worker.
assert 1 == len(mock_executor.active_workers)
assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn
change_state_mock.assert_called_once_with(
airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False
)

@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_success_execute_api_exception(self, mock_backoff, mock_executor):
Expand Down