0

tfx.components.FnArgs is the way to pass values to the run_fn function which will in turn train the model in Tensorflow Extended pipeline.

Looking at the tfx.components.FnArgs' documentation, I cannot help but wonder why there's no attribute for the number of epochs to run the training loop (perhaps the most important attribute in training). Is this an oversight or am I supposed to control the number of epochs differently?

Mehran
  • 15,593
  • 27
  • 122
  • 221

2 Answers2

1

You can pass the epochs attribute in custom_config dict as shown in example notebook.

Example code:

trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_trainer_module_file),
    examples=ratings_transform.outputs['transformed_examples'],
    transform_graph=ratings_transform.outputs['transform_graph'],
    schema=ratings_transform.outputs['post_transform_schema'],
    train_args=tfx.proto.TrainArgs(num_steps=500),
    eval_args=tfx.proto.EvalArgs(num_steps=10),
    custom_config={
        'epochs':5,
        'movies':movies_transform.outputs['transformed_examples'],
        'movie_schema':movies_transform.outputs['post_transform_schema'],
        'ratings':ratings_transform.outputs['transformed_examples'],
        'ratings_schema':ratings_transform.outputs['post_transform_schema']
        })

context.run(trainer, enable_cache=False)
0

I think you can also do something like the below with train_args and eval_args, which may be a more direct solution:

trainer = Trainer(                                                              
    module_file=trainer_file,                                                   
    examples=transform.outputs['transformed_examples'],                         
    transform_graph = transform.outputs['transform_graph'],                     
    schema=transform.outputs['post_transform_schema'],                          
    hyperparameters = tuner.outputs['best_hyperparameters'],                    
    train_args=proto.TrainArgs(splits=['train'], num_steps=50),                 
    eval_args=proto.EvalArgs(splits=['eval'], num_steps=5)) 
Pritam Dodeja
  • 177
  • 1
  • 8