File size: 3,734 Bytes
e2c48d1
899c524
77fb271
5910f04
b723189
 
0184aed
b723189
0184aed
41b1b4e
5910f04
b723189
77fb271
 
 
 
 
 
 
b723189
77fb271
 
 
b723189
 
636e0ba
78c79c8
d0e1651
78c79c8
 
 
 
 
 
 
0184aed
 
 
b723189
ba05218
b723189
 
 
 
caa7434
b723189
899c524
b723189
2398e1d
899c524
b5d23b1
899c524
65a4898
899c524
78c79c8
 
ce805af
e2c48d1
 
b5d23b1
2398e1d
b723189
 
72e01bc
b5d23b1
15d93c6
b5a31f2
 
 
3c886fa
 
 
b5d23b1
3c886fa
 
 
 
b5a31f2
 
 
afda8c0
b5a31f2
7918186
94e7509
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import json
import argparse
import operator
import gradio as gr
import torchvision
from typing import Tuple, Dict
from facetorch import FaceAnalyzer
from facetorch.datastruct import ImageData
from omegaconf import OmegaConf
from torch.nn.functional import cosine_similarity

parser = argparse.ArgumentParser(description="App")
parser.add_argument(
    "--path-conf",
    type=str,
    default="config.merged.yml",
    help="Path to the config file",
)

args = parser.parse_args()

cfg = OmegaConf.load(args.path_conf)
analyzer = FaceAnalyzer(cfg.analyzer)


def gen_sim_dict_str(response: ImageData, pred_name: str = "verify", index: int = 0)-> str:     
    if len(response.faces) > 0:
        base_emb = response.faces[index].preds[pred_name].logits
        sim_dict = {face.indx: cosine_similarity(base_emb, face.preds[pred_name].logits, dim=0).item() for face in response.faces}
        sim_dict_sort = dict(sorted(sim_dict.items(), key=operator.itemgetter(1),reverse=True))
        sim_dict_sort_str = str(sim_dict_sort)
    else:
        sim_dict_sort_str = ""
        
    return sim_dict_sort_str

def inference(path_image: str) -> Tuple:
    response = analyzer.run(
        path_image=path_image,
        batch_size=cfg.batch_size,
        fix_img_size=cfg.fix_img_size,
        return_img_data=cfg.return_img_data,
        include_tensors=cfg.include_tensors,
        path_output=None,
    )
    
    pil_image = torchvision.transforms.functional.to_pil_image(response.img)
    
    fer_dict_str = str({face.indx: face.preds["fer"].label for face in response.faces})
    au_dict_str = str({face.indx: face.preds["au"].other["multi"] for face in response.faces})
    deepfake_dict_str = str({face.indx: face.preds["deepfake"].label for face in response.faces})
    response_str = str(response)
    
    sim_dict_str_embed = gen_sim_dict_str(response, pred_name="embed", index=0)
    sim_dict_str_verify = gen_sim_dict_str(response, pred_name="verify", index=0)
    
    os.remove(path_image)
    
    out_tuple = (pil_image, fer_dict_str, au_dict_str, deepfake_dict_str, sim_dict_str_embed, sim_dict_str_verify, response_str)
    return out_tuple


title = "Face Analysis"
description = "Demo of facetorch, a face analysis Python library that implements open-source pre-trained neural networks for face detection, representation learning, verification, expression recognition, action unit detection, deepfake detection, and 3D alignment. Try selecting one of the example images or upload your own. Feel free to duplicate this space and run it faster on a GPU instance. This work would not be possible without the researchers and engineers who trained the models (sources and credits can be found in the facetorch repository)."
article = "<p style='text-align: center'><a href='https://github.com/tomas-gajarsky/facetorch' style='text-align:center' target='_blank'>facetorch GitHub repository</a></p>"

demo=gr.Interface(
    inference,
    [gr.Image(label="Input", type="filepath")],
    [gr.Image(type="pil", label="Face Detection and 3D Landmarks"),
     gr.Textbox(label="Facial Expression Recognition"),
     gr.Textbox(label="Facial Action Unit Detection"),
     gr.Textbox(label="DeepFake Detection"),
     gr.Textbox(label="Cosine similarity of Face Representation Embeddings"),
     gr.Textbox(label="Cosine similarity of Face Verification Embeddings"),
     gr.Textbox(label="Response")],
    title=title,
    description=description,
    article=article,
    examples=[["./test5.jpg"], ["./test.jpg"], ["./test4.jpg"], ["./test8.jpg"], ["./test6.jpg"], ["./test3.jpg"], ["./test10.jpg"]],
)
demo.queue(concurrency_count=1, api_open=False)
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)