0

I have 2 sets of operators in Airflow that I run in parallel, with one set being downstream of the first parallel set.

chain([task_1a, task_2a, task_3a],  [task_1b, task_2b, task_3b], end_task)

I utilized the chain() operator since the >> bitshift operator isn't compatible between two lists, i.e.,

[task_1a, task_2a, task_3a] >>  [task_1b, task_2b, task_3b] >> end_task

Now I want to add a second variation of this operator pipeline based on some condition. I figured I could do this via branching and the BranchPythonOperator. AFAIK the BranchPythonOperator will return either one task ID string or a list of task ID strings. However, I have not found any public documentation or successful examples of using the BranchPythonOperator to return a chained sequence of tasks involving parallel tasks.

I've tried the method below as well as other variations, but so far I've encountered issues with operators downstream of 'option1' or 'option2' being skipped entirely.

Another Stackoverflow post mentioned needing to alter the trigger rules with branching - so I've also tried setting the trigger_rule of the end_task to 'all_success' but that has no effect either.

from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator

from airflow.utils.helpers import chain

def _choose_best_model():
    value = 6

    if value > 10:
        return 'option1'
    else:
         return 'option2’
with DAG('branching', schedule_interval='@daily', default_args=default_args, catchup=False) as dag:

choose_best_model = BranchPythonOperator(
task_id='choose_best_model',
python_callable=_choose_best_model
)
option1 = DummyOperator(
task_id='option1'
)
option2 = DummyOperator(
task_id='option2'
)

#Parallel tasks
task_1a = DummyOperator(
task_id='task_1a'
)
task_2a = DummyOperator(
task_id='task_2a'
)
task_3a = DummyOperator(
task_id='task_3a'
)

task_1b = DummyOperator(
task_id='task_1b'
)
task_2b = DummyOperator(
task_id='task_2b'
)
task_3b = DummyOperator(
task_id='task_3b'
)

end_task = DummyOperator(
task_id='end_task'
)

choose_best_model >> [option1, option2]

chain(option1, [task_1a, task_2a, task_3a],  [task_1b, task_2b, task_3b], end_task)
chain(option2, [task_1a, task_2a],  [task_1b, task_2b], end_task)


hamhung
  • 53
  • 8

1 Answers1

0

To me the example you shared feels like wrong usage of branching because option1 is invoking everything. This doesn't seem to be a case where base on condition the workflow branches between separated flow of tasks. This seems to be more like the use case of several questions like "should I execute task_3a"? For that you should use ShortCircuitOperator which allows you to skip execution based on condition. I believe this will simplify this DAG significantly.

However, since you asked about branching I'll answer about that. But before I do, I'd like to highlight that the code you shared is showing odd branching behavior (and sometimes illogical behavior) so if my assumptions are incorrect please revise them to your real use case. From now on, I'm ignoring the issue of odd branching behavior for the simplification of explnation.

The issue you have is because you did not set trigger rules on the junction nodes task_2a is junction node with trigger rule of all_success (default) which means all direct upstream tasks must be in success state. task_2a has upstream of 2 tasks (option1 and option2) but we know that they have upstream of the branching task thus there will never be case that both of them are success. Only one of them can success and the other one will be skipped. This means that task_2a will always be skipped. As also shown by executing your code:

enter image description here

Same concept apply to all other nodes that have direct upstream of option1 and option2.

What you need to do is to change the task trigger rule of task_1a, task_2a, task_3a:

task_1a = DummyOperator(
    task_id='task_1a',
    trigger_rule="none_failed_min_one_success"
)

Now if we will execute the modified code we will have:

enter image description here

But as you can see, that did not solve all issues. We need also to handle end_task. This task/node also has similar problem. There will never be a case where all it's direct upstream are success. So we need to change it's trigger_rule as well:

end_task = DummyOperator(
    task_id='end_task',
    trigger_rule="none_failed"
)

After applying these changes, if you will start the dag you will get:

enter image description here

Now, back to the beginning where I explain that the branching logic you set doesn't make sense. You can actually see this if we will change the condition to be value = 11 in the python callable it will show what happens when the option1 is evaluated to true. This will give: enter image description here

Which makes little sense. I assume that the code you shared is something you constructed for learning opportunity and not a representation of actual workflow logic.

Note: DummyOperator is deprecated since Airflow 2.3.0 you should use EmptyOperator. See this answer for more details.

Elad Kalif
  • 14,110
  • 2
  • 17
  • 49
  • Hi thanks for the answer. On your note: end_task = DummyOperator( task_id='end_task', trigger_rule="none_failed_min_one_success" ). Let's say the 'end_task' also requires any tasks that are not skipped to all finish before the 'end_task' operation can begin, and the series of tasks running in parallel may finish at different times (e.g., task_2b finishes 1 hour before task_1b. Would the 'end_task' operator need to use the "none_failed" trigger rule instead of 'none_failed_min_one_success'? – hamhung Mar 07 '23 at 07:50
  • @hamhung yes. Let me update the answer. – Elad Kalif Mar 07 '23 at 08:12