Skip to content

Commit

Permalink
Feature: Support using content of kubeconfig to create KubernetesHook
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyuliuyin committed Apr 25, 2024
1 parent d08f893 commit e4689de
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 5 deletions.
11 changes: 7 additions & 4 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Expand Up @@ -66,8 +66,8 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
- use in cluster configuration by using extra field ``in_cluster`` in connection
- use custom config by providing path to the file using extra field ``kube_config_path`` in connection
- use custom configuration by providing content of kubeconfig file via
extra field ``kube_config`` in connection
- use custom configuration by providing content of kubeconfig file using extra field ``kube_config``
or via extra field ``kube_config`` in connection
- use default config by providing no extras
This hook check for configuration option in the above order. Once an option is present it will
Expand All @@ -84,6 +84,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
:param cluster_context: Optionally specify a context to use (e.g. if you have multiple
in your kubeconfig.
:param config_file: Path to kubeconfig file.
:param kube_config: content of kubeconfig file.
:param in_cluster: Set to ``True`` if running from within a kubernetes cluster.
:param disable_verify_ssl: Set to ``True`` if SSL verification should be disabled.
:param disable_tcp_keepalive: Set to ``True`` if you want to disable keepalive logic.
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
client_configuration: client.Configuration | None = None,
cluster_context: str | None = None,
config_file: str | None = None,
kube_config: str | None = None,
in_cluster: bool | None = None,
disable_verify_ssl: bool | None = None,
disable_tcp_keepalive: bool | None = None,
Expand All @@ -144,6 +146,7 @@ def __init__(
self.client_configuration = client_configuration
self.cluster_context = cluster_context
self.config_file = config_file
self.kube_config = kube_config
self.in_cluster = in_cluster
self.disable_verify_ssl = disable_verify_ssl
self.disable_tcp_keepalive = disable_tcp_keepalive
Expand Down Expand Up @@ -203,7 +206,7 @@ def get_conn(self) -> client.ApiClient:
in_cluster = self._coalesce_param(self.in_cluster, self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, self._get_field("cluster_context"))
kubeconfig_path = self._coalesce_param(self.config_file, self._get_field("kube_config_path"))
kubeconfig = self._get_field("kube_config")
kubeconfig = self._coalesce_param(self.kube_config, self._get_field("kube_config"))
num_selected_configuration = sum(1 for o in [in_cluster, kubeconfig, kubeconfig_path] if o)

if num_selected_configuration > 1:
Expand Down Expand Up @@ -645,7 +648,7 @@ async def _load_config(self):
in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
kubeconfig_path = self._coalesce_param(self.config_file, await self._get_field("kube_config_path"))
kubeconfig = await self._get_field("kube_config")
kubeconfig = await self._coalesce_param(self.kube_config, self._get_field("kube_config"))

num_selected_configuration = sum(1 for o in [in_cluster, kubeconfig, kubeconfig_path] if o)

Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/cncf/kubernetes/operators/job.py
Expand Up @@ -130,6 +130,7 @@ def hook(self) -> KubernetesHook:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)
return hook
Expand Down Expand Up @@ -185,6 +186,7 @@ def execute_deferrable(self):
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_file=self.config_file,
kube_config=self.kube_config,
in_cluster=self.in_cluster,
poll_interval=self.job_poll_interval,
),
Expand Down Expand Up @@ -363,6 +365,7 @@ class KubernetesDeleteJobOperator(BaseOperator):
for the Kubernetes cluster.
:param config_file: The path to the Kubernetes config file. (templated)
If not specified, default value is ``~/.kube/config``
:param kube_config: content of kubeconfig file.
:param in_cluster: run kubernetes client with in_cluster configuration.
:param cluster_context: context that points to kubernetes cluster.
Ignored when in_cluster is True. If None, current-context is used. (templated)
Expand All @@ -388,6 +391,7 @@ def __init__(
namespace: str,
kubernetes_conn_id: str | None = KubernetesHook.default_conn_name,
config_file: str | None = None,
kube_config: str | None = None,
in_cluster: bool | None = None,
cluster_context: str | None = None,
delete_on_status: str | None = None,
Expand All @@ -400,6 +404,7 @@ def __init__(
self.namespace = namespace
self.kubernetes_conn_id = kubernetes_conn_id
self.config_file = config_file
self.kube_config = kube_config
self.in_cluster = in_cluster
self.cluster_context = cluster_context
self.delete_on_status = delete_on_status
Expand All @@ -412,6 +417,7 @@ def hook(self) -> KubernetesHook:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)

Expand Down Expand Up @@ -473,6 +479,7 @@ class KubernetesPatchJobOperator(BaseOperator):
for the Kubernetes cluster.
:param config_file: The path to the Kubernetes config file. (templated)
If not specified, default value is ``~/.kube/config``
:param kube_config: content of kubeconfig file.
:param in_cluster: run kubernetes client with in_cluster configuration.
:param cluster_context: context that points to kubernetes cluster.
Ignored when in_cluster is True. If None, current-context is used. (templated)
Expand All @@ -494,6 +501,7 @@ def __init__(
body: object,
kubernetes_conn_id: str | None = KubernetesHook.default_conn_name,
config_file: str | None = None,
kube_config: str | None = None,
in_cluster: bool | None = None,
cluster_context: str | None = None,
**kwargs,
Expand All @@ -504,6 +512,7 @@ def __init__(
self.body = body
self.kubernetes_conn_id = kubernetes_conn_id
self.config_file = config_file
self.kube_config = kube_config
self.in_cluster = in_cluster
self.cluster_context = cluster_context

Expand All @@ -513,6 +522,7 @@ def hook(self) -> KubernetesHook:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)

Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Expand Up @@ -164,6 +164,7 @@ class KubernetesPodOperator(BaseOperator):
:param affinity: affinity scheduling rules for the launched pod.
:param config_file: The path to the Kubernetes config file. (templated)
If not specified, default value is ``~/.kube/config``
:param kube_config: content of kubeconfig file.
:param node_selector: A dict containing a group of scheduling rules.
:param image_pull_secrets: Any image pull secrets to be given to the pod.
If more than one secret is required, provide a
Expand Down Expand Up @@ -284,6 +285,7 @@ def __init__(
container_resources: k8s.V1ResourceRequirements | None = None,
affinity: k8s.V1Affinity | None = None,
config_file: str | None = None,
kube_config: str | None = None,
node_selector: dict | None = None,
image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
service_account_name: str | None = None,
Expand Down Expand Up @@ -359,6 +361,7 @@ def __init__(
self.affinity = convert_affinity(affinity) if affinity else {}
self.container_resources = container_resources
self.config_file = config_file
self.kube_config = kube_config
self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else []
self.service_account_name = service_account_name
self.hostnetwork = hostnetwork
Expand Down Expand Up @@ -510,6 +513,7 @@ def hook(self) -> PodOperatorHookProtocol:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)
return hook
Expand Down Expand Up @@ -676,6 +680,7 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None):
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_file=self.config_file,
kube_config=self.kube_config,
in_cluster=self.in_cluster,
poll_interval=self.poll_interval,
get_logs=self.get_logs,
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/cncf/kubernetes/operators/resource.py
Expand Up @@ -66,6 +66,7 @@ def __init__(
custom_resource_definition: bool = False,
namespaced: bool = True,
config_file: str | None = None,
kube_config: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -76,6 +77,7 @@ def __init__(
self.custom_resource_definition = custom_resource_definition
self.namespaced = namespaced
self.config_file = config_file
self.kube_config = kube_config

if not any([self.yaml_conf, self.yaml_conf_file]):
raise AirflowException("One of `yaml_conf` or `yaml_conf_file` arguments must be provided")
Expand All @@ -90,7 +92,9 @@ def custom_object_client(self) -> CustomObjectsApi:

@cached_property
def hook(self) -> KubernetesHook:
hook = KubernetesHook(conn_id=self.kubernetes_conn_id, config_file=self.config_file)
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id, config_file=self.config_file, kube_config=self.kube_config
)
return hook

def get_namespace(self) -> str:
Expand Down
Expand Up @@ -265,6 +265,7 @@ def hook(self) -> KubernetesHook:
in_cluster=self.in_cluster or self.template_body.get("kubernetes", {}).get("in_cluster", False),
config_file=self.config_file
or self.template_body.get("kubernetes", {}).get("kube_config_file", None),
kube_config=self.kube_config,
cluster_context=self.cluster_context
or self.template_body.get("kubernetes", {}).get("cluster_context", None),
)
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/cncf/kubernetes/triggers/job.py
Expand Up @@ -36,6 +36,7 @@ class KubernetesJobTrigger(BaseTrigger):
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param kube_config: content of kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param in_cluster: run kubernetes client with in_cluster configuration.
"""
Expand All @@ -48,6 +49,7 @@ def __init__(
poll_interval: float = 10.0,
cluster_context: str | None = None,
config_file: str | None = None,
kube_config: str | None = None,
in_cluster: bool | None = None,
):
super().__init__()
Expand All @@ -57,6 +59,7 @@ def __init__(
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.kube_config = kube_config
self.in_cluster = in_cluster

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -70,6 +73,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"kube_config": self.kube_config,
"in_cluster": self.in_cluster,
},
)
Expand Down Expand Up @@ -97,5 +101,6 @@ def hook(self) -> AsyncKubernetesHook:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)
5 changes: 5 additions & 0 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Expand Up @@ -62,6 +62,7 @@ class KubernetesPodTrigger(BaseTrigger):
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param kube_config: content of kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param trigger_start_time: time in Datetime format when the trigger was started
:param in_cluster: run kubernetes client with in_cluster configuration.
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
poll_interval: float = 2,
cluster_context: str | None = None,
config_file: str | None = None,
kube_config: str | None = None,
in_cluster: bool | None = None,
get_logs: bool = True,
startup_timeout: int = 120,
Expand All @@ -108,6 +110,7 @@ def __init__(
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.kube_config = kube_config
self.in_cluster = in_cluster
self.get_logs = get_logs
self.startup_timeout = startup_timeout
Expand Down Expand Up @@ -143,6 +146,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"kube_config": self.kube_config,
"in_cluster": self.in_cluster,
"get_logs": self.get_logs,
"startup_timeout": self.startup_timeout,
Expand Down Expand Up @@ -281,6 +285,7 @@ def _get_async_hook(self) -> AsyncKubernetesHook:
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
kube_config=self.kube_config,
cluster_context=self.cluster_context,
)

Expand Down
2 changes: 2 additions & 0 deletions tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
Expand Up @@ -92,6 +92,7 @@ def f():
mock_hook.assert_called_once_with(
conn_id="kubernetes_default",
in_cluster=False,
kube_config=None,
cluster_context="default",
config_file="/tmp/fake_file",
)
Expand Down Expand Up @@ -142,6 +143,7 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None):
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
kube_config=None,
)
assert mock_create_pod.call_count == 1
assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
Expand Down
49 changes: 49 additions & 0 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Expand Up @@ -279,6 +279,55 @@ def test_kube_config_path(
mock_kube_config_loader.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
"conn_id, kube_config, has_conn_id, has_kube_config",
(
(None, None, False, False),
(None, "content of kubeconfig file", False, True),
("kube_config", None, True, False),
("kube_config", "content of kubeconfig file", True, True),
),
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch.object(tempfile, "NamedTemporaryFile")
def test_kube_config(
self,
mock_tempfile,
mock_kube_config_merger,
mock_kube_config_loader,
conn_id,
kube_config,
has_conn_id,
has_kube_config
):
"""
Verifies whether temporary kube config file is created.
"""
mock_tempfile.return_value.__enter__.return_value.name = "fake-temp-file"
mock_kube_config_merger.return_value.config = {"fake_config": "value"}
kubernetes_hook = KubernetesHook(conn_id=conn_id, kube_config=kube_config)
api_conn = kubernetes_hook.get_conn()
if has_conn_id:
if has_kube_config:
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with("fake-temp-file")
else:
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with("fake-temp-file")
else:
if has_kube_config:
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with("fake-temp-file")
else:
mock_tempfile.assert_not_called()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with(KUBE_CONFIG_PATH)
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
"conn_id, has_config",
(
Expand Down
1 change: 1 addition & 0 deletions tests/providers/cncf/kubernetes/operators/test_job.py
Expand Up @@ -587,6 +587,7 @@ def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable):
cluster_context=mock_cluster_context,
config_file=mock_config_file,
in_cluster=mock_in_cluster,
kube_config=None,
poll_interval=POLL_INTERVAL,
)
assert actual_result is None
Expand Down
1 change: 1 addition & 0 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Expand Up @@ -232,6 +232,7 @@ def test_config_path(self, hook_mock):
conn_id="kubernetes_default",
config_file=file_path,
in_cluster=None,
kube_config=None
)

@pytest.mark.parametrize(
Expand Down
2 changes: 2 additions & 0 deletions tests/providers/cncf/kubernetes/triggers/test_job.py
Expand Up @@ -61,6 +61,7 @@ def test_serialize(self, trigger):
"poll_interval": POLL_INTERVAL,
"cluster_context": CLUSTER_CONTEXT,
"config_file": CONFIG_FILE,
"kube_config": None,
"in_cluster": IN_CLUSTER,
}

Expand Down Expand Up @@ -130,6 +131,7 @@ def test_hook(self, mock_hook, trigger):
conn_id=CONN_ID,
in_cluster=IN_CLUSTER,
config_file=CONFIG_FILE,
kube_config=None,
cluster_context=CLUSTER_CONTEXT,
)
assert hook_actual == hook_expected
1 change: 1 addition & 0 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Expand Up @@ -102,6 +102,7 @@ def test_serialize(self, trigger):
"poll_interval": POLL_INTERVAL,
"cluster_context": CLUSTER_CONTEXT,
"config_file": CONFIG_FILE,
"kube_config": None,
"in_cluster": IN_CLUSTER,
"get_logs": GET_LOGS,
"startup_timeout": STARTUP_TIMEOUT_SECS,
Expand Down

0 comments on commit e4689de

Please sign in to comment.