randomshit11 commited on
Commit
7fefcad
1 Parent(s): db5782f

Rename app.txt to app.py

Browse files
Files changed (2) hide show
  1. app.py +32 -0
  2. app.txt +0 -44
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import HTMLResponse
3
+ from transformers import pipeline
4
+ import gradio as gr
5
+
6
+ # Load the model pipeline
7
+ pipe = pipeline("image-classification", "dima806/medicinal_plants_image_detection")
8
+
9
+ # Define the image classification function
10
+ def image_classifier(image):
11
+ # Perform image classification
12
+ outputs = pipe(image)
13
+ results = {}
14
+ for result in outputs:
15
+ results[result['label']] = result['score']
16
+ return results
17
+
18
+ # Define Gradio Interface
19
+ gr_interface = gr.Interface(fn=image_classifier, inputs=gr.Image(type="pil"), outputs="label")
20
+
21
+ # Define FastAPI app
22
+ app = FastAPI()
23
+
24
+ # Define route for Gradio interface
25
+ @app.get("/")
26
+ async def gr_interface_route():
27
+ return HTMLResponse(gr_interface.launch(inline=False, inbrowser=True))
28
+
29
+ # Expose the FastAPI app using Uvicorn
30
+ if __name__ == "__main__":
31
+ import uvicorn
32
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app.txt DELETED
@@ -1,44 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- from PIL import Image
4
- from torchvision import transforms
5
- from model import ResNet50 # Assuming your model architecture is defined in a separate file called model.py
6
-
7
- # Load the model
8
- model = ResNet50()
9
- model.load_state_dict(torch.load('best_modelv2.pth', map_location=torch.device('cpu')))
10
- model.eval()
11
-
12
- # Define transform for input images
13
- data_transforms = transforms.Compose([
14
- transforms.Resize((224, 224)),
15
- transforms.ToTensor(),
16
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
- ])
18
-
19
- # Function to predict image label
20
- def predict_image_label(image):
21
- # Preprocess the image
22
- image = data_transforms(image).unsqueeze(0)
23
-
24
- # Make prediction
25
- with torch.no_grad():
26
- output = model(image)
27
- _, predicted = torch.max(output, 1)
28
-
29
- return predicted.item()
30
-
31
- # Streamlit app
32
- st.title("Leaf or Plant Classifier")
33
-
34
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
35
-
36
- if uploaded_file is not None:
37
- # Display the uploaded image
38
- image = Image.open(uploaded_file)
39
- st.image(image, caption='Uploaded Image', use_column_width=True)
40
-
41
- # Classify the image
42
- prediction = predict_image_label(image)
43
- label = 'Leaf' if prediction == 0 else 'Plant'
44
- st.write(f"Prediction: {label}")