Skip to content

Commit

Permalink
Fix trigger kwarg encryption migration (#39246)
Browse files Browse the repository at this point in the history
Do the encryption in the migration itself, and fix support for offline
migrations as well.

The offline up migration won't actually encrypt the trigger kwargs as there
isn't a safe way to accomplish that, so the decryption processes checks
and short circuits if it isn't encrypted.

The offline down migration will now print out a warning that the offline
migration will fail if there are any running triggers. I think this is
the best we can do for that scenario (and folks willing to do offline
migrations will hopefully be able to understand the situation).

This also solves the "encrypting the already encrypted kwargs" bug in
2.9.0.
  • Loading branch information
jedcunningham committed Apr 25, 2024
1 parent a8afa2e commit adeb7f7
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 52 deletions.
Expand Up @@ -16,18 +16,22 @@
# specific language governing permissions and limitations
# under the License.

"""update trigger kwargs type
"""update trigger kwargs type and encrypt
Revision ID: 1949afb29106
Revises: ee1467d4aa35
Create Date: 2024-03-17 22:09:09.406395
"""
import json
from textwrap import dedent

from alembic import context, op
import sqlalchemy as sa
from sqlalchemy.orm import lazyload

from airflow.serialization.serialized_objects import BaseSerialization
from airflow.models.trigger import Trigger
from alembic import op

from airflow.utils.sqlalchemy import ExtendedJSON

# revision identifiers, used by Alembic.
Expand All @@ -38,13 +42,43 @@
airflow_version = "2.9.0"


def get_session() -> sa.orm.Session:
conn = op.get_bind()
sessionmaker = sa.orm.sessionmaker()
return sessionmaker(bind=conn)

def upgrade():
"""Update trigger kwargs type to string"""
"""Update trigger kwargs type to string and encrypt"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=sa.Text(), )

if not context.is_offline_mode():
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.kwargs = trigger.kwargs
session.commit()
finally:
session.close()


def downgrade():
"""Unapply update trigger kwargs type to string"""
"""Unapply update trigger kwargs type to string and encrypt"""
if context.is_offline_mode():
print(dedent("""
------------
-- WARNING: Unable to decrypt trigger kwargs automatically in offline mode!
-- If any trigger rows exist when you do an offline downgrade, the migration will fail.
------------
"""))
else:
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()
finally:
session.close()

with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json")
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json')
10 changes: 9 additions & 1 deletion airflow/models/trigger.py
Expand Up @@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization

decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))
# We weren't able to encrypt the kwargs in all migration paths,
# so we need to handle the case where they are not encrypted.
# Triggers aren't long lasting, so we can skip encrypting them now.
if encrypted_kwargs.startswith("{"):
decrypted_kwargs = json.loads(encrypted_kwargs)
else:
decrypted_kwargs = json.loads(
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
)

return BaseSerialization.deserialize(decrypted_kwargs)

Expand Down
39 changes: 0 additions & 39 deletions airflow/utils/db.py
Expand Up @@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))


def encrypt_trigger_kwargs(*, session: Session) -> None:
"""Encrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

for trigger in session.query(Trigger):
# convert serialized dict to string and encrypt it
trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
session.commit()


def decrypt_trigger_kwargs(*, session: Session) -> None:
"""Decrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

if not inspect(session.bind).has_table(Trigger.__tablename__):
# table does not exist, nothing to do
# this can happen when we downgrade to an old version before the Trigger table was added
return

for trigger in session.scalars(select(Trigger.encrypted_kwargs)):
# decrypt the string and convert it to serialized dict
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()


def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table.
Expand Down Expand Up @@ -1666,12 +1639,6 @@ def upgradedb(
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
synchronize_log_template(session=session)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
_get_current_revision(session=session),
):
encrypt_trigger_kwargs(session=session)


@provide_session
Expand Down Expand Up @@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
to_revision,
):
decrypt_trigger_kwargs(session=session)


def drop_airflow_models(connection):
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8
77757e21aee500cb7fe7fd75e0f158633a0037d4d74e6f45eb14238f901ebacd
8 changes: 4 additions & 4 deletions docs/apache-airflow/img/airflow_erd.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/apache-airflow/migrations-ref.rst
Expand Up @@ -41,7 +41,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
+=================================+===================+===================+==============================================================+
| ``677fdbb7fc54`` (head) | ``1949afb29106`` | ``2.10.0`` | add new executor field to db |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type |
| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
Expand Down
17 changes: 17 additions & 0 deletions tests/models/test_trigger.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import datetime
import json
from typing import Any, AsyncIterator

import pytest
Expand All @@ -27,6 +28,7 @@
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs():
assert isinstance(trigger_row.encrypted_kwargs, str)
assert "value1" not in trigger_row.encrypted_kwargs
assert "value2" not in trigger_row.encrypted_kwargs


def test_kwargs_not_encrypted():
"""
Tests that we don't decrypt kwargs if they aren't encrypted.
We weren't able to encrypt the kwargs in all migration paths.
"""
trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
# force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade
trigger.encrypted_kwargs = json.dumps(
BaseSerialization.serialize({"param1": "value1", "param2": "value2"})
)

assert trigger.kwargs["param1"] == "value1"
assert trigger.kwargs["param2"] == "value2"

0 comments on commit adeb7f7

Please sign in to comment.