engrharis commited on
Commit
cc52a3a
·
verified ·
1 Parent(s): a3d7b30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import pytesseract
5
+ from torchvision import transforms
6
+ from model import UTRNet # Assuming the UTRNet model is defined in a file `model.py`
7
+
8
+ # Load model
9
+ def load_model():
10
+ model = UTRNet() # Initialize the model (ensure it is defined in a separate model.py)
11
+ model.load_state_dict(torch.load('saved_models/UTRNet-Large/best_norm_ED.pth'))
12
+ model.eval()
13
+ return model
14
+
15
+ # Image preprocessing
16
+ def preprocess_image(image):
17
+ transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Resize((320, 320)),
20
+ ])
21
+ return transform(image).unsqueeze(0)
22
+
23
+ # OCR prediction function
24
+ def predict_ocr(image, model):
25
+ image_tensor = preprocess_image(image)
26
+ with torch.no_grad():
27
+ output = model(image_tensor)
28
+ # Post-process the output to get text (This depends on how the model is structured)
29
+ return output # You might need to decode the output to actual text
30
+
31
+ # Streamlit App
32
+ def main():
33
+ st.title("Urdu Text Extraction Using UTRNet")
34
+ st.write("Upload an image containing Urdu text for OCR extraction.")
35
+
36
+ uploaded_image = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
37
+
38
+ if uploaded_image is not None:
39
+ # Load and display the image
40
+ image = Image.open(uploaded_image)
41
+ st.image(image, caption="Uploaded Image", use_column_width=True)
42
+
43
+ # Load the model
44
+ model = load_model()
45
+
46
+ # Get predictions
47
+ if st.button("Extract Text"):
48
+ output = predict_ocr(image, model)
49
+ st.write("Extracted Text:")
50
+ st.write(output) # You will need to process `output` to display text properly
51
+
52
+ if __name__ == "__main__":
53
+ main()