Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example DAG for databricks delta support #1585

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
97 changes: 97 additions & 0 deletions python-sdk/example_dags/example_s3_to_databricks_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import time
from datetime import datetime, timedelta

import pandas as pd

# Uses data from https://www.kaggle.com/c/shelter-animal-outcomes
from airflow.decorators import dag

from astro import sql as aql
from astro.files import File
from astro.table import Table


@aql.transform()
def combine_data(center_1: Table, center_2: Table):
return """SELECT * FROM {{center_1}}
UNION SELECT * FROM {{center_2}}"""


@aql.transform()
def clean_data(input_table: Table):
return """SELECT *
FROM {{input_table}} WHERE type NOT LIKE 'Guinea Pig'
"""


# Please note that this function will move a delta table into a local dataframe and not a spark dataframe.
# This is not recommended for large tables.
@aql.dataframe(columns_names_capitalization="original")
def aggregate_data(df: pd.DataFrame):
new_df = df.pivot_table(index="date", values="name", columns=["type"], aggfunc="count").reset_index()
new_df.columns = new_df.columns.str.lower()
return new_df


@dag(
start_date=datetime(2021, 1, 1),
max_active_runs=1,
schedule_interval="@daily",
default_args={
"email_on_failure": False,
"retries": 0,
"retry_delay": timedelta(minutes=5),
},
catchup=False,
)
def example_amazon_s3_delta_transform():
s3_bucket = os.getenv("S3_BUCKET", "s3://tmp9")

input_table_1 = Table(
name="ADOPTION_CENTER_1_" + str(int(time.time())),
conn_id="databricks_conn",
temp=True,
)
# [START metadata_example_delta]
input_table_2 = Table(
name="ADOPTION_CENTER_2_" + str(int(time.time())),
conn_id="databricks_conn",
temp=True,
)
# [END metadata_example_delta]

temp_table_1 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1_unquoted.csv", conn_id="aws_default"),
output_table=input_table_1,
)
temp_table_2 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2_unquoted.csv", conn_id="aws_default"),
output_table=input_table_2,
)

combined_data = combine_data(
center_1=temp_table_1,
center_2=temp_table_2,
)

cleaned_data = clean_data(combined_data)
# [START dataframe_example_2]
delta_output_table = Table(
name="aggregated_adoptions_" + str(int(time.time())),
conn_id="delta_conn",
temp=True,
)
aggregate_data(
cleaned_data,
output_table=delta_output_table,
)
# [END dataframe_example_2]
aql.cleanup()


dag = example_amazon_s3_delta_transform()

if __name__ == "__main__":
conn_file = os.getenv("CONN_FILE_PATH")
dag.test(conn_file_path=conn_file)