I am trying to run the official BERT pretraining scripts based on this tutorial https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379 with the main exception that I am trying to use Azure Databricks. When I try to run the tensorflow estimator to train the net it starts to run ok, saving the first iteration of the model. However when it tries to save the second itteration I get an Input/output error which looks like its caused by trying to rename a temporary file. Does anybody know a resolution to this problem?
---------------------------------------------------------------------------
UnknownError Traceback (most recent call last)
<command-3548884146162520> in <module>()
----> 1 estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py in train(self, input_fn, hooks, steps, max_steps, saving_listeners)
2874 finally:
2875 rendezvous.record_done('training_loop')
-> 2876 rendezvous.raise_errors()
2877
2878 def evaluate(self,
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py in raise_errors(self, timeout_sec)
129 else:
130 logging.warn('Reraising captured error')
--> 131 six.reraise(typ, value, traceback)
132
133 for k, (typ, value, traceback) in kept_errors:
/databricks/python/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py in train(self, input_fn, hooks, steps, max_steps, saving_listeners)
2869 steps=steps,
2870 max_steps=max_steps,
-> 2871 saving_listeners=saving_listeners)
2872 except Exception: # pylint: disable=broad-except
2873 rendezvous.record_error('training_loop', sys.exc_info())
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py in train(self, input_fn, hooks, steps, max_steps, saving_listeners)
365
366 saving_listeners = _check_listeners_type(saving_listeners)
--> 367 loss = self._train_model(input_fn, hooks, saving_listeners)
368 logging.info('Loss for final step: %s.', loss)
369 return self
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py in _train_model(self, input_fn, hooks, saving_listeners)
1156 return self._train_model_distributed(input_fn, hooks, saving_listeners)
1157 else:
-> 1158 return self._train_model_default(input_fn, hooks, saving_listeners)
1159
1160 def _train_model_default(self, input_fn, hooks, saving_listeners):
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py in _train_model_default(self, input_fn, hooks, saving_listeners)
1190 return self._train_with_estimator_spec(estimator_spec, worker_hooks,
1191 hooks, global_step_tensor,
-> 1192 saving_listeners)
1193
1194 def _train_model_distributed(self, input_fn, hooks, saving_listeners):
/databricks/python/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py in _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners)
1482 any_step_done = False
1483 while not mon_sess.should_stop():
-> 1484 _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
1485 any_step_done = True
1486 if not any_step_done:
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
752 feed_dict=feed_dict,
753 options=options,
--> 754 run_metadata=run_metadata)
755
756 def run_step_fn(self, step_fn):
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
1250 feed_dict=feed_dict,
1251 options=options,
-> 1252 run_metadata=run_metadata)
1253 except _PREEMPTION_ERRORS as e:
1254 logging.info(
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
1351 raise six.reraise(*original_exc_info)
1352 else:
-> 1353 raise six.reraise(*original_exc_info)
1354
1355
/databricks/python/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
1336 def run(self, *args, **kwargs):
1337 try:
-> 1338 return self._sess.run(*args, **kwargs)
1339 except _PREEMPTION_ERRORS:
1340 raise
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
1417 results=outputs[hook] if hook in outputs else None,
1418 options=options,
-> 1419 run_metadata=run_metadata))
1420 self._should_stop = self._should_stop or run_context.stop_requested
1421
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/basic_session_run_hooks.py in after_run(self, run_context, run_values)
592 if self._timer.should_trigger_for_step(global_step):
593 self._timer.update_last_triggered_step(global_step)
--> 594 if self._save(run_context.session, global_step):
595 run_context.request_stop()
596
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/basic_session_run_hooks.py in _save(self, session, step)
609 l.before_save(session, step)
610
--> 611 self._get_saver().save(session, self._save_path, global_step=step)
612 self._summary_writer.add_session_log(
613 SessionLog(
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/saver.py in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs, save_debug_info)
1181 all_model_checkpoint_paths=self.last_checkpoints,
1182 latest_filename=latest_filename,
-> 1183 save_relative_paths=self._save_relative_paths)
1184 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1185 except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
/databricks/python/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_management.py in update_checkpoint_state_internal(save_dir, model_checkpoint_path, all_model_checkpoint_paths, latest_filename, save_relative_paths, all_model_checkpoint_timestamps, last_preserved_timestamp)
240 # file.
241 file_io.atomic_write_string_to_file(coord_checkpoint_filename,
--> 242 text_format.MessageToString(ckpt))
243
244
/databricks/python/lib/python3.6/site-packages/tensorflow/python/lib/io/file_io.py in atomic_write_string_to_file(filename, contents, overwrite)
538 write_string_to_file(temp_pathname, contents)
539 try:
--> 540 rename(temp_pathname, filename, overwrite)
541 except errors.OpError:
542 delete_file(temp_pathname)
/databricks/python/lib/python3.6/site-packages/tensorflow/python/lib/io/file_io.py in rename(oldname, newname, overwrite)
500 errors.OpError: If the operation fails.
501 """
--> 502 rename_v2(oldname, newname, overwrite)
503
504
/databricks/python/lib/python3.6/site-packages/tensorflow/python/lib/io/file_io.py in rename_v2(src, dst, overwrite)
517 """
518 pywrap_tensorflow.RenameFile(
--> 519 compat.as_bytes(src), compat.as_bytes(dst), overwrite)
520
521
UnknownError: /dbfs/tmp/model/checkpoint.tmp2feb8d7a932249e7ba1a11f96d3cb334; Input/output error