Skip to content

Commit

Permalink
Bugfix yaml parsing for GKEStartKueueInsideClusterOperator (#39234)
Browse files Browse the repository at this point in the history
* Bugfix yaml parsing for GKEStartKueueInsideClusterOperator

* Unit tests
  • Loading branch information
moiseenkov committed May 7, 2024
1 parent 06b3b02 commit 287c107
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
14 changes: 3 additions & 11 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Expand Up @@ -19,7 +19,6 @@

from __future__ import annotations

import re
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
Expand Down Expand Up @@ -566,17 +565,10 @@ def pod_hook(self) -> GKEPodHook:
def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
"""Download content of YAML file and separate it into several dictionaries."""
response = requests.get(kueue_yaml_url, allow_redirects=True)
yaml_dicts = []
if response.status_code == 200:
yaml_data = response.text
documents = re.split(r"---\n", yaml_data)

for document in documents:
document_dict = yaml.safe_load(document)
yaml_dicts.append(document_dict)
else:
if response.status_code != 200:
raise AirflowException("Was not able to read the yaml file from given URL")
return yaml_dicts

return list(yaml.safe_load_all(response.text))

def execute(self, context: Context):
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
Expand Down
22 changes: 22 additions & 0 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Expand Up @@ -129,6 +129,7 @@
requests:
storage: 5Gi
"""
KUEUE_YAML_URL = "http://test-url/config.yaml"


class TestGoogleCloudPlatformContainerOperator:
Expand Down Expand Up @@ -641,6 +642,27 @@ def test_gcp_conn_id(self, mock_get_credentials):

assert hook.gcp_conn_id == "test_conn"

@mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests")
@mock.patch(f"{GKE_HOOK_MODULE_PATH}.yaml")
def test_get_yaml_content_from_file(self, mock_yaml, mock_requests):
yaml_content_expected = [mock.MagicMock(), mock.MagicMock()]
mock_yaml.safe_load_all.return_value = yaml_content_expected
response_text_expected = "response test expected"
mock_requests.get.return_value = mock.MagicMock(status_code=200, text=response_text_expected)

yaml_content_actual = GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL)

assert yaml_content_actual == yaml_content_expected
mock_requests.get.assert_called_once_with(KUEUE_YAML_URL, allow_redirects=True)
mock_yaml.safe_load_all.assert_called_once_with(response_text_expected)

@mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests")
def test_get_yaml_content_from_file_exception(self, mock_requests):
mock_requests.get.return_value = mock.MagicMock(status_code=400)

with pytest.raises(AirflowException):
GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL)


class TestGKEPodOperatorAsync:
def setup_method(self):
Expand Down
Expand Up @@ -103,7 +103,7 @@
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
kueue_version="v0.5.1",
kueue_version="v0.6.2",
)
# [END howto_operator_gke_install_kueue]

Expand Down

0 comments on commit 287c107

Please sign in to comment.