RishuD7's picture
first commit
4fe8a8c
import re
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2")
model = AutoModelForImageClassification.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def classify_documents(image):
# input_image = image.convert("RGB")
inputs = extractor(images=image, return_tensor='pt')
tensors = torch.from_numpy(inputs.pixel_values[0]).unsqueeze(0)
model_output = model(tensors).logits
max_index = torch.argmax(model_output)
document_class = model.config.id2label[max_index.item()]
return {
"result" : str(document_class)
}
article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>"
demo = gr.Interface(
fn=classify_documents,
inputs="image",
outputs="json",
title="Document Classification",
article=article,
enable_queue=True,
examples=[
["./test_images/email_image_2.jpg"],
["./test_images/form_image_3.jpg"]
],
cache_examples=False)
demo.launch()