|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
import pytesseract |
|
from torchvision import transforms |
|
from model import UTRNet |
|
|
|
|
|
def load_model(): |
|
model = UTRNet() |
|
model.load_state_dict(torch.load('saved_models/UTRNet-Large/best_norm_ED.pth')) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess_image(image): |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((320, 320)), |
|
]) |
|
return transform(image).unsqueeze(0) |
|
|
|
|
|
def predict_ocr(image, model): |
|
image_tensor = preprocess_image(image) |
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
|
|
return output |
|
|
|
|
|
def main(): |
|
st.title("Urdu Text Extraction Using UTRNet") |
|
st.write("Upload an image containing Urdu text for OCR extraction.") |
|
|
|
uploaded_image = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"]) |
|
|
|
if uploaded_image is not None: |
|
|
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
if st.button("Extract Text"): |
|
output = predict_ocr(image, model) |
|
st.write("Extracted Text:") |
|
st.write(output) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|