I'm training a model in PyTorch 2.0.0.I built a model Bert+Liner Model below. I have set device=torch.device("mps").Error occurs where input_ids,attention_mask,token_type_ids,labels. Thanks for your help~~
i expect that model could run on mps, while i figure it out on cpu, with much time to spend.I want to run the model on mps with less time.Thanks
model = bert+Linear
class Model(torch.nn.Module): def init(self): super().init() # 定义一个全连接层 # 输入768维:bert-base-chinese的输出维度,输出2维:情感倾向的种类数 self.fc = torch.nn.Linear(768, 2)
# 前向传播函数
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids) # 基于bert模型,直接获得输出768维的结果
# 取0个词的特征作为全连接层的输入
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1) # 取出概率最大的一维作为结果
return out
predict sentiment of text
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(bar_predict): print('当前批次:', i)
# 将变量转移到MPS上
input_ids = input_ids.to(torch.long).to(config.device)
attention_mask = attention_mask.to(torch.long).to(config.device)
token_type_ids = token_type_ids.to(torch.long).to(config.device)
labels = labels.to(torch.long).to(config.device)
with torch.no_grad(): # 预测过程不计算梯度,
out = sentimodel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
out = out.argmax(dim=1) # 取列维度最大的值所对应的位置索引