from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup from torch.optim import AdamW import torch import gc from fastapi import FastAPI import re from pydantic import BaseModel, Field app = FastAPI() # Note: Most code is identical to the notebook tokenizer = AutoTokenizer.from_pretrained( 'google/mt5-small') model = AutoModelForSeq2SeqLM.from_pretrained('yonatan-h/amharic-summarizer')#.cuda() text_len = 512 #ideally 512 summary_len = 128 #ideally 128 batch_size = 8 # 64, 24,16, 8 depending on gpu usage class SummarizeDto(BaseModel): text: str = Field(..., description="The text you want to summarize", examples= ["ግጭት በሚካሄድባቸው የአማራ እና ኦሮሚያ ክልሎች ከፍርድ ውጭ የሚፈጸሙ ግድያዎች በአሳሳቢነት መቀጠላቸውን የኢትዮጵያ ሰብዓዊ መብቶች ኮሚሽን ..."]) def encode(text, length): encoded = tokenizer.encode( text, return_tensors='pt', padding="max_length", max_length=length, truncation=True )#.cuda() return encoded[0] def decode(encoded, skip_special=False): decoded = tokenizer.decode(encoded, skip_special_tokens=skip_special) if skip_special: decoded = re.sub(r"<[^>]+>", "", decoded).strip() return decoded def summarize_multiple(text_encodeds, summary_len=summary_len, model=model): outputs = model.generate( text_encodeds, min_length=int(summary_len*0.5), max_length=int(summary_len * 2), num_beams=10, no_repeat_ngram_size=2, ) outputs = [decode(output, skip_special=True) for output in outputs] return outputs def summarize(text,text_len=text_len, summary_len=summary_len, model=model): encodeds = [encode(text, text_len).unsqueeze(0)] encodeds = torch.cat(encodeds)#.cuda() return summarize_multiple(encodeds, summary_len, model)[0] # Incase a larger text than the average training data is being summarized def chunkify(text, text_len): texts = [] sentences = text.split(".") text = "" for sentence in sentences: if len(text) + len(sentence) > text_len: texts.append(text) text = "" text += sentence if text: texts.append(text) return texts @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/summarise") def summarise(request: SummarizeDto): text = request.text summary = "" for chunk in chunkify(text, text_len): summary += " "+summarize(chunk) return {"summary": summary}