prithivMLmods commited on
Commit
134672c
·
verified ·
1 Parent(s): 7204fab

Update roop/face_analyser.py

Browse files
Files changed (1) hide show
  1. roop/face_analyser.py +55 -34
roop/face_analyser.py CHANGED
@@ -1,34 +1,55 @@
1
- import threading
2
- from typing import Any
3
- import insightface
4
-
5
- import roop.globals
6
- from roop.typing import Frame
7
-
8
- FACE_ANALYSER = None
9
- THREAD_LOCK = threading.Lock()
10
-
11
-
12
- def get_face_analyser() -> Any:
13
- global FACE_ANALYSER
14
-
15
- with THREAD_LOCK:
16
- if FACE_ANALYSER is None:
17
- FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
18
- FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640))
19
- return FACE_ANALYSER
20
-
21
-
22
- def get_one_face(frame: Frame) -> Any:
23
- face = get_face_analyser().get(frame)
24
- try:
25
- return min(face, key=lambda x: x.bbox[0])
26
- except ValueError:
27
- return None
28
-
29
-
30
- def get_many_faces(frame: Frame) -> Any:
31
- try:
32
- return get_face_analyser().get(frame)
33
- except IndexError:
34
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Any, Optional, List
3
+ import insightface
4
+ import numpy
5
+ import spaces
6
+
7
+ import roop.globals
8
+ from roop.typing import Frame, Face
9
+
10
+ FACE_ANALYSER = None
11
+ THREAD_LOCK = threading.Lock()
12
+
13
+ @spaces.GPU()
14
+ def get_face_analyser() -> Any:
15
+ global FACE_ANALYSER
16
+
17
+ with THREAD_LOCK:
18
+ if FACE_ANALYSER is None:
19
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
20
+ FACE_ANALYSER.prepare(ctx_id=0)
21
+ return FACE_ANALYSER
22
+
23
+
24
+ def clear_face_analyser() -> Any:
25
+ global FACE_ANALYSER
26
+
27
+ FACE_ANALYSER = None
28
+
29
+
30
+ def get_one_face(frame: Frame, position: int = 0) -> Optional[Face]:
31
+ many_faces = get_many_faces(frame)
32
+ if many_faces:
33
+ try:
34
+ return many_faces[position]
35
+ except IndexError:
36
+ return many_faces[-1]
37
+ return None
38
+
39
+
40
+ def get_many_faces(frame: Frame) -> Optional[List[Face]]:
41
+ try:
42
+ return get_face_analyser().get(frame)
43
+ except ValueError:
44
+ return None
45
+
46
+
47
+ def find_similar_face(frame: Frame, reference_face: Face) -> Optional[Face]:
48
+ many_faces = get_many_faces(frame)
49
+ if many_faces:
50
+ for face in many_faces:
51
+ if hasattr(face, 'normed_embedding') and hasattr(reference_face, 'normed_embedding'):
52
+ distance = numpy.sum(numpy.square(face.normed_embedding - reference_face.normed_embedding))
53
+ if distance < roop.globals.similar_face_distance:
54
+ return face
55
+ return None