I am currently engaged in the process of fine-tuning the google/mT5-small model on Google Colab for a translation task. To assess the quality of translations, I am utilizing the keras_nlp.metrics.Bleu metric to compute the Bleu score. However, I have encountered an error:
ValueError: y_pred must be of rank 0, 1 or 2. Found rank: 3
I attempted to search for the solution to this issue on Google, but regrettably, I was unable to obtain a resolution.
Here is the details of the error:
ValueError Traceback (most recent call last)
<ipython-input-10-3c7a368af456> in <cell line: 10>()
8 ]
9
---> 10 model.fit(x=tf_train_set, validation_data=tf_val_set, epochs=5, callbacks=callbacks)
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
59 def error_handler(*args, **kwargs):
60 if not tf.debugging.is_traceback_filtering_enabled():
---> 61 return fn(*args, **kwargs)
62
63 filtered_tb = None
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1740 ):
1741 callbacks.on_train_batch_begin(step)
-> 1742 tmp_logs = self.train_function(iterator)
1743 if data_handler.should_sync:
1744 context.async_wait()
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
139 try:
140 if not is_traceback_filtering_enabled():
--> 141 return fn(*args, **kwargs)
142 except NameError:
143 # In some very rare cases,
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py in __call__(self, *args, **kwds)
823
824 with OptionalXlaContext(self._jit_compile):
--> 825 result = self._call(*args, **kwds)
826
827 new_tracing_count = self.experimental_get_tracing_count()
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py in _call(self, *args, **kwds)
871 # This is the first call of __call__, so we have to initialize.
872 initializers = []
--> 873 self._initialize(args, kwds, add_initializers_to=initializers)
874 finally:
875 # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py in _initialize(self, args, kwds, add_initializers_to)
692 # Force the definition of the function for these arguments
693 self._concrete_variable_creation_fn = (
--> 694 self._variable_creation_fn # pylint: disable=protected-access
695 ._get_concrete_function_internal_garbage_collected(
696 *args, **kwds))
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
174 """Returns a concrete function which cleans up its graph function."""
175 with self._lock:
--> 176 concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
177 return concrete_function
178
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py in _maybe_define_concrete_function(self, args, kwargs)
169 kwargs = {}
170
--> 171 return self._maybe_define_function(args, kwargs)
172
173 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py in _maybe_define_function(self, args, kwargs)
396 kwargs = placeholder_bound_args.kwargs
397
--> 398 concrete_function = self._create_concrete_function(
399 args, kwargs, func_graph)
400
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py in _create_concrete_function(self, args, kwargs, func_graph)
303
304 concrete_function = monomorphic_function.ConcreteFunction(
--> 305 func_graph_module.func_graph_from_py_func(
306 self._name,
307 self._python_function,
/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, create_placeholders)
1053
1054 _, original_func = tf_decorator.unwrap(python_func)
-> 1055 func_outputs = python_func(*func_args, **func_kwargs)
1056
1057 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py in wrapped_fn(*args, **kwds)
595 # the function a weak reference to itself to avoid a reference cycle.
596 with OptionalXlaContext(compile_with_xla):
--> 597 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
598 return out
599
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py in autograph_handler(*args, **kwargs)
50 except Exception as e: # pylint:disable=broad-except
51 if hasattr(e, "ag_error_metadata"):
---> 52 raise e.ag_error_metadata.to_exception(e)
53 else:
54 raise
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py in autograph_handler(*args, **kwargs)
39 """Calls a converted version of original_func."""
40 try:
---> 41 return api.converted_call(
42 original_func,
43 args,
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
437 try:
438 if kwargs is not None:
--> 439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py in tf__train_function(iterator)
13 try:
14 do_return = True
---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
16 except:
17 do_return = False
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
375
376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377 return _call_unconverted(f, args, kwargs, options)
378
379 # internal_convert_user_code is for example turned off when issuing a dynamic
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
458 if kwargs is not None:
459 return f(*args, **kwargs)
--> 460 return f(*args)
461
462
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py in step_function(model, iterator)
1320 )
1321 data = next(iterator)
-> 1322 outputs = model.distribute_strategy.run(run_step, args=(data,))
1323 outputs = reduce_per_replica(
1324 outputs,
/usr/local/lib/python3.10/dist-packages/tensorflow/python/distribute/distribute_lib.py in run(***failed resolving arguments***)
1671 fn = autograph.tf_convert(
1672 fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
-> 1673 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1674
1675 def reduce(self, reduce_op, value, axis):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/distribute/distribute_lib.py in call_for_each_replica(self, fn, args, kwargs)
3248 kwargs = {}
3249 with self._container_strategy().scope():
-> 3250 return self._call_for_each_replica(fn, args, kwargs)
3251
3252 def _call_for_each_replica(self, fn, args, kwargs):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/distribute/distribute_lib.py in _call_for_each_replica(self, fn, args, kwargs)
4046 def _call_for_each_replica(self, fn, args, kwargs):
4047 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
-> 4048 return fn(*args, **kwargs)
4049
4050 def _reduce_to(self, reduce_op, value, destinations, options):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
--> 693 raise e.ag_error_metadata.to_exception(e)
694 else:
695 raise
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
688 try:
689 with conversion_ctx:
--> 690 return converted_call(f, args, kwargs, options=options)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
375
376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377 return _call_unconverted(f, args, kwargs, options)
378
379 # internal_convert_user_code is for example turned off when issuing a dynamic
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
457
458 if kwargs is not None:
--> 459 return f(*args, **kwargs)
460 return f(*args)
461
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py in run_step(data)
1301
1302 def run_step(data):
-> 1303 outputs = model.train_step(data)
1304 # Ensure counter is updated only if `train_step` succeeds.
1305 with tf.control_dependencies(_minimum_control_deps(outputs)):
/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in train_step(self, data)
1679 self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
1680
-> 1681 self.compiled_metrics.update_state(y, y_pred, sample_weight)
1682 # Collect metrics to return
1683 return_metrics = {}
/usr/local/lib/python3.10/dist-packages/keras/src/engine/compile_utils.py in update_state(self, y_true, y_pred, sample_weight)
603 if metric_obj is None:
604 continue
--> 605 metric_obj.update_state(y_t, y_p, sample_weight=mask)
606
607 for weighted_metric_obj in weighted_metric_objs:
/usr/local/lib/python3.10/dist-packages/keras/src/utils/metrics_utils.py in decorated(metric_obj, *args, **kwargs)
75
76 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
---> 77 update_op = update_state_fn(*args, **kwargs)
78 if update_op is not None: # update_op will be None in eager execution.
79 metric_obj.add_update(update_op)
/usr/local/lib/python3.10/dist-packages/keras/src/metrics/base_metric.py in update_state_fn(*args, **kwargs)
138 obj_update_state, control_status
139 )
--> 140 return ag_update_state(*args, **kwargs)
141
142 else:
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
--> 693 raise e.ag_error_metadata.to_exception(e)
694 else:
695 raise
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
688 try:
689 with conversion_ctx:
--> 690 return converted_call(f, args, kwargs, options=options)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
437 try:
438 if kwargs is not None:
--> 439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py in tf__update_state(self, y_true, y_pred, sample_weight)
104 return fscope_1.ret(retval__1, do_return_1)
105 y_true = ag__.converted_call(ag__.ld(validate_and_fix_rank), (ag__.ld(y_true), 'y_true', 1), None, fscope)
--> 106 y_pred = ag__.converted_call(ag__.ld(validate_and_fix_rank), (ag__.ld(y_pred), 'y_pred', 0), None, fscope)
107 y_true = ag__.converted_call(ag__.ld(self)._tokenizer, (ag__.ld(y_true),), None, fscope)
108 y_pred = ag__.converted_call(ag__.ld(self)._tokenizer, (ag__.ld(y_pred),), None, fscope)
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
329 if conversion.is_in_allowlist_cache(f, options):
330 logging.log(2, 'Allowlisted %s: from cache', f)
--> 331 return _call_unconverted(f, args, kwargs, options, False)
332
333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
458 if kwargs is not None:
459 return f(*args, **kwargs)
--> 460 return f(*args)
461
462
/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py in validate_and_fix_rank(inputs, tensor_name, base_rank)
101 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 2, if_body_2, else_body_2, get_state_2, set_state_2, ('do_return_1', 'retval__1'), 2)
102 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 1, if_body_3, else_body_3, get_state_3, set_state_3, ('do_return_1', 'retval__1'), 2)
--> 103 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank), if_body_4, else_body_4, get_state_4, set_state_4, ('do_return_1', 'retval__1'), 2)
104 return fscope_1.ret(retval__1, do_return_1)
105 y_true = ag__.converted_call(ag__.ld(validate_and_fix_rank), (ag__.ld(y_true), 'y_true', 1), None, fscope)
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1267 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1268 else:
-> 1269 _py_if_stmt(cond, body, orelse)
1270
1271
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in _py_if_stmt(cond, body, orelse)
1320 def _py_if_stmt(cond, body, orelse):
1321 """Overload of if_stmt that executes a Python if statement."""
-> 1322 return body() if cond else orelse()
/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py in else_body_4()
100 raise ag__.converted_call(ag__.ld(ValueError), (f'{ag__.ld(tensor_name)} must be of rank {ag__.ld(base_rank)}, {ag__.ld(base_rank) + 1} or {ag__.ld(base_rank) + 2}. Found rank: {ag__.ld(inputs).shape.rank}',), None, fscope_1)
101 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 2, if_body_2, else_body_2, get_state_2, set_state_2, ('do_return_1', 'retval__1'), 2)
--> 102 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 1, if_body_3, else_body_3, get_state_3, set_state_3, ('do_return_1', 'retval__1'), 2)
103 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank), if_body_4, else_body_4, get_state_4, set_state_4, ('do_return_1', 'retval__1'), 2)
104 return fscope_1.ret(retval__1, do_return_1)
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1267 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1268 else:
-> 1269 _py_if_stmt(cond, body, orelse)
1270
1271
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in _py_if_stmt(cond, body, orelse)
1320 def _py_if_stmt(cond, body, orelse):
1321 """Overload of if_stmt that executes a Python if statement."""
-> 1322 return body() if cond else orelse()
/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py in else_body_3()
99 nonlocal do_return_1, retval__1
100 raise ag__.converted_call(ag__.ld(ValueError), (f'{ag__.ld(tensor_name)} must be of rank {ag__.ld(base_rank)}, {ag__.ld(base_rank) + 1} or {ag__.ld(base_rank) + 2}. Found rank: {ag__.ld(inputs).shape.rank}',), None, fscope_1)
--> 101 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 2, if_body_2, else_body_2, get_state_2, set_state_2, ('do_return_1', 'retval__1'), 2)
102 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 1, if_body_3, else_body_3, get_state_3, set_state_3, ('do_return_1', 'retval__1'), 2)
103 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank), if_body_4, else_body_4, get_state_4, set_state_4, ('do_return_1', 'retval__1'), 2)
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1267 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1268 else:
-> 1269 _py_if_stmt(cond, body, orelse)
1270
1271
/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/operators/control_flow.py in _py_if_stmt(cond, body, orelse)
1320 def _py_if_stmt(cond, body, orelse):
1321 """Overload of if_stmt that executes a Python if statement."""
-> 1322 return body() if cond else orelse()
/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py in else_body_2()
98 def else_body_2():
99 nonlocal do_return_1, retval__1
--> 100 raise ag__.converted_call(ag__.ld(ValueError), (f'{ag__.ld(tensor_name)} must be of rank {ag__.ld(base_rank)}, {ag__.ld(base_rank) + 1} or {ag__.ld(base_rank) + 2}. Found rank: {ag__.ld(inputs).shape.rank}',), None, fscope_1)
101 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 2, if_body_2, else_body_2, get_state_2, set_state_2, ('do_return_1', 'retval__1'), 2)
102 ag__.if_stmt(ag__.ld(inputs).shape.rank == ag__.ld(base_rank) + 1, if_body_3, else_body_3, get_state_3, set_state_3, ('do_return_1', 'retval__1'), 2)
ValueError: in user code:
File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1338, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/metrics/bleu.py", line 342, in validate_and_fix_rank *
raise ValueError(
ValueError: y_pred must be of rank 0, 1 or 2. Found rank: 3
And my code is here:
checkpoint = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = TFAutoModelForSeq2SeqLM.from_pretrained(checkpoint)
optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
model.compile(optimizer=optimizer, metrics=["accuracy", keras_nlp.metrics.Bleu(tokenizer)])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(f'models/{model_name}',save_best_only=True, save_weights_only=True),
tf.keras.callbacks.TensorBoard(log_dir=f'logs/{model_name}_logs')
]
model.fit(x=tf_train_set, validation_data=tf_val_set, epochs=5, callbacks=callbacks)
How to fix this problem, or is there something else can compute Bleu score?
Please help me, thanks!!ヾ(•ω•`)o