news-analyzer / pipeline.py
elozano's picture
Limit generated sequences
e567dcc
raw
history blame
628 Bytes
from typing import Dict, Optional
from transformers import TextClassificationPipeline
class NewsPipeline(TextClassificationPipeline):
def __init__(self, emojis: Dict[str, str], **kwargs) -> None:
self.emojis = emojis
super().__init__(**kwargs)
def __call__(self, headline: str, content: Optional[str]) -> str:
if content:
text = f" {self.tokenizer.sep_token} ".join([headline, content])
else:
text = headline
prediction = super().__call__(text, padding=True, truncation=True)[0]
return {**prediction, "emoji": self.emojis[prediction["label"]]}