I'm assuming you have tensorflow & tensorflow_hub installed, and youhave already downloaded the model.
IMPORTANT: I'm assuming you're looking at https://tfhub.dev/google/universal-sentence-encoder/4
! There's no guarantee the object graph looks the same for different versions, it's likely that modifications will be needed.
Find it's location on disk - it's somewhere at /tmp/tfhub_modules
unless you set the TFHUB_CACHE_DIR
environment variable (Windows/Mac have different locations). The path should contain a file called saved_model.pb
, which is the model, serialized using Protocol Buffers.
Unfortunately, the dictionary is serialized inside the model's Protocol Buffers file and not as an external asset, so we'll have to load the model and get the variable from it.
The strategy is to use tensorflow's code to deserialize the file, and then travel down the serialized object tree all the way to the dictionary.
import importlib
MODEL_PATH = 'path/to/model/dir' # e.g. '/tmp/tfhub_modules/063d866c06683311b44b4992fd46003be952409c/'
# Use the tensorflow internal Protobuf loader. A regular import statement will fail.
loader_impl = importlib.import_module('tensorflow.python.saved_model.loader_impl')
saved_model = loader_impl.parse_saved_model(MODEL_PATH)
# reach into the object graph to get the tensor
graph = saved_model.meta_graphs[0].graph_def
function = graph.library.function
node_type, node_value = function[5].node_def
# if you print(node_type) you'll see it's called "text_preprocessor/hash_table"
# as well as get insight into this branch of the object graph we're looking at
words_tensor = node_value.attr.get("value").tensor
word_list = [i.decode('utf-8') for i in words_tensor.string_val]
print(len(word_list)) # -> 400004
Some resources that helped:
- A GitHub issue relating to changing the vocabulary
- A Tensorflow Google-group thread linked from the issue
Extra Notes
Despite what the GitHub issue may lead you to think, the 400k words here are not the GloVe 400k vocabulary. You can verify this by downloading the GloVe 6B embeddings (file link), extracting glove.6B.50d.txt
, and then using the following code to compare the two dictionaries:
with open('/path/to/glove.6B.50d.txt') as f:
glove_vocabulary = set(line.strip().split(maxsplit=1)[0] for line in f)
USE_vocabulary = set(word_list) # from above
print(len(USE_vocabulary - glove_vocabulary)) # -> 281150
Inspecting the different vocabularies is interesting in and of itself, e.g. why does GloVe have an entry for '287.9'?