Text Classification
Transformers
Safetensors
English
HHEMv2Config
custom_code
Miaoran000 commited on
Commit
6f7b340
1 Parent(s): ade58fc

update for pipeline

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_hhem_v2.py +9 -2
config.json CHANGED
@@ -8,5 +8,6 @@
8
  },
9
  "model_type": "HHEMv2Config",
10
  "torch_dtype": "float32",
11
- "transformers_version": "4.39.3"
 
12
  }
 
8
  },
9
  "model_type": "HHEMv2Config",
10
  "torch_dtype": "float32",
11
+ "transformers_version": "4.39.3",
12
+ "id2label": {"0": "hallucinated", "1": "consistent"}
13
  }
modeling_hhem_v2.py CHANGED
@@ -45,8 +45,15 @@ class HHEMv2ForSequenceClassification(PreTrainedModel):
45
  # combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
46
  # self.t5 = combined_model
47
 
48
- def forward(self, **kwargs):
49
- return self.t5(**kwargs)
 
 
 
 
 
 
 
50
 
51
  def predict(self, text_pairs):
52
  tokenizer = self.tokenzier
 
45
  # combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
46
  # self.t5 = combined_model
47
 
48
+ def forward(self, **kwargs): # To cope with `text-classiication` pipeline
49
+ self.t5.eval()
50
+ with torch.no_grad():
51
+ outputs = self.t5(**kwargs)
52
+ logits = outputs.logits
53
+ logits = logits[:, 0, :]
54
+ outputs.logits = logits
55
+ return outputs
56
+ # return self.t5(**kwargs)
57
 
58
  def predict(self, text_pairs):
59
  tokenizer = self.tokenzier