0

I'm training a model in PyTorch 2.0.0.I built a model Bert+Liner Model below. I have set device=torch.device("mps").Error occurs where input_ids,attention_mask,token_type_ids,labels. Thanks for your help~~

i expect that model could run on mps, while i figure it out on cpu, with much time to spend.I want to run the model on mps with less time.Thanks

model = bert+Linear

class Model(torch.nn.Module): def init(self): super().init() # 定义一个全连接层 # 输入768维:bert-base-chinese的输出维度,输出2维:情感倾向的种类数 self.fc = torch.nn.Linear(768, 2)

# 前向传播函数
def forward(self, input_ids, attention_mask, token_type_ids):
    with torch.no_grad():  
        out = pretrained(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids)  # 基于bert模型,直接获得输出768维的结果

    # 取0个词的特征作为全连接层的输入
    out = self.fc(out.last_hidden_state[:, 0])
    out = out.softmax(dim=1)  # 取出概率最大的一维作为结果
    return out

predict sentiment of text

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(bar_predict): print('当前批次:', i)

    # 将变量转移到MPS上
    input_ids = input_ids.to(torch.long).to(config.device)
    attention_mask = attention_mask.to(torch.long).to(config.device)
    token_type_ids = token_type_ids.to(torch.long).to(config.device)
    labels = labels.to(torch.long).to(config.device)

    with torch.no_grad():  # 预测过程不计算梯度,
        out = sentimodel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = out.argmax(dim=1)  # 取列维度最大的值所对应的位置索引
ErShi
  • 1

0 Answers0