Skip to content

Commit

Permalink
migrate to dbt v3 api for project endpoints (#39214)
Browse files Browse the repository at this point in the history
* feat(providers/dbt): migrate to v3 api for /, {account_id}/projects/, and {account_id}/projects/{project_id}/
* refactor(dbt): make all arguments in _run_and_get_response keyword only argument
  • Loading branch information
Lee-W committed Apr 30, 2024
1 parent 778e8c5 commit d4bdffc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 32 deletions.
25 changes: 14 additions & 11 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Expand Up @@ -165,7 +165,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

class DbtCloudHook(HttpHook):
"""
Interact with dbt Cloud using the V2 API.
Interact with dbt Cloud using the V2 (V3 if supported) API.
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
"""
Expand Down Expand Up @@ -194,7 +194,7 @@ def _get_tenant_domain(conn: Connection) -> str:

@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: list[str] | None = None
tenant: str, endpoint: str, include_related: list[str] | None = None, *, api_version: str = "v2"
) -> tuple[str, dict[str, Any]]:
"""
Form URL from base url and endpoint url.
Expand All @@ -207,7 +207,7 @@ def get_request_url_params(
data: dict[str, Any] = {}
if include_related:
data = {"include_related": include_related}
url = f"https://{tenant}/api/v2/accounts/{endpoint or ''}"
url = f"https://{tenant}/api/{api_version}/accounts/{endpoint or ''}"
return url, data

async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
Expand Down Expand Up @@ -270,7 +270,7 @@ def connection(self) -> Connection:

def get_conn(self, *args, **kwargs) -> Session:
tenant = self._get_tenant_domain(self.connection)
self.base_url = f"https://{tenant}/api/v2/accounts/"
self.base_url = f"https://{tenant}/"

session = Session()
session.auth = self.auth_type(self.connection.password)
Expand Down Expand Up @@ -298,23 +298,26 @@ def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> lis

def _run_and_get_response(
self,
*,
method: str = "GET",
endpoint: str | None = None,
payload: str | dict[str, Any] | None = None,
paginate: bool = False,
api_version: str = "v2",
) -> Any:
self.method = method
full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None

if paginate:
if isinstance(payload, str):
raise ValueError("Payload cannot be a string to paginate a response.")

if endpoint:
return self._paginate(endpoint=endpoint, payload=payload)
else:
raise ValueError("An endpoint is needed to paginate a response.")
if full_endpoint:
return self._paginate(endpoint=full_endpoint, payload=payload)

return self.run(endpoint=endpoint, data=payload)
raise ValueError("An endpoint is needed to paginate a response.")

return self.run(endpoint=full_endpoint, data=payload)

def list_accounts(self) -> list[Response]:
"""
Expand Down Expand Up @@ -342,7 +345,7 @@ def list_projects(self, account_id: int | None = None) -> list[Response]:
:param account_id: Optional. The ID of a dbt Cloud account.
:return: List of request responses.
"""
return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True)
return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True, api_version="v3")

@fallback_to_default_account
def get_project(self, project_id: int, account_id: int | None = None) -> Response:
Expand All @@ -353,7 +356,7 @@ def get_project(self, project_id: int, account_id: int | None = None) -> Respons
:param account_id: Optional. The ID of a dbt Cloud account.
:return: The request response.
"""
return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/")
return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/", api_version="v3")

@fallback_to_default_account
def list_jobs(
Expand Down
51 changes: 30 additions & 21 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Expand Up @@ -46,8 +46,8 @@
JOB_ID = 4444
RUN_ID = 5555

BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/"
SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/"
BASE_URL = "https://cloud.getdbt.com/"
SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/"


class TestDbtCloudJobRunStatus:
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_get_account(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", data=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -229,7 +229,9 @@ def test_list_projects(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(endpoint=f"{_account_id}/projects/", payload=None)
hook._paginate.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None
)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand All @@ -245,7 +247,9 @@ def test_get_project(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/projects/{PROJECT_ID}/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -263,7 +267,7 @@ def test_list_jobs(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/jobs/", payload={"order_by": None, "project_id": None}
endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": None, "project_id": None}
)
hook.run.assert_not_called()

Expand All @@ -282,7 +286,8 @@ def test_list_jobs_with_payload(self, mock_http_run, mock_paginate, conn_id, acc

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/jobs/", payload={"order_by": "-id", "project_id": PROJECT_ID}
endpoint=f"api/v2/accounts/{_account_id}/jobs/",
payload={"order_by": "-id", "project_id": PROJECT_ID},
)
hook.run.assert_not_called()

Expand All @@ -300,7 +305,7 @@ def test_get_job(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/jobs/{JOB_ID}", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -319,7 +324,7 @@ def test_trigger_job_run(self, mock_http_run, mock_paginate, conn_id, account_id

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}),
)
hook._paginate.assert_not_called()
Expand Down Expand Up @@ -348,7 +353,7 @@ def test_trigger_job_run_with_overrides(self, mock_http_run, mock_paginate, conn

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps(
{"cause": cause, "steps_override": steps_override, "schema_override": schema_override}
),
Expand Down Expand Up @@ -376,7 +381,7 @@ def test_trigger_job_run_with_additional_run_configs(

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps(
{
"cause": cause,
Expand Down Expand Up @@ -405,7 +410,7 @@ def test_list_job_runs(self, mock_http_run, mock_paginate, conn_id, account_id):
_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/runs/",
endpoint=f"api/v2/accounts/{_account_id}/runs/",
payload={
"include_related": None,
"job_definition_id": None,
Expand All @@ -431,7 +436,7 @@ def test_list_job_runs_with_payload(self, mock_http_run, mock_paginate, conn_id,
_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/runs/",
endpoint=f"api/v2/accounts/{_account_id}/runs/",
payload={
"include_related": ["job"],
"job_definition_id": JOB_ID,
Expand All @@ -452,7 +457,7 @@ def test_get_job_runs(self, mock_http_run, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand All @@ -469,7 +474,7 @@ def test_get_job_run(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": None}
)
hook._paginate.assert_not_called()

Expand All @@ -488,7 +493,7 @@ def test_get_job_run_with_payload(self, mock_http_run, mock_paginate, conn_id, a

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]}
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -543,7 +548,9 @@ def test_cancel_job_run(self, mock_http_run, mock_paginate, conn_id, account_id)
assert hook.method == "POST"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/cancel/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -561,7 +568,7 @@ def test_list_job_run_artifacts(self, mock_http_run, mock_paginate, conn_id, acc

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None}
)
hook._paginate.assert_not_called()

Expand All @@ -579,7 +586,9 @@ def test_list_job_run_artifacts_with_payload(self, mock_http_run, mock_paginate,
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2})
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2}
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -598,7 +607,7 @@ def test_get_job_run_artifact(self, mock_http_run, mock_paginate, conn_id, accou

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None}
)
hook._paginate.assert_not_called()

Expand All @@ -618,7 +627,7 @@ def test_get_job_run_artifact_with_payload(self, mock_http_run, mock_paginate, c

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2}
)
hook._paginate.assert_not_called()

Expand Down

0 comments on commit d4bdffc

Please sign in to comment.