Ariamehr commited on
Commit
caff9fa
1 Parent(s): 9b28203

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,18 +1,31 @@
1
- import gradio as gr
2
- from transformers import AutoImageProcessor, SapiensForImageSegmentation
3
  from PIL import Image
 
 
4
 
 
 
 
5
 
6
- model_name = "facebook/sapiens"
7
- processor = AutoImageProcessor.from_pretrained(model_name)
8
- model = SapiensForImageSegmentation.from_pretrained(model_name)
9
 
10
  def segment_image(image):
 
11
  inputs = processor(images=image, return_tensors="pt")
12
- outputs = model(**inputs)
13
- segmentation = outputs.logits.argmax(dim=1).detach().cpu().numpy()[0]
 
 
 
 
 
 
14
  return Image.fromarray(segmentation)
15
 
 
 
16
  interface = gr.Interface(
17
  fn=segment_image,
18
  inputs=gr.Image(type="pil"),
 
1
+ import torch
2
+ from transformers import AutoImageProcessor
3
  from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
 
7
+ # آدرس مدل
8
+ model_url = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.3b"
9
+ model = torch.jit.load(model_url)
10
 
11
+ # بارگذاری پردازشگر تصویر
12
+ processor = AutoImageProcessor.from_pretrained("facebook/sapiens")
 
13
 
14
  def segment_image(image):
15
+ # پردازش تصویر
16
  inputs = processor(images=image, return_tensors="pt")
17
+
18
+ # اجرای مدل روی تصویر پردازش شده
19
+ with torch.no_grad():
20
+ outputs = model(inputs['pixel_values'])
21
+
22
+ # فرض می‌کنیم خروجی یک ماسک است
23
+ segmentation = outputs.argmax(dim=1).detach().cpu().numpy()[0]
24
+
25
  return Image.fromarray(segmentation)
26
 
27
+ # رابط Gradio
28
+ import gradio as gr
29
  interface = gr.Interface(
30
  fn=segment_image,
31
  inputs=gr.Image(type="pil"),