I have a simple DAG of three operators. The first one is PythonOperator
with our own functionality, the other two are standard operators from airflow.contrib
(FileToGoogleCloudStorageOperator
and GoogleCloudStorageToBigQueryOperator
to be precise). They work in sequence. Our custom task produces a number of files, typically between 2 and 5, depending on the parameters. All of these files have to be processed by subsequent tasks separately. That means I want several downstream branches, but it's unknowable how many exactly before the DAG is run.
How would you approach this problem?
UPDATE:
Using BranchPythonOperator
that jhnclvr mentioned in his another reply as a point of departure, I created an operator that would skip or continue executing a branch, depending on condition. This approach is feasible only because highest possible number of branches is known and sufficiently small.
The operator:
class SkipOperator(PythonOperator):
def execute(self, context):
boolean = super(SkipOperator, self).execute(context)
session = settings.Session()
for task in context['task'].downstream_list:
if boolean is False:
ti = TaskInstance(
task, execution_date=context['ti'].execution_date)
ti.state = State.SKIPPED
ti.start_date = datetime.now()
ti.end_date = datetime.now()
session.merge(ti)
session.commit()
session.close()
Usage:
def check(i, templates_dict=None, **kwargs):
return len(templates_dict["data_list"].split(",")) > i
dag = DAG(
dag_name,
default_args=default_args,
schedule_interval=None
)
load = CustomOperator(
task_id="load_op",
bash_command=' '.join([
'./command.sh'
'--data-list {{ dag_run.conf["data_list"]|join(",") }}'
]),
dag=dag
)
for i in range(0, 5):
condition = SkipOperator(
task_id=f"{dag_name}_condition_{i}",
python_callable=partial(check, i),
provide_context=True,
templates_dict={"data_list": '{{ dag_run.conf["data_list"]|join(",") }}'},
dag=dag
)
gs_filename = 'prefix_{{ dag_run.conf["data_list"][%d] }}.json' % i
load_to_gcs = CustomFileToGoogleCloudStorageOperator(
task_id=f"{dag_name}_to_gs_{i}",
src='/tmp/{{ run_id }}_%d.{{ dag_run.conf["file_extension"] }}' % i,
bucket=gs_bucket,
dst=gs_filename,
mime_type='application/json',
google_cloud_storage_conn_id=connection_id,
dag=dag
)
load_to_bq = GoogleCloudStorageToBigQueryOperator(
task_id=f"{dag_name}_to_bq_{i}",
bucket=gs_bucket,
source_objects=[gs_filename, ],
source_format='NEWLINE_DELIMITED_JSON',
destination_project_dataset_table='myproject.temp_{{ dag_run.conf["data_list"][%d] }}' % i,
bigquery_conn_id=connection_id,
schema_fields={},
google_cloud_storage_conn_id=connection_id,
write_disposition='WRITE_TRUNCATE',
dag=dag
)
condition.set_upstream(load)
load_to_gcs.set_upstream(condition)
load_to_bq.set_upstream(load_to_gcs)