yohannesbekele's picture
setup
956b704
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
import torch
import gc
from fastapi import FastAPI
import re
from pydantic import BaseModel, Field
app = FastAPI()
# Note: Most code is identical to the notebook
tokenizer = AutoTokenizer.from_pretrained( 'google/mt5-small')
model = AutoModelForSeq2SeqLM.from_pretrained('yonatan-h/amharic-summarizer')#.cuda()
text_len = 512 #ideally 512
summary_len = 128 #ideally 128
batch_size = 8 # 64, 24,16, 8 depending on gpu usage
class SummarizeDto(BaseModel):
text: str = Field(..., description="The text you want to summarize", examples= ["ግጭት በሚካሄድባቸው የአማራ እና ኦሮሚያ ክልሎች ከፍርድ ውጭ የሚፈጸሙ ግድያዎች በአሳሳቢነት መቀጠላቸውን የኢትዮጵያ ሰብዓዊ መብቶች ኮሚሽን ..."])
def encode(text, length):
encoded = tokenizer.encode(
text, return_tensors='pt', padding="max_length", max_length=length, truncation=True
)#.cuda()
return encoded[0]
def decode(encoded, skip_special=False):
decoded = tokenizer.decode(encoded, skip_special_tokens=skip_special)
if skip_special:
decoded = re.sub(r"<[^>]+>", "", decoded).strip()
return decoded
def summarize_multiple(text_encodeds, summary_len=summary_len, model=model):
outputs = model.generate(
text_encodeds,
min_length=int(summary_len*0.5),
max_length=int(summary_len * 2),
num_beams=10,
no_repeat_ngram_size=2,
)
outputs = [decode(output, skip_special=True) for output in outputs]
return outputs
def summarize(text,text_len=text_len, summary_len=summary_len, model=model):
encodeds = [encode(text, text_len).unsqueeze(0)]
encodeds = torch.cat(encodeds)#.cuda()
return summarize_multiple(encodeds, summary_len, model)[0]
# Incase a larger text than the average training data is being summarized
def chunkify(text, text_len):
texts = []
sentences = text.split(".")
text = ""
for sentence in sentences:
if len(text) + len(sentence) > text_len:
texts.append(text)
text = ""
text += sentence
if text:
texts.append(text)
return texts
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/summarise")
def summarise(request: SummarizeDto):
text = request.text
summary = ""
for chunk in chunkify(text, text_len):
summary += " "+summarize(chunk)
return {"summary": summary}