AhmedIbrahim007 commited on
Commit
275f227
1 Parent(s): 479723d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -60
app.py CHANGED
@@ -9,48 +9,47 @@ import firebase_admin
9
  from firebase_admin import credentials, firestore, storage
10
  from pydantic import BaseModel
11
 
12
- # Load the pre-trained model
13
- learn = load_learner('model.pkl')
14
-
15
- # Define categories and map them to indices
16
- searches = ['formal', 'casual', 'athletic']
17
- searches = sorted(searches) # Ensure the categories are in sorted order
18
- values = [i for i in range(0, len(searches))]
19
- class_dict = dict(zip(searches, values))
20
-
21
- # Set up logging
22
- logging.basicConfig(level=logging.DEBUG,
23
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24
- logger = logging.getLogger(__name__)
25
-
26
- # Initialize Firebase
27
- try:
28
- cred = credentials.Certificate("serviceAccountKey.json")
29
- firebase_app = firebase_admin.initialize_app(cred, {
30
- 'storageBucket': 'future-forge-60d3f.appspot.com'
31
- })
32
- db = firestore.client()
33
- bucket = storage.bucket(app=firebase_app)
34
- logger.info("Firebase initialized successfully")
35
- except Exception as e:
36
- logger.error(f"Failed to initialize Firebase: {str(e)}")
37
-
38
- app = FastAPI()
39
-
40
- # Add CORS middleware
41
- app.add_middleware(
42
- CORSMiddleware,
43
- allow_origins=["*"],
44
- allow_credentials=True,
45
- allow_methods=["*"],
46
- allow_headers=["*"],
47
- )
48
-
49
-
50
- # Define the input model
51
- class FileProcess(BaseModel):
52
- file_path: str
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @app.post("/process")
56
  async def process_file(file_data: FileProcess):
@@ -70,9 +69,9 @@ async def process_file(file_data: FileProcess):
70
  file_type = file_data.file_path.split('.')[-1].lower()
71
 
72
  try:
73
- if file_type in ['jpg', 'jpeg', 'png', 'bmp']:
74
- output = process_video(str(tmp_file_path))
75
- result = {"type": "image", "data": {"result": output}}
76
  else:
77
  raise HTTPException(status_code=400, detail="Unsupported file type")
78
 
@@ -96,22 +95,27 @@ async def process_file(file_data: FileProcess):
96
  raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
97
 
98
 
99
- def process_video(video_path):
100
- # Load the image from the provided path
101
- img = PILImage.create(video_path)
102
-
103
- # Make the prediction
104
- classification, _, probs = learn.predict(img)
105
-
106
- # Convert the prediction to a confidence dictionary
107
- confidences = {label: float(probs[i]) for i, label in enumerate(class_dict)}
108
-
109
- # If classification is not formal, return 'informal'
110
- if classification != 'formal':
111
- informal_confidence = sum(confidences[label] for label in class_dict if label != 'formal')
112
- return {'informal': informal_confidence}
113
- else:
114
- return {'formal': confidences['formal']}
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  logger.info("Starting the Face Emotion Recognition API")
 
9
  from firebase_admin import credentials, firestore, storage
10
  from pydantic import BaseModel
11
 
12
+ import torch
13
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
14
+ from PIL import Image, ImageDraw, ImageFont
15
+ import cv2
16
+ import random
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load model and processor
19
+ processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
20
+ model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
21
+
22
+ # Fashionpedia categories
23
+ FASHION_CATEGORIES = [
24
+ 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
25
+ 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove',
26
+ 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood',
27
+ 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow',
28
+ 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'
29
+ ]
30
+
31
+
32
+ def detect_fashion(image):
33
+ inputs = processor(images=image, return_tensors="pt")
34
+ outputs = model(**inputs)
35
+
36
+ # Convert outputs (bounding boxes and class logits) to COCO API
37
+ target_sizes = torch.tensor([image.size[::-1]])
38
+ results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
39
+
40
+ detected_items = []
41
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
42
+ if score > 0.5: # Adjust this threshold as needed
43
+ detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist()))
44
+
45
+ return detected_items
46
+
47
+
48
+ def check_dress_code(detected_items):
49
+ formal_workplace_attire = {
50
+ "shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat"
51
+ }
52
+ return any(item[0] in formal_workplace_attire for item in detected_items)
53
 
54
  @app.post("/process")
55
  async def process_file(file_data: FileProcess):
 
69
  file_type = file_data.file_path.split('.')[-1].lower()
70
 
71
  try:
72
+ if file_type in ['mp4', 'avi', 'mov', 'wmv']:
73
+ output,testing = process_video(str(tmp_file_path))
74
+ result = {"type": "video", "data": {"result": output}}
75
  else:
76
  raise HTTPException(status_code=400, detail="Unsupported file type")
77
 
 
95
  raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
96
 
97
 
98
+ def process_video(video_path,num_frames=10):
99
+ cap = cv2.VideoCapture(video_path)
100
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
101
+
102
+ frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames)))
103
+
104
+ compliance_results = []
105
+
106
+ for frame_index in frame_indices:
107
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
108
+ ret, frame = cap.read()
109
+ if ret:
110
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
111
+ detected_items = detect_fashion(image)
112
+ is_compliant = check_dress_code(detected_items)
113
+ compliance_results.append(is_compliant)
114
+
115
+ cap.release()
116
+
117
+ average_compliance = sum(compliance_results) / len(compliance_results)
118
+ return average_compliance, compliance_results
119
 
120
  if __name__ == "__main__":
121
  logger.info("Starting the Face Emotion Recognition API")