0

I am trying to use GPT-2 in a codebase that is written for Tensorflow 1.x. However, I am running the code against TF 2.x installation binaries with tf.disable_v2_behavior() flag. Without this tf.disable_v2_behavior() flag, GPT-2 pretrained model loads fine, but the model fails to load if the flag is used. Here is my code :

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() #works fine without this line

from transformers import TFGPT2Model
model = TFGPT2Model.from_pretrained('gpt2') #fails

Here is the error:

>>> TFGPT2Model.from_pretrained('gpt2')
2022-02-15 10:17:08.792655: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/modeling_tf_utils.py", line 1467, in from_pretrained
    model(model.dummy_inputs)  # build the network with dummy inputs
  File "/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 783, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py:628 call  *
        outputs = self.transformer(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:763 __call__  **
        self._maybe_build(inputs)
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:2084 _maybe_build
        self.build(input_shapes)
    /home1/07782/marefin/.local/lib/python3.8/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py:241 build
        self.wpe = self.add_weight(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:441 add_weight
        variable = self._add_variable_with_custom_getter(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py:810 _add_variable_with_custom_getter
        new_variable = getter(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:127 make_variable
        return tf_variables.VariableV1(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:206 _variable_v1_call
        return previous_getter(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:199 <lambda>
        previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variable_scope.py:2612 default_variable_creator
        return resource_variable_ops.ResourceVariable(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1584 __init__
        self._init_from_args(
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:1722 _init_from_args
        initial_value = initial_value()
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/initializers_v2.py:413 __call__
        dtype = _assert_float_dtype(_get_dtype(dtype))
    /home1/07782/marefin/.local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/initializers_v2.py:948 _assert_float_dtype
        raise ValueError('Expected floating point type, got %s.' % dtype)

    ValueError: Expected floating point type, got <dtype: 'int32'>.

I am using TF 2.5 with transformers v4.12.5. Is there any way around to make this work with TF v2 behaviour disabled?

1 Answers1

0

With tf.compat.v1.disable_eager_execution(), the pretrained model loads fine. Also, TF 1.x code doesn't raise any error with this flag.