diff --git a/musecoco/1-text2attribute_model/model.py b/musecoco/1-text2attribute_model/model.py index 3dfef69..1f32955 100644 --- a/musecoco/1-text2attribute_model/model.py +++ b/musecoco/1-text2attribute_model/model.py @@ -436,7 +436,7 @@ class BertForAttributModel(BertPreTrainedModel): i += 1 logits[k] = self.classifieratt[idx](self.dropout(pooled_outputs[k])) - total_loss = torch.Tensor(0).to(input_ids.device) + total_loss = torch.Tensor(0.0).to(input_ids.device) loss = {} if labels is not None: for k in labels.keys():