File size: 3,277 Bytes
7ff20b3 a03eeb1 dc52995 a03eeb1 7ff20b3 3dfa0b7 7ff20b3 3dfa0b7 7ff20b3 3dfa0b7 7ff20b3 3dfa0b7 7ff20b3 3dfa0b7 7ff20b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import cv2
import io
import numpy as np
from PIL import Image
import pytesseract
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder
from mltu.transformers import ImageResizer
from mltu.configs import BaseModelConfigs
from textblob import TextBlob
from happytransformer import HappyTextToText, TTSettings
from transformers import AutoTokenizer, T5ForConditionalGeneration
from pydantic import BaseModel
tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large")
configs = BaseModelConfigs.load("./configs.yaml")
#happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageToWordModel(OnnxInferenceModel):
def __init__(self, char_list, *args, **kwargs):
super().__init__(*args, **kwargs)
self.char_list = char_list
def predict(self, image: np.ndarray):
image = ImageResizer.resize_maintaining_aspect_ratio(
image, *self.input_shape[:2][::-1]
)
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
preds = self.model.run(None, {self.input_name: image_pred})[0]
text = ctc_decoder(preds, self.char_list)[0]
return text
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
extracted_text = ""
@app.post("/extract_handwritten_text/")
async def predict_text(image: UploadFile):
global extracted_text
# Read the uploaded image
img = await image.read()
nparr = np.frombuffer(img, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Make a prediction
extracted_text = model.predict(img)
#corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
return {"text": extracted_text}
@app.post("/extract_text/")
async def extract_text_from_image(image: UploadFile):
global extracted_text
# Check if the uploaded file is an image
if image.content_type.startswith("image/"):
# Read the image from the uploaded file
image_bytes = await image.read()
img = Image.open(io.BytesIO(image_bytes))
# Perform OCR on the image
extracted_text = pytesseract.image_to_string(img)
#corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
return {"text": extracted_text}
else:
return {"error": "Invalid file format. Please upload an image."}
class ChatPrompt(BaseModel):
prompt: str
@app.post("/chat_prompt/")
async def chat_prompt(request: ChatPrompt):
global extracted_text
input_text = request.prompt + ": " + extracted_text
print(input_text)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = chatModel.generate(input_ids, max_length=256)
edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"edited_text": edited_text}
|