0

def skip_update_job_pod_name(dag):
    """
    :param dag: Airflow DAG
    :return: Dummy operator to skip update pod name
    """
    return DummyOperator(task_id="skip_update_job_pod_name", dag=dag)


def update_pod_name_branch_operator(dag: DAG, job_id: str):
    """branch operator to update pod name."""
    return BranchPythonOperator(
        dag=dag,
        trigger_rule="all_done",
        task_id="update_pod_name",
        python_callable=update_pod_name_func,
        op_kwargs={"job_id": job_id},
    )


def update_pod_name_func(job_id: Optional[str]) -> str:
    """function for update pod name."""
    return "update_job_pod_name" if job_id else "skip_update_pod_name"


def update_job_pod_name(dag: DAG, job_id: str, process_name: str) -> MySqlOperator:
    """
    :param dag: Airflow DAG
    :param job_id: Airflow Job ID
    :param process_name: name of the current running process
    :return: MySqlOperator to update Airflow job ID
    """
    return MySqlOperator(
        task_id="update_job_pod_name",
        mysql_conn_id="semantic-search-airflow-sdk",
        autocommit=True,
        sql=[
            f"""
                INSERT INTO airflow.Pod (job_id, pod_name, task_name)
                SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
                WHERE NOT EXISTS (
                    SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
                ) LIMIT 1;
            """
        ],
        task_concurrency=1,
        dag=dag,
        trigger_rule="all_done",
    )

def create_k8s_pod_operator_without_volume(dag: DAG,
                                           job_id: int,
                                           ....varaible) -> TaskGroup:
    """
    Create task group for k8 operator without volume
    """
    with TaskGroup(group_id="k8s_pod_operator_without_volume", dag=dag) as eks_without_volume_group:
        emit_pod_name_branch = update_pod_name_branch_operator(dag=dag, job_id=job_id)
        update_pod_name = update_job_pod_name(dag=dag, job_id=job_id, process_name=process_name)
        skip_update_pod_name = skip_update_job_pod_name(dag=dag)
       emit_pod_name_branch >> [update_pod_name, skip_update_pod_name]
    return eks_without_volume_group

I update the code based on the comment, I am curious how does the taskgroup work with branch operator I will get this when I try to do this airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_job_pod_name'}

WOWpopo
  • 15
  • 4

1 Answers1

0

You can use BranchPythonOperator that get the value and return which the name of task to run in any condition.

def choose_job_func(job_id):
    if job_id:
        return "update_pod_name_rds"


choose_update_job =BranchPythonOperator(task_id="choose_update_job", python_callable=choose_job_func,
                     op_kwargs={"job_id": "{{ params.job_id }}"})

or, in task flow api it would look like this :

@task.branch
def choose_update_job(job_id):
    if job_id:
        return "update_pod_name_rds"

Full Dag Example :

with DAG(
dag_id="test_dag",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
render_template_as_native_obj=True,
params={
    "job_id": Param(default=None, type=["null", "string"])
},
tags=["test"],) as dag:

def update_job_pod_name(job_id: str, process_name: str):
    return MySqlOperator(
        task_id="update_pod_name_rds",
        mysql_conn_id="semantic-search-airflow-sdk",
        autocommit=True,
        sql=[
            f"""
                        INSERT INTO airflow.Pod (job_id, pod_name, task_name)
                        SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
                        WHERE NOT EXISTS (
                            SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
                        ) LIMIT 1;
                    """
        ],
        task_concurrency=1,
        dag=dag,
        trigger_rule="all_done",
    )

@task.branch
def choose_update_job(job_id):
    print(job_id)
    if job_id:
        return "update_pod_name_rds"
    return "do_nothing"

sql_task = update_job_pod_name(
    "{{ params.job_id}}",
    "process_name",
)
do_nothing = EmptyOperator(task_id="do_nothing")
start_dag = EmptyOperator(task_id="start")
end_dag = EmptyOperator(task_id="end", trigger_rule=TriggerRule.ONE_SUCCESS)

(start_dag >> choose_update_job("{{ params.job_id }}") >> [sql_task, do_nothing] >> end_dag)
ozs
  • 3,051
  • 1
  • 10
  • 19
  • I noticed this will failed due to this error. airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_pod_name_rds'} – WOWpopo Jul 07 '22 at 22:39
  • the BranchPythonOperator should return an exist task_id. I edit my answer with full example for better understanding – ozs Jul 08 '22 at 10:13
  • Got it thanks a lot! Do you happen to know how will this work with a taskgroup? I want to put all into a task group but it will also complain about Branch callable must return valid task_ids. Invalid tasks found: {'update_pod_name_rds'} I updated my code with the taskgroup code – WOWpopo Jul 08 '22 at 21:56
  • @WOWpopo, can you share please the full code with the TaskGroup? I succeeded to define in TaskGroup. also if you can accept my answer I'll appreciated it. – ozs Jul 09 '22 at 06:00
  • @WOWpopo. Add the name of the task group to the name of the taskid in the return value. Taskgroup.taskid – ozs Jul 09 '22 at 06:47