import comfy import re from impact.utils import * hf_transformer_model_urls = [ "rizvandwiki/gender-classification-2", "NTQAI/pedestrian_gender_recognition", "Leilab/gender_class", "ProjectPersonal/GenderClassifier", "crangana/trained-gender", "cledoux42/GenderNew_v002", "ivensamdh/genderage2" ] class HF_TransformersClassifierProvider: @classmethod def INPUT_TYPES(s): global hf_transformer_model_urls return {"required": { "preset_repo_id": (hf_transformer_model_urls + ['Manual repo id'],), "manual_repo_id": ("STRING", {"multiline": False}), "device_mode": (["AUTO", "Prefer GPU", "CPU"],), }, } RETURN_TYPES = ("TRANSFORMERS_CLASSIFIER",) FUNCTION = "doit" CATEGORY = "ImpactPack/HuggingFace" def doit(self, preset_repo_id, manual_repo_id, device_mode): from transformers import pipeline if preset_repo_id == 'Manual repo id': url = manual_repo_id else: url = preset_repo_id if device_mode != 'CPU': device = comfy.model_management.get_torch_device() else: device = "cpu" classifier = pipeline('image-classification', model=url, device=device) return (classifier,) preset_classify_expr = [ '#Female > #Male', '#Female < #Male', 'female > 0.5', 'male > 0.5', 'Age16to25 > 0.1', 'Age50to69 > 0.1', ] symbolic_label_map = { '#Female': {'female', 'Female', 'Human Female', 'woman', 'women', 'girl'}, '#Male': {'male', 'Male', 'Human Male', 'man', 'men', 'boy'} } def is_numeric_string(input_str): return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None classify_expr_pattern = r'([^><= ]+)\s*(>|<|>=|<=|=)\s*([^><= ]+)' class SEGS_Classify: @classmethod def INPUT_TYPES(s): global preset_classify_expr return {"required": { "classifier": ("TRANSFORMERS_CLASSIFIER",), "segs": ("SEGS",), "preset_expr": (preset_classify_expr + ['Manual expr'],), "manual_expr": ("STRING", {"multiline": False}), }, "optional": { "ref_image_opt": ("IMAGE", ), } } RETURN_TYPES = ("SEGS", "SEGS", "STRING") RETURN_NAMES = ("filtered_SEGS", "remained_SEGS", "detected_labels") OUTPUT_IS_LIST = (False, False, True) FUNCTION = "doit" CATEGORY = "ImpactPack/HuggingFace" @staticmethod def lookup_classified_label_score(score_infos, label): global symbolic_label_map if label.startswith('#'): if label not in symbolic_label_map: return None else: label = symbolic_label_map[label] else: label = {label} for x in score_infos: if x['label'] in label: return x['score'] return None def doit(self, classifier, segs, preset_expr, manual_expr, ref_image_opt=None): if preset_expr == 'Manual expr': expr_str = manual_expr else: expr_str = preset_expr match = re.match(classify_expr_pattern, expr_str) if match is None: return (segs[0], []), segs, [] a = match.group(1) op = match.group(2) b = match.group(3) a_is_lab = not is_numeric_string(a) b_is_lab = not is_numeric_string(b) classified = [] remained_SEGS = [] provided_labels = set() for seg in segs[1]: cropped_image = None if seg.cropped_image is not None: cropped_image = seg.cropped_image elif ref_image_opt is not None: # take from original image cropped_image = crop_image(ref_image_opt, seg.crop_region) if cropped_image is not None: cropped_image = to_pil(cropped_image) res = classifier(cropped_image) classified.append((seg, res)) for x in res: provided_labels.add(x['label']) else: remained_SEGS.append(seg) filtered_SEGS = [] for seg, res in classified: if a_is_lab: avalue = SEGS_Classify.lookup_classified_label_score(res, a) else: avalue = a if b_is_lab: bvalue = SEGS_Classify.lookup_classified_label_score(res, b) else: bvalue = b if avalue is None or bvalue is None: remained_SEGS.append(seg) continue avalue = float(avalue) bvalue = float(bvalue) if op == '>': cond = avalue > bvalue elif op == '<': cond = avalue < bvalue elif op == '>=': cond = avalue >= bvalue elif op == '<=': cond = avalue <= bvalue else: cond = avalue == bvalue if cond: filtered_SEGS.append(seg) else: remained_SEGS.append(seg) return (segs[0], filtered_SEGS), (segs[0], remained_SEGS), list(provided_labels)