0

I am trying to build a pipeline to parallelize writing tfrecords files on datasets that are too large to fit into memory. I have successfully used dask to do this many times in the past, but I have a new dataset requiring that TextVectorization and StringIndexers be applied outside of the model (running them inside of the model was choking the GPU's). Ultimately I'm trying to apply the vectorizers/string indexers in series within a single partition and then process each of the partitions in parallel using Dask's computation engine. I have tried about 100 different ways to write the functions and have attempted to apply tf.keras.layers.deserialize and tf.keras.layers.serialize to the vectorizers without luck.

Here is the pipeline as it stands:


from dask.distributed import LocalCluster, Client
import logging
import dask
import json
import os
import pickle
import tensorflow
from tensorflow.keras.layers.experimental.preprocessing import StringLookup, TextVectorization

cluster = LocalCluster()

client = Client(cluster)

# Setup logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
 
def load_and_broadcast_vectorizers(model_dir):
    logging.info("Loading and broadcasting vectorizers...")
    path = os.path.join(model_dir, 'vectorizers_saved')
    vectorizers = {}
   
    for file in os.listdir(path):
        if file.endswith("_vectorizer_config.json"):
            key = file[:-23]
            with open(os.path.join(path, f'{key}_vectorizer_config.json'), 'r') as f:
                config = json.load(f)
           
            vectorizer_class = config['vectorizer_class']
            config.pop('vectorizer_class', None)  # Remove the 'vectorizer_class' key
 
            if vectorizer_class == 'StringLookup':
                vectorizer = StringLookup.from_config(config)
            elif vectorizer_class == 'TextVectorization':
                vectorizer = TextVectorization.from_config(config)
            else:
                raise ValueError(f"Unknown vectorizer class: {vectorizer_class}")
 
            vectorizers[key] = vectorizer
    logging.info("Vectorizers loaded and broadcasted successfully.")
    return client.scatter(vectorizers, broadcast=True)
 
def apply_vectorizers_to_partition(partition, vectorizers):
    logging.info("Applying vectorizers to partition...")
    for vectorizer_name, vectorizer in vectorizers.items():
        partition[vectorizer_name] = vectorizer.transform(partition[vectorizer_name])
    return partition
 
def apply_vectorizers_in_parallel(vectorizers, df):
    logging.info("Applying vectorizers in parallel...")
    return df.map_partitions(apply_vectorizers_to_partition, vectorizers)
 
def write_partition_to_tfrecord(partition, output_path, partition_label, partition_id):
    logging.info(f"Writing {partition_label} partition {partition_id} to TFRecord...")
    file_name_tfrecord = f'{partition_label}_partition_{partition_id}.tfrecord'
    output_file_path_tfrecord = os.path.join(output_path, file_name_tfrecord)
    with tf.io.TFRecordWriter(output_file_path_tfrecord) as writer:
        for row in partition.itertuples():
                        # Extract features and label from each row
            features = {
                'input_1': tf.train.Feature(int64_list=tf.train.Int64List(value=[row['input_1']])),
                'input_2': tf.train.Feature(float_list=tf.train.FloatList(value=[row['input_2']])),
                'input_3': tf.train.Feature(int64_list=tf.train.Int64List(value=row['input_3'])),
                'input_4': tf.train.Feature(int64_list=tf.train.Int64List(value=[row['input_4']])),
                'input_5': tf.train.Feature(int64_list=tf.train.Int64List(value=[row['input_5']]))
    }

            label = tf.train.Feature(float_list=tf.train.FloatList(value=[row['label_col']]))
            example = tf.train.Example(features=tf.train.Features(feature={**features, **{'label': label}}))
           writer.write(example.SerializeToString())
    logging.info(f"{partition_label} partition {partition_id} written to TFRecord.")
 
def write_partition_to_parquet(partition, output_path, partition_label, partition_id):
    logging.info(f"Writing {partition_label} partition {partition_id} to Parquet...")
    selected_columns = partition[['personuuid', 'claimnum', 'claimrowid']]
    file_name_parquet = f'{partition_label}_partition_{partition_id}.parquet.snappy'
    output_file_path_parquet = os.path.join(output_path, file_name_parquet)
    selected_columns.to_parquet(output_file_path_parquet, compression='snappy')
    logging.info(f"{partition_label} partition {partition_id} written to Parquet.")
 
def write_vectorized_partitions_to_files(vectorizers, df, output_path, partition_label):
    logging.info(f"Writing {partition_label} vectorized partitions to files...")
    dask_tasks = []
    for i, partition in enumerate(df.to_delayed()):
        tfrecord_task = dask.delayed(write_partition_to_tfrecord)(partition, output_path, partition_label, i)
        parquet_task = dask.delayed(write_partition_to_parquet)(partition, output_path, partition_label, i)
        dask_tasks.extend([tfrecord_task, parquet_task])
    dask.compute(*dask_tasks)
    logging.info(f"{partition_label} vectorized partitions written to files successfully.")
 
def process_data(model_dir, df, output_path):
    logging.info("Processing data...")
    vectorizers = load_and_broadcast_vectorizers(model_dir)
   
    if vectorizers is None:
        logging.error("Data processing failed due to missing vectorizers.")
        return
 
    train_df, test_df = df.random_split([0.8, 0.2], random_state=42)
    train_df_vectorized = apply_vectorizers_in_parallel(vectorizers, train_df)
    test_df_vectorized = apply_vectorizers_in_parallel(vectorizers, test_df)
    write_vectorized_partitions_to_files(vectorizers, train_df_vectorized, os.path.join(output_path, 'train'), 'train')
    write_vectorized_partitions_to_files(vectorizers, test_df_vectorized, os.path.join(output_path, 'test'), 'test')
    logging.info("Data processed successfully.")

Error Message:

TypeError: ('Could not serialize object of type StringLookup', '<keras.layers.preprocessing.string_lookup.StringLookup object at 0x7f7f1d79e9a0>')
scribbles
  • 4,089
  • 7
  • 22
  • 29

0 Answers0