I am trying to use GPT-2 in a codebase that is written for Tensorflow 1.x. However, I am running the code against TF 2.x installation binaries with tf.disable_v2_behavior()
flag. Without this tf.disable_v2_behavior()
flag, GPT-2 pretrained model loads fine, but the model fails to load if the flag is used. Here is my code :
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() #works fine without this line
from transformers import TFGPT2Model
model = TFGPT2Model.from_pretrained('gpt2') #fails
Here is the error:
>>> TFGPT2Model.from_pretrained('gpt2')
2022-02-15 10:17:08.792655: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/modeling_tf_utils.py", line 1467, in from_pretrained
model(model.dummy_inputs) # build the network with dummy inputs
File "/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 783, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py:628 call *
outputs = self.transformer(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:763 __call__ **
self._maybe_build(inputs)
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:2084 _maybe_build
self.build(input_shapes)
/home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py:241 build
self.wpe = self.add_weight(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:441 add_weight
variable = self._add_variable_with_custom_getter(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py:810 _add_variable_with_custom_getter
new_variable = getter(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:127 make_variable
return tf_variables.VariableV1(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:260 __call__
return cls._variable_v1_call(*args, **kwargs)
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:206 _variable_v1_call
return previous_getter(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:199 <lambda>
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variable_scope.py:2612 default_variable_creator
return resource_variable_ops.ResourceVariable(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:264 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1584 __init__
self._init_from_args(
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1722 _init_from_args
initial_value = initial_value()
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/initializers_v2.py:413 __call__
dtype = _assert_float_dtype(_get_dtype(dtype))
/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/initializers_v2.py:948 _assert_float_dtype
raise ValueError('Expected floating point type, got %s.' % dtype)
ValueError: Expected floating point type, got <dtype: 'int32'>.
I am using TF 2.5 with transformers v4.12.5. Is there any way around to make this work with TF v2 behaviour disabled?