0

I have an SSHOperator that writes a filepath to stdout. I'd like to get the os.path.basename of that filepath so that I can pass it as a parameter to my next task (which is an sftp pull). The idea is to download a remote file into the current working directory. This is what I have so far:

with DAG('my_dag',
         default_args = dict(...
                             xcom_push = True,
                             )
         ) as dag:

    # there is a get_update_id task here, which has been snipped for brevity

    get_results  = SSHOperator(task_id = 'get_results',
                               ssh_conn_id = 'my_remote_server',
                               command = """cd ~/path/to/dir && python results.py -t p -u {{ task_instance.xcom_pull(task_ids='get_update_id') }}""",
                               cmd_timeout = -1,
                               )

    download_results = SFTPOperator(task_id = 'download_results',
                                    ssh_conn_id = 'my_remote_server',
                                    remote_filepath = base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }}"""),
                                    local_filepath = os.path.basename(base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }""").decode()),
                                    operation = 'get',
                                    )

Airflow tells me there's an error on the remote_filepath = line. Investigating this further, I see that the value passed to base64.b64decode is not the xcom value from the get_results task, but is rather the raw string starting with {{.

My feeling is that since tasks are templated, there's some under-the-hood magic to resolve the templated string. Whereas this is not exactly supported by os.path.basename. So would I need to create an intermediate task to get the basename? Is there no way to shorthand this the way I've tried?

I'd appreciate any help on this

inspectorG4dget
  • 110,290
  • 27
  • 149
  • 241
  • Tasks may run on different workers and workers don't share disk so pushing to xcom location of file downloaded by first task may not solve your issue (Unless you are running with LocalExecutor), or by download you mean disk on the remote machine? can you please clarify? – Elad Kalif Jun 21 '22 at 19:44
  • I /am/ running with LocalExecutor! :) – inspectorG4dget Jun 21 '22 at 19:53

1 Answers1

1

You want to decode the XCOM return value when Airflow renders the remote_filepath property for the Task instance.

This means that the b64decode function must be invoked within the template string. There is a catch though, we have to make this function available in the template context by providing it as a parameter or on the DAG level as a user defined filter or macro.

def basename_b64decode(value):
    return os.path.basename(base64.b64decode(value)).decode()

download_results = SFTPOperator(
    task_id = 'download_results',
    ssh_conn_id = 'my_remote_server',
    remote_filepath = """{{params.b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    local_filepath = """{{params.basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    operation = 'get',
    params = {
       'b64decode': base64.b64decode
       'basename_b64decode': basename_b64decode
    }
)

For the DAG user-defined macro approach, you can write:


with DAG('my_dag',
         default_args = dict(...
                             xcom_push = True,
                             user_defined_macros=dict(
                                 basename_b64decode=basename_b64decode,
                                 b64decode=base64.b64decode
                             ) 
                        )
         ) as dag:

download_results = SFTPOperator(
    task_id = 'download_results',
    ssh_conn_id = 'my_remote_server',
    remote_filepath = """{{b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    local_filepath = """{{basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
    operation = 'get',
)

Oluwafemi Sule
  • 36,144
  • 1
  • 56
  • 81