Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deferrable support to DatabricksNotebookOperator #39295

Merged
merged 28 commits into from May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
95d3a77
add deferrable support to DatabricksNotebookOperator
rawwar Apr 28, 2024
0d5cd7c
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 3, 2024
e7db2d3
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 4, 2024
054cf64
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 5, 2024
fb78207
Merge branch 'main' into kalyan/db-notebook-deferrable-support
rawwar May 5, 2024
59dc579
refactor defer call
rawwar May 6, 2024
521d921
Merge branch 'main' into kalyan/db-notebook-deferrable-support
rawwar May 6, 2024
e646a06
update caller
rawwar May 6, 2024
798bc22
update caller
rawwar May 6, 2024
e27966f
Update airflow/providers/databricks/operators/databricks.py
rawwar May 6, 2024
5c88c49
fix issue with repair_run check
rawwar May 6, 2024
cfdbcaf
Merge branch 'main' into kalyan/db-notebook-deferrable-support
rawwar May 6, 2024
a7ee133
update logs for failed state
rawwar May 6, 2024
1618ab7
Merge branch 'main' into kalyan/db-notebook-deferrable-support
rawwar May 6, 2024
dfecbf0
rewrite execute_complete
rawwar May 7, 2024
fcd550b
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 11, 2024
59a0f47
add test
rawwar May 11, 2024
1cb5b69
add test for termination before defer
rawwar May 11, 2024
52793ee
call execute in tests
rawwar May 11, 2024
dbf25f5
add run id in tests
rawwar May 11, 2024
5198c44
update execption message when job not successful
rawwar May 11, 2024
3a5494b
update error message in tests
rawwar May 11, 2024
46204df
fix tests
rawwar May 12, 2024
fd4041e
refactor tests
rawwar May 12, 2024
7c2dc0f
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 12, 2024
0698252
update execute_complete
rawwar May 12, 2024
5e8414e
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
rawwar May 13, 2024
e853173
assert Trigger type in deferrable test
rawwar May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/providers/databricks/hooks/databricks_base.py
Expand Up @@ -80,6 +80,7 @@ class BaseDatabricksHook(BaseHook):
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param caller: The name of the operator that is calling the hook.
"""

conn_name_attr: str = "databricks_conn_id"
Expand Down
40 changes: 34 additions & 6 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -167,7 +167,7 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)

error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}"

if event["repair_run"]:
if event.get("repair_run"):
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
Expand Down Expand Up @@ -923,9 +923,11 @@ class DatabricksNotebookOperator(BaseOperator):
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param databricks_conn_id: The name of the Airflow connection to use.
:param deferrable: Run operator in the deferrable mode.
"""

template_fields = ("notebook_params",)
CALLER = "DatabricksNotebookOperator"
rawwar marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -942,6 +944,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
self.notebook_path = notebook_path
Expand All @@ -958,19 +961,20 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
self.deferrable = deferrable
super().__init__(**kwargs)

@cached_property
def _hook(self) -> DatabricksHook:
return self._get_hook(caller="DatabricksNotebookOperator")
return self._get_hook(caller=self.CALLER)

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
caller=self.CALLER,
)

def _get_task_timeout_seconds(self) -> int:
Expand Down Expand Up @@ -1041,6 +1045,19 @@ def monitor_databricks_job(self) -> None:
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
if self.deferrable and not run_state.is_terminal:
return self.defer(
trigger=DatabricksExecutionTrigger(
run_id=self.databricks_run_id,
databricks_conn_id=self.databricks_conn_id,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.CALLER,
),
method_name=DEFER_METHOD_NAME,
)
while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(self.databricks_run_id)
Expand All @@ -1056,13 +1073,24 @@ def monitor_databricks_job(self) -> None:
)
if not run_state.is_successful:
raise AirflowException(
"Task failed. Final state %s. Reason: %s",
run_state.result_state,
run_state.state_message,
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)

def execute(self, context: Context) -> None:
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()

def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
2 changes: 2 additions & 0 deletions airflow/providers/databricks/triggers/databricks.py
Expand Up @@ -48,6 +48,7 @@ def __init__(
retry_args: dict[Any, Any] | None = None,
run_page_url: str | None = None,
repair_run: bool = False,
caller: str = "DatabricksExecutionTrigger",
) -> None:
super().__init__()
self.run_id = run_id
Expand All @@ -63,6 +64,7 @@ def __init__(
retry_limit=self.retry_limit,
retry_delay=self.retry_delay,
retry_args=retry_args,
caller=caller,
)

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand Down
48 changes: 46 additions & 2 deletions tests/providers/databricks/operators/test_databricks.py
Expand Up @@ -1865,6 +1865,50 @@ def test_execute_without_wait_for_termination(self):
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_execute_with_deferrable(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {"state": {"life_cycle_state": "PENDING"}}
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=True,
deferrable=True,
)
operator.databricks_run_id = 12345

with pytest.raises(TaskDeferred) as exec_info:
operator.monitor_databricks_job()
assert isinstance(
exec_info.value.trigger, DatabricksExecutionTrigger
), "Trigger is not a DatabricksExecutionTrigger"
assert exec_info.value.method_name == "execute_complete"

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_execute_with_deferrable_early_termination(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "FAILURE",
}
}
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=True,
deferrable=True,
)
operator.databricks_run_id = 12345

with pytest.raises(AirflowException) as exec_info:
operator.monitor_databricks_job()
exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
assert exception_message == str(exec_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
Expand Down Expand Up @@ -1896,10 +1940,10 @@ def test_monitor_databricks_job_failed(self, mock_databricks_hook):

operator.databricks_run_id = 12345

exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'"
with pytest.raises(AirflowException) as exc_info:
operator.monitor_databricks_job()
assert exception_message in str(exc_info.value)
exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
assert exception_message == str(exc_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_launch_notebook_job(self, mock_databricks_hook):
Expand Down