0

The following is my Sentiment Analyser:

from transformers import DistilBertTokenizer, DistilBertModel
PRE_TRAINED_MODEL_NAME = 'distilbert-base-cased'
db_model = DistilBertModel.from_pretrained(PRE_TRAINED_MODEL_NAME, return_dict = False)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased', return_dict = False, return_tensors="pt")

class SentimentClassifier(nn.Module):
  def __init__(self, n_classes):
    super(SentimentClassifier, self).__init__()
    self.db = DistilBertModel.from_pretrained(PRE_TRAINED_MODEL_NAME, return_dict = False)
    self.drop = nn.Dropout(p=0.3)
    self.out = nn.Linear(self.db.config.hidden_size, n_classes)
  
  def forward(self, input_ids, attention_mask):
    pooled_output = self.db(
      input_ids=input_ids,
      attention_mask=attention_mask
    )
    output = self.drop(pooled_output)
    return self.out(output)

When I try to run :

F.softmax(model(input_ids, attention_mask), dim=1)

I am getting the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-191-96f6522cbd43> in <module>
----> 1 F.softmax(model(input_ids, attention_mask), dim=1)

4 frames
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace)
   1250     if p < 0.0 or p > 1.0:
   1251         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
-> 1252     return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
   1253 
   1254 

TypeError: dropout(): argument 'input' (position 1) must be Tensor, not tuple

I applied the solution used in BERT model (ie, return_dict = False and return_tensor = 'pt') and it is still running into this error. Any solution to this would be highly appreciated.

wamika
  • 21
  • 1
  • 8
  • It seems `pooled_output` is a tuple but `self.drop` requires it to be a tensor. Have you checked what the type of `pooled_output` is? – kmkurn Dec 08 '22 at 03:06

0 Answers0