I want to call a TaskGroup with a Dynamic sub-task id from BranchPythonOperator.
This is the DAG flow that I have:
My case is I want to check whether a table exists in BigQuery or not.
If exists: do nothing and end the DAG
If not exists: Ingest the data from Postgres to Google Cloud Storage
I know that to call a TaskGroup from BranchPythonOperator is by calling the task id with following format:
group_task_id.task_id
The problem is, my task group's sub task id is dynamic, depends on how many time I loop the TaskGroup. So the sub_task will be:
parent_task_id.sub_task_1
parent_task_id.sub_task_2
parent_task_id.sub_task_3
...
parent_task_id.sub_task_x
This is the following code for the DAG that I have:
import airflow
from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator
from airflow.utils.task_group import TaskGroup
from google.cloud.exceptions import NotFound
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator
from google.cloud import bigquery
default_args = {
'owner': 'Airflow',
'start_date': airflow.utils.dates.days_ago(2),
}
with DAG(dag_id='branch_dag', default_args=default_args, schedule_interval=None) as dag:
def create_task_group(worker=1):
var = dict()
with TaskGroup(group_id='parent_task_id') as tg1:
for i in range(worker):
var[f'sub_task_{i}'] = PostgresToGCSOperator(
task_id = f'sub_task_{i}',
postgres_conn_id = 'some_postgres_conn_id',
sql = 'test.sql',
bucket = 'test_bucket',
filename = 'test_file.json',
export_format = 'json',
gzip = True,
params = {
'worker': worker
}
)
return tg1
def is_exists_table():
client = bigquery.Client()
try:
table_name = client.get_table('dataset_id.some_table')
if table_name:
return 'task_end'
except NotFound as error:
return 'parent_task_id'
task_start = DummyOperator(
task_id = 'start'
)
task_branch_table = BranchPythonOperator(
task_id ='check_table_exists_in_bigquery',
python_callable = is_exists_table
)
task_pg_to_gcs_init = create_task_group(worker=3)
task_end = DummyOperator(
task_id = 'end',
trigger_rule = 'all_done'
)
task_start >> task_branch_table >> task_end
task_start >> task_branch_table >> task_pg_to_gcs_init >> task_end
When I run the dag, it returns
**airflow.exceptions.TaskNotFound: Task parent_task_id not found
**
But this is expected, what I don't know is how to iterate the parent_task_id.sub_task_x
on is_exists_table
function. Or are there any workaround?
This is the test.sql
file if it's needed
SELECT
id,
name,
country
FROM some_table
WHERE 1=1
AND ABS(MOD(hashtext(id::TEXT), 3)) = {{params.worker}};
-- returns 1M+ rows
I already seen this question as reference Question but I think my case is more specific.