Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup | |
from torch.optim import AdamW | |
import torch | |
import gc | |
from fastapi import FastAPI | |
import re | |
from pydantic import BaseModel, Field | |
app = FastAPI() | |
# Note: Most code is identical to the notebook | |
tokenizer = AutoTokenizer.from_pretrained( 'google/mt5-small') | |
model = AutoModelForSeq2SeqLM.from_pretrained('yonatan-h/amharic-summarizer')#.cuda() | |
text_len = 512 #ideally 512 | |
summary_len = 128 #ideally 128 | |
batch_size = 8 # 64, 24,16, 8 depending on gpu usage | |
class SummarizeDto(BaseModel): | |
text: str = Field(..., description="The text you want to summarize", examples= ["ግጭት በሚካሄድባቸው የአማራ እና ኦሮሚያ ክልሎች ከፍርድ ውጭ የሚፈጸሙ ግድያዎች በአሳሳቢነት መቀጠላቸውን የኢትዮጵያ ሰብዓዊ መብቶች ኮሚሽን ..."]) | |
def encode(text, length): | |
encoded = tokenizer.encode( | |
text, return_tensors='pt', padding="max_length", max_length=length, truncation=True | |
)#.cuda() | |
return encoded[0] | |
def decode(encoded, skip_special=False): | |
decoded = tokenizer.decode(encoded, skip_special_tokens=skip_special) | |
if skip_special: | |
decoded = re.sub(r"<[^>]+>", "", decoded).strip() | |
return decoded | |
def summarize_multiple(text_encodeds, summary_len=summary_len, model=model): | |
outputs = model.generate( | |
text_encodeds, | |
min_length=int(summary_len*0.5), | |
max_length=int(summary_len * 2), | |
num_beams=10, | |
no_repeat_ngram_size=2, | |
) | |
outputs = [decode(output, skip_special=True) for output in outputs] | |
return outputs | |
def summarize(text,text_len=text_len, summary_len=summary_len, model=model): | |
encodeds = [encode(text, text_len).unsqueeze(0)] | |
encodeds = torch.cat(encodeds)#.cuda() | |
return summarize_multiple(encodeds, summary_len, model)[0] | |
# Incase a larger text than the average training data is being summarized | |
def chunkify(text, text_len): | |
texts = [] | |
sentences = text.split(".") | |
text = "" | |
for sentence in sentences: | |
if len(text) + len(sentence) > text_len: | |
texts.append(text) | |
text = "" | |
text += sentence | |
if text: | |
texts.append(text) | |
return texts | |
def greet_json(): | |
return {"Hello": "World!"} | |
def summarise(request: SummarizeDto): | |
text = request.text | |
summary = "" | |
for chunk in chunkify(text, text_len): | |
summary += " "+summarize(chunk) | |
return {"summary": summary} | |