import os import clip import torch import logging import json import pickle from PIL import Image import gradio as gr from scipy.special import softmax logger = logging.getLogger("basebody") CLIP_MODEL_NAME = "ViT-B/16" TEXT_PROMPTS_FILE_NAME = "text_prompts2.json" HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json" LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl" HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl" HF_TOKEN = os.getenv('HF_TOKEN') hf_writer = gr.HuggingFaceDatasetSaver( HF_TOKEN, "Roblox/basebody_feedback" ) clip_model, preprocess = clip.load( CLIP_MODEL_NAME, device="cpu" ) with open( os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r" ) as f: text_prompts = json.load(f) with open( os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r" ) as f: hair_text_prompts = json.load(f) with open( os.path.join( os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME ), "rb", ) as f: lr_model = pickle.load(f) with open( os.path.join( os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME ), "rb", ) as f: hair_rf_model = pickle.load(f) logger.info("Logistic regression model loaded, coefficients: ") all_text_features = [] with torch.no_grad(): for k, prompts in text_prompts.items(): assert len(prompts) == 2 inputs = clip.tokenize(prompts) outputs = clip_model.encode_text(inputs) all_text_features.append(outputs) all_text_features = torch.cat(all_text_features, dim=0) all_text_features = all_text_features.cpu() hair_text_features = [] with torch.no_grad(): for k, prompts in hair_text_prompts.items(): assert len(prompts) == 2 inputs = clip.tokenize(prompts) outputs = clip_model.encode_text(inputs) hair_text_features.append(outputs) hair_text_features = torch.cat(hair_text_features, dim=0) hair_text_features = hair_text_features.cpu() def get_cosine_similarities(image_features, text_features, text_prompts): cosine_simlarities = softmax( (text_features @ image_features.cpu().T) .squeeze() .reshape(len(text_prompts), 2, -1), axis=1, )[:, 0, :] return cosine_simlarities def predict_fn(input_img): input_img = Image.fromarray(input_img.astype("uint8"), "RGB") image = preprocess( input_img ).unsqueeze(0) with torch.no_grad(): image_features = clip_model.encode_image(image) base_body_cosine_simlarities = get_cosine_similarities( image_features, all_text_features, text_prompts ) hair_cosine_simlarities = get_cosine_similarities( image_features, hair_text_features, hair_text_prompts ) # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}") logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}") probabilities = lr_model.predict_proba( base_body_cosine_simlarities.reshape(1, -1) ) hair_probabilities = hair_rf_model.predict_proba( hair_cosine_simlarities.reshape(1, -1) ) logger.info(f"probabilities: {probabilities}") result_probabilty = float(probabilities[0][1].round(3)) hair_result_probabilty = float(hair_probabilities[0][1].round(3)) # get decision string if result_probabilty > 0.77: if hair_result_probabilty < 0.5: logger.info("hair_result_probabilty < 0.5") result_probabilty = hair_result_probabilty decision = "AUTO REJECT" else: decision = "AUTO ACCEPT" elif result_probabilty < 0.4: decision = "AUTO REJECT" else: decision = "MODERATION" logger.info(f"decision: {decision}") decision_json = json.dumps( {"is_base_body": result_probabilty, "decision": decision} ).encode("utf-8") logger.info(f"decision_json: {decision_json}") return decision_json iface = gr.Interface( fn=predict_fn, inputs="image", outputs="text", description=""" The model returns the probability of the image being a base body. If probability > 0.77, the image can be automatically tagged as a base body. If probability < 0.4, the image can be automatically REJECTED as NOT as base body. All other cases will be submitted for moderation. Please flag if you think the decision is wrong. """, allow_flagging="manual", flagging_options=[ ": decision should be accept", ": decision should be reject", ": decision should be moderation" ], flagging_callback=hf_writer ) iface.launch()