0

When I create a PyTorch DataLoader and trying to train the model, I got this User Warning:

/usr/local/lib/python3.10/dist-packages/sentence_transformers/SentenceTransformer.py:547: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:245.)
labels = torch.tensor(labels)

    from torch.utils.data import DataLoader
    from sentence_transformers import losses
    from sentence_transformers import ParallelSentencesDataset
    from sentence_transformers import models
    from sentence_transformers import SentenceTransformer
    
    xlmr = models.Transformer('xlm-roberta-base')
    pooler = models.Pooling(
        xlmr.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True
    )
    
    student = SentenceTransformer(modules=[xlmr, pooler])
    teacher = SentenceTransformer('paraphrase-distilroberta-base-v2')
    
    data = ParallelSentencesDataset(
        student_model=student,
        teacher_model=teacher,
        batch_size=32,
        use_embedding_cache=True
    )
    data.load_data('/path/to/somefile', max_sentence_length=512)
    
    loader = DataLoader(data, shuffle=True, batch_size=32)
    loss = losses.MSELoss(model=student)
      
    epochs=1
    student.fit(
        train_objectives=[(loader, loss)],
        epochs=epochs,
        warmup_steps=int(len(loader) * epochs * 0.1), # 10% of data
        output_path='./xlmr-ted',
        optimizer_params={'lr': 2e-5, 'eps': 1e-6},
        save_best_model=True,
        show_progress_bar=True
    )

Can I covert from a DataLoader() dataset to a tensor, or is there a better way of approaching the issue?

LeMoussel
  • 5,290
  • 12
  • 69
  • 122

1 Answers1

1

The correct way to create a tensor from a numpy array is to use: tensor = torch.from_numpy(array) (doc)

The problem is in sentence_transformer library though, so either you learn to live with this warning, or you modify it yourself in their code.

qmeeus
  • 2,341
  • 2
  • 12
  • 21