# app.py import gradio as gr from PIL import Image import numpy as np import pandas as pd import io import os import cv2 import glob from pathlib import Path from typing import List, Dict, Any, Tuple # ---------------------------- # Lightweight image utilities # ---------------------------- import torch import torch.nn as nn from torchvision import models, transforms from huggingface_hub import hf_hub_download, list_repo_files, login from transformers import CLIPModel, CLIPProcessor try: from safetensors.torch import load_file as load_safetensors except Exception: # optional load_safetensors = None def _ensure_rgb(img: Image.Image) -> Image.Image: if img.mode != "RGB": img = img.convert("RGB") return img def _to_np(img: Image.Image) -> np.ndarray: return np.array(_ensure_rgb(img)) def _hsv_hist_features(img: Image.Image) -> np.ndarray: """Return simple features for matching and scoring. - Hue histogram (18 bins) - Mean saturation, mean value - Edge density via Canny """ arr = _to_np(img) hsv = cv2.cvtColor(arr, cv2.COLOR_RGB2HSV) h = hsv[:, :, 0] s = hsv[:, :, 1] / 255.0 v = hsv[:, :, 2] / 255.0 # Hue histogram (0..179 in OpenCV HSV). 18 bins of width 10. hist, _ = np.histogram(h.flatten(), bins=18, range=(0, 180), density=True) mean_s = float(s.mean()) mean_v = float(v.mean()) # Edge density gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 100, 200) edge_density = edges.mean() / 255.0 return np.concatenate([hist, np.array([mean_s, mean_v, edge_density])]) def _complementary_hue_score(q_hist: np.ndarray, w_hist: np.ndarray) -> float: """Simple compatibility score favoring complementary hues and balanced saturation/value. We approximate complement by circularly shifting hue histogram by half the bins and taking similarity with wardrobe hist. """ hb = 18 q_h = q_hist[:hb] w_h = w_hist[:hb] q_shift = np.roll(q_h, hb // 2) hue_sim = float(np.dot(q_shift, w_h) / (np.linalg.norm(q_shift) * np.linalg.norm(w_h) + 1e-8)) # Encourage pairing items with different edge density (texture contrast) q_ed = q_hist[-1] w_ed = w_hist[-1] texture_contrast = 1.0 - abs(q_ed - w_ed) # Encourage similar brightness range q_v = q_hist[-2] w_v = w_hist[-2] value_affinity = 1.0 - min(1.0, abs(q_v - w_v)) # Weighted sum -> [0, 1] score = 0.5 * hue_sim + 0.3 * texture_contrast + 0.2 * value_affinity return max(0.0, min(1.0, score)) # ---------------------------- # Model: ResNet18 regressor from Hugging Face # ---------------------------- class _Resnet18Regressor: """Loads danjung9/resnet18-approval-regressor and outputs a scalar. Assumes target scale 1..10; we map to 1..100. Set env MODEL_REPO or MODEL_OUTPUT_SCALE to adjust. """ def __init__(self, repo_id: str = os.getenv("MODEL_REPO", "danjung9/resnet18-approval-regressor"), device: str = "cpu", output_scale: str = os.getenv("MODEL_OUTPUT_SCALE", "1-10")): # "1-10"|"0-1"|"1-100" self.repo_id = repo_id self.device = torch.device(device) self.output_scale = output_scale self.model = self._load_model() self.model.eval().to(self.device) self.tf = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _load_model(self) -> nn.Module: hf_token = os.getenv("HF_TOKEN") if hf_token: try: login(token=hf_token) except Exception: pass net = models.resnet18(weights=None) in_features = net.fc.in_features net.fc = nn.Linear(in_features, 1) candidates = [ "model.safetensors", "pytorch_model.safetensors", "pytorch_model.bin", "model.pth", "model.pt", "weights.pth", "checkpoint.pth" ] repo_files = [] try: repo_files = list_repo_files(self.repo_id, token=hf_token) except Exception: pass chosen = None for c in candidates: if (repo_files and c in repo_files) or not repo_files: try: fpath = hf_hub_download(self.repo_id, c, token=hf_token) chosen = (c, fpath) break except Exception: continue if chosen is None: return net name, fpath = chosen try: if name.endswith(".safetensors") and load_safetensors is not None: state = load_safetensors(fpath, device="cpu") else: state = torch.load(fpath, map_location="cpu") if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] if isinstance(state, dict) and "model_state_dict" in state: state = state["model_state_dict"] new_state = {} for k, v in state.items(): nk = k.replace("module.", "") new_state[nk] = v net.load_state_dict(new_state, strict=False) except Exception: pass return net def predict_1to100(self, img: Image.Image) -> int: img = _ensure_rgb(img) x = self.tf(img).unsqueeze(0).to(self.device) with torch.no_grad(): y = self.model(x).float().squeeze().item() if self.output_scale == "1-10": score = y * 10.0 elif self.output_scale == "0-1": score = y * 99.0 + 1.0 else: # "1-100" score = y score = float(np.clip(score, 1.0, 100.0)) return int(round(score)) # Embedding model for AI-based matching (separate model from rating) class _Embedder: def __init__(self, repo_id: str = os.getenv("MATCH_REPO", "openai/clip-vit-base-patch32"), device: str = "cpu"): self.repo_id = repo_id self.device = torch.device(os.getenv("MATCH_DEVICE", device)) self.model = CLIPModel.from_pretrained(self.repo_id).to(self.device).eval() self.processor = CLIPProcessor.from_pretrained(self.repo_id) @torch.no_grad() def embed(self, img: Image.Image) -> np.ndarray: img = _ensure_rgb(img) inputs = self.processor(images=img, return_tensors="pt") pixel_values = inputs["pixel_values"].to(self.device) feats = self.model.get_image_features(pixel_values=pixel_values) # [1, D] feats = feats / (feats.norm(p=2, dim=-1, keepdim=True) + 1e-8) return feats.squeeze(0).cpu().numpy().astype(np.float32) @torch.no_grad() def classify(self, img: Image.Image, labels: List[str]) -> Tuple[str, np.ndarray]: img = _ensure_rgb(img) prompts = [f"a photo of a {l}" for l in labels] batch = self.processor(text=prompts, images=img, return_tensors="pt", padding=True) batch = {k: v.to(self.device) for k, v in batch.items()} img_feats = self.model.get_image_features(pixel_values=batch["pixel_values"]) # [1,D] txt_feats = self.model.get_text_features(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) # [L,D] img_feats = img_feats / (img_feats.norm(dim=-1, keepdim=True) + 1e-8) txt_feats = txt_feats / (txt_feats.norm(dim=-1, keepdim=True) + 1e-8) logits = img_feats @ txt_feats.T # [1,L] idx = int(logits.argmax(dim=-1).item()) return labels[idx], logits.squeeze(0).detach().cpu().numpy().astype(np.float32) # Category config _CAT_ALIASES = { "shirt": ["shirt", "t-shirt", "tee", "top", "blouse", "polo", "sweater"], "pants": ["pants", "trousers", "jeans", "chinos", "shorts", "sweatpants"], "jacket": ["jacket", "coat", "blazer", "overcoat", "parka", "cardigan", "hoodie"], "shoes": ["shoes", "sneakers", "boots", "loafers", "heels", "sandals"] } _CAT_LABELS = sorted({a for lst in _CAT_ALIASES.values() for a in lst}) _ALIAS_TO_CANON = {alias: canon for canon, lst in _CAT_ALIASES.items() for alias in lst} _ALLOWED_PAIRS = { "shirt": ["pants", "jacket", "shoes"], "pants": ["shirt", "jacket", "shoes"], "jacket": ["shirt", "pants", "shoes"], "shoes": ["shirt", "pants", "jacket"], } # Lazy globals _SCORER = None _EMBEDDER = None def _get_scorer() -> _Resnet18Regressor: global _SCORER if _SCORER is None: _SCORER = _Resnet18Regressor() return _SCORER def _get_embedder() -> _Embedder: global _EMBEDDER if _EMBEDDER is None: _EMBEDDER = _Embedder() return _EMBEDDER # ---------------------------- # In-memory session store # ---------------------------- # State schema: # { # "wardrobe": [ {"id": int, "name": str, "image": PIL.Image, "features": np.ndarray, "embedding": np.ndarray, "rating": int|None} ], # "selected_idx": int|None # } def _blank_state() -> Dict[str, Any]: return {"wardrobe": [], "selected_idx": None} # ---------------------------- def add_wardrobe(files: List[Any], state: Dict[str, Any]): if state is None: state = _blank_state() next_id = 0 if not state["wardrobe"] else max(w["id"] for w in state["wardrobe"]) + 1 if files is None: return state, _render_gallery(state), _ratings_df(state), "" scorer = _get_scorer() embedder = _get_embedder() warnings = [] added = 0 allowed_exts = {".png", ".jpg", ".jpeg"} for f in files: try: fname = os.path.basename(getattr(f, 'name', f)) ext = Path(fname).suffix.lower() if ext not in allowed_exts: warnings.append(fname) continue img = Image.open(f.name if hasattr(f, "name") else f) img = _ensure_rgb(img) feats = _hsv_hist_features(img) emb = embedder.embed(img) # category via zero-shot CLIP try: alias, _ = embedder.classify(img, _CAT_LABELS) category = _ALIAS_TO_CANON.get(alias, "shirt") except Exception: category = "shirt" name = fname rating = scorer.predict_1to100(img) state["wardrobe"].append({ "id": next_id, "name": name, "image": img, "features": feats, "embedding": emb, "category": category, "rating": int(rating), }) added += 1 next_id += 1 except Exception: continue gallery = _render_gallery(state) status_lines = [] if added: status_lines.append(f"✅ Added {added} item(s) to your wardrobe.") if warnings: status_lines.append( f"⚠️ Skipped {len(warnings)} file(s) (not PNG/JPG): " + ", ".join(warnings[:5]) + ("..." if len(warnings) > 5 else "") ) status_lines.append("Please upload .png, .jpg, or .jpeg files.") status_msg = "\n\n".join(status_lines) return state, gallery, _ratings_df(state), status_msg def add_wardrobe_from_dir(example_dir: str, state: Dict[str, Any]): """Load all images in a folder into the wardrobe and auto-rate/classify them. Used by gr.Examples. Accepts relative paths in the Space repo. """ if not example_dir: return state, _render_gallery(state), _ratings_df(state), "" p = Path(example_dir) patterns = ["*.jpg", "*.jpeg", "*.png"] # keep examples aligned with allowed types files = [] for pat in patterns: files.extend([str(x) for x in p.glob(pat)]) st, gal, df = add_wardrobe(files, state)[:3] # add_wardrobe returns 4 now; reuse message-less return for examples return st, gal, df, "Loaded example wardrobe." def clear_wardrobe(state: Dict[str, Any]): state = _blank_state() return state, [], _ratings_df(state), "Wardrobe cleared." def _render_gallery(state: Dict[str, Any]) -> List[Image.Image]: return [w["image"] for w in state.get("wardrobe", [])] def on_select_item(item_label: str, state: Dict[str, Any]): if not item_label: return state, None, gr.update(value=50) # label format: "# · " try: item_id = int(item_label.split(" ")[0][1:]) except Exception: return state, None, gr.update(value=50) idx = next((i for i, w in enumerate(state["wardrobe"]) if w["id"] == item_id), None) state["selected_idx"] = idx if idx is None: return state, None, gr.update(value=50) w = state["wardrobe"][idx] current_rating = w["rating"] if w["rating"] is not None else 50 return state, w["image"], gr.update(value=int(current_rating)) def save_rating(rating: int, state: Dict[str, Any]): idx = state.get("selected_idx", None) if idx is None: return state, _ratings_df(state) state["wardrobe"][idx]["rating"] = int(rating) return state, _ratings_df(state) def _ratings_df(state: Dict[str, Any]) -> pd.DataFrame: rows = [] for w in state.get("wardrobe", []): rows.append({ "id": w["id"], "name": w["name"], "category": w.get("category"), "model_rating": w["rating"], }) df = pd.DataFrame(rows) if not df.empty: df = df.sort_values(by=["category", "model_rating", "id"], ascending=[True, False, True]) return df def export_ratings(state: Dict[str, Any]): df = _ratings_df(state) buf = io.BytesIO() df.to_csv(buf, index=False) buf.seek(0) return buf buf = io.BytesIO() df.to_csv(buf, index=False) buf.seek(0) return buf def import_ratings(file_obj, state: Dict[str, Any]): # Deprecated in auto-rating flow; keep no-op for compatibility return state, _ratings_df(state) try: df = pd.read_csv(file_obj.name if hasattr(file_obj, "name") else file_obj) names_to_rating = {str(row["name"]): int(row["rating"]) if not pd.isna(row["rating"]) else None for _, row in df.iterrows()} for w in state.get("wardrobe", []): if w["name"] in names_to_rating: w["rating"] = names_to_rating[w["name"]] except Exception: pass return state, _ratings_df(state) # ---------------------------- # Rating + Recommendation # ---------------------------- def rate_and_recommend(query_img: Image.Image, top_k: int, matching_mode: str, state: Dict[str, Any]): if query_img is None: return 0, "Please upload a PNG/JPG image to get a rating and matches.", [] query_img = _ensure_rgb(query_img) # Model rating try: scorer = _get_scorer() pred = scorer.predict_1to100(query_img) except Exception: pred = 50 # Features and category for matching embedder = _get_embedder() qfeat = _hsv_hist_features(query_img) qemb = embedder.embed(query_img) # Determine query category (allow override for tests) forced = os.getenv("FORCE_QUERY_CATEGORY") if forced and forced in _ALLOWED_PAIRS: qcat = forced else: try: alias, _ = embedder.classify(query_img, _CAT_LABELS) qcat = _ALIAS_TO_CANON.get(alias, "shirt") except Exception: qcat = "shirt" allowed_targets = _ALLOWED_PAIRS.get(qcat, ["pants", "jacket", "shoes"]) # default avoid same candidates = [] for w in state.get("wardrobe", []): if w.get("category") not in allowed_targets: continue # Cosine similarity on embeddings wemb = w.get("embedding") if wemb is None: wemb = embedder.embed(w["image"]) w["embedding"] = wemb cos = float(np.dot(qemb, wemb) / (np.linalg.norm(qemb) * np.linalg.norm(wemb) + 1e-8)) cos = max(-1.0, min(1.0, cos)) # Color complement score comp = _complementary_hue_score(qfeat, w["features"]) # 0..1 if matching_mode == "Similar style": final = (cos + 1.0) / 2.0 # 0..1 else: # Complementary color + style final = 0.5 * ((cos + 1.0) / 2.0) + 0.5 * comp # Quality prior if w.get("rating") is not None: qual = 0.5 + 0.5 * (w["rating"] / 100.0) final *= qual candidates.append((final, w)) # Fallback: if nothing matched allowed categories, allow any category except same if not candidates: for w in state.get("wardrobe", []): if w.get("category") == qcat: continue wemb = w.get("embedding") or embedder.embed(w["image"]) cos = float(np.dot(qemb, wemb) / (np.linalg.norm(qemb) * np.linalg.norm(wemb) + 1e-8)) comp = _complementary_hue_score(qfeat, w["features"]) final = 0.5 * ((cos + 1.0) / 2.0) + 0.5 * comp candidates.append((final, w)) candidates.sort(key=lambda x: x[0], reverse=True) top = candidates[: max(0, top_k)] recs = [] for score, w in top: caption = f"{w['name']} · {w.get('category','?')} · match {int(round(100*score))}%" recs.append((w["image"], caption)) if len(recs) == 0: txt = f"**Predicted wear score:** {pred}/100\n\n_No compatible matches found in your wardrobe._" return pred, txt, [] top_names = ", ".join([f"{w['name']} ({w.get('category')})" for _, w in top]) txt = ( f"**Predicted wear score:** {pred}/100 \n" f"**Detected category:** {qcat} \n" f"**Top suggestions:** {top_names} \n" f"_Matching mode:_ {matching_mode.lower()} with category filtering and quality prior." ) return pred, txt, recs # (Unreached legacy block kept to minimize the overall diff) try: scorer = _get_scorer() pred = scorer.predict_1to100(query_img) except Exception: pred = 50 qfeat = _hsv_hist_features(query_img) candidates = [] for w in state.get("wardrobe", []): comp = _complementary_hue_score(qfeat, w["features"]) user_w = 1.0 if w["rating"] is not None: user_w = 0.5 + 0.5 * (w["rating"] / 100.0) final = comp * user_w candidates.append((final, w)) candidates.sort(key=lambda x: x[0], reverse=True) recs = [w["image"] for _, w in candidates[: max(0, top_k)]] if len(recs) == 0: txt = f"Predicted rating: {pred}/100. No matches found in your wardrobe." return pred, txt, [] top_names = ", ".join([w["name"] for _, w in candidates[: max(0, top_k)]]) txt = ( f"Predicted rating: {pred}/100. Suggested pairings from your wardrobe: {top_names}." f"Logic: complementary hues + texture contrast + your cached ratings." ) return pred, txt, recs # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(title="Wardrobe Rater + Recommender", css=""" .gradio-container {max-width: 1200px} """) as demo: gr.Markdown( "# Wardrobe Rater + Recommender\n" "**What this app does:** Scores how likely you are to wear a new item (1–100) and suggests compatible pieces from your wardrobe. \n" "**How to use it:** (1) Upload a few wardrobe images first. (PNG/JPG only.) (2) Go to *Rate + Recommend* and upload a new item to get a score and matches." ) app_state = gr.State(_blank_state()) with gr.Tab("1) Wardrobe Manager"): with gr.Row(): wardrobe_uploader = gr.File( label="Upload wardrobe images (PNG/JPG)", file_types=[".png", ".jpg", ".jpeg"], # enforce png/jpg file_count="multiple", info="Tip: crop to the item; solid backgrounds work well." ) with gr.Row(): add_btn = gr.Button("Add to wardrobe (auto-rate + auto-category)") clear_btn = gr.Button("Clear wardrobe") status_md = gr.Markdown("") # status / reminders (e.g., wrong file types) gallery = gr.Gallery( label="Current wardrobe", columns=6, height=220, object_fit="contain", allow_preview=False ) ratings_table = gr.Dataframe(headers=["id", "name", "category", "model_rating"], interactive=False) gr.Markdown("### Or load an example wardrobe") example_dir = gr.Textbox(label="Example folder path", value="examples/wardrobe_basic", visible=False) gr.Examples( examples=[["examples/wardrobe_basic", None]], inputs=[example_dir, app_state], outputs=[app_state, gallery, ratings_table, status_md], fn=add_wardrobe_from_dir, cache_examples=False, run_on_click=True, ) with gr.Tab("2) Rate + Recommend New Item"): with gr.Row(): query_img = gr.Image( label="Upload or take photo (PNG/JPG)", sources=["upload", "webcam","clipboard"], type="pil", image_mode="RGB", info="Upload a clear photo of the item to score and match." ) topk = gr.Slider(1, 6, value=3, step=1, label="# of matches to return") matching_mode = gr.Radio( ["Complementary color+style", "Similar style"], value="Complementary color+style", label="Matching mode" ) go_btn = gr.Button("Rate + Recommend") with gr.Row(): pred_score = gr.Number(label="Predicted wear score (1–100)") rec_text = gr.Markdown() rec_gallery = gr.Gallery( label="Matches in your wardrobe", columns=6, height=220, object_fit="contain", allow_preview=False ) # --- Wiring --- add_btn.click( add_wardrobe, inputs=[wardrobe_uploader, app_state], outputs=[app_state, gallery, ratings_table, status_md] ) clear_btn.click( clear_wardrobe, inputs=[app_state], outputs=[app_state, gallery, ratings_table, status_md] ) go_btn.click( rate_and_recommend, inputs=[query_img, topk, matching_mode, app_state], outputs=[pred_score, rec_text, rec_gallery] ) # Lightweight tests. Run only when RUN_TESTS=1 if __name__ == "__main__": if os.getenv("RUN_TESTS") == "1": # Synthetic solid color images def solid(c): return Image.fromarray(np.full((224,224,3), c, dtype=np.uint8)) # Test scorer bounds s = _get_scorer() for c in [0, 64, 128, 192, 255]: r = s.predict_1to100(solid(c)) assert 1 <= r <= 100, f"rating out of range: {r}" # Test embedder produces nonzero vector and cosine in range e = _get_embedder() v1 = e.embed(solid(0)); v2 = e.embed(solid(255)) assert v1.ndim == 1 and v1.size > 0 cos = float(np.dot(v1, v2) / (np.linalg.norm(v1)*np.linalg.norm(v2)+1e-8)) assert -1.0 <= cos <= 1.0 # Test recommend path with small wardrobe st = _blank_state() img_b = solid(32) st["wardrobe"].append({"id":0, "name":"test.png", "image":img_b, "features":_hsv_hist_features(img_b), "embedding":_get_embedder().embed(img_b), "rating":50}) pred, txt, recs = rate_and_recommend(solid(200), 1, "Similar style", st) assert isinstance(pred, int) and isinstance(txt, str) and isinstance(recs, list) print("Tests passed.") else: demo.launch()