Spaces:
Paused
Paused
Commit
•
d69879c
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +17 -0
- .gitattributes +13 -0
- .gitignore +31 -0
- Dockerfile +67 -0
- LICENSE +56 -0
- README.md +119 -0
- app.py +264 -0
- build.sh +3 -0
- client/.gitignore +175 -0
- client/README.md +13 -0
- client/bun.lockb +0 -0
- client/package.json +35 -0
- client/src/app.tsx +190 -0
- client/src/components/DoubleCard.tsx +18 -0
- client/src/components/PoweredBy.tsx +17 -0
- client/src/components/Spinner.tsx +7 -0
- client/src/components/Title.tsx +8 -0
- client/src/components/ui/alert.tsx +59 -0
- client/src/hooks/landmarks.ts +520 -0
- client/src/hooks/useFaceLandmarkDetection.tsx +632 -0
- client/src/hooks/useFacePokeAPI.ts +44 -0
- client/src/hooks/useMainStore.ts +58 -0
- client/src/index.tsx +6 -0
- client/src/layout.tsx +14 -0
- client/src/lib/circularBuffer.ts +31 -0
- client/src/lib/convertImageToBase64.ts +19 -0
- client/src/lib/facePoke.ts +398 -0
- client/src/lib/throttle.ts +32 -0
- client/src/lib/utils.ts +15 -0
- client/src/styles/globals.css +81 -0
- client/tailwind.config.js +86 -0
- client/tsconfig.json +32 -0
- engine.py +300 -0
- liveportrait/config/__init__.py +0 -0
- liveportrait/config/argument_config.py +44 -0
- liveportrait/config/base_config.py +29 -0
- liveportrait/config/crop_config.py +18 -0
- liveportrait/config/inference_config.py +53 -0
- liveportrait/config/models.yaml +43 -0
- liveportrait/gradio_pipeline.py +140 -0
- liveportrait/live_portrait_pipeline.py +193 -0
- liveportrait/live_portrait_wrapper.py +307 -0
- liveportrait/modules/__init__.py +0 -0
- liveportrait/modules/appearance_feature_extractor.py +48 -0
- liveportrait/modules/convnextv2.py +149 -0
- liveportrait/modules/dense_motion.py +104 -0
- liveportrait/modules/motion_extractor.py +35 -0
- liveportrait/modules/spade_generator.py +59 -0
- liveportrait/modules/stitching_retargeting_network.py +38 -0
- liveportrait/modules/util.py +441 -0
.dockerignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The .dockerignore file excludes files from the container build process.
|
2 |
+
#
|
3 |
+
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
4 |
+
|
5 |
+
# Exclude Git files
|
6 |
+
.git
|
7 |
+
.github
|
8 |
+
.gitignore
|
9 |
+
|
10 |
+
# Exclude Python cache files
|
11 |
+
__pycache__
|
12 |
+
.mypy_cache
|
13 |
+
.pytest_cache
|
14 |
+
.ruff_cache
|
15 |
+
|
16 |
+
# Exclude Python virtual environment
|
17 |
+
/venv
|
.gitattributes
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.xml filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mpg filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
**/__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
**/*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# Model weights
|
9 |
+
**/*.pth
|
10 |
+
**/*.onnx
|
11 |
+
|
12 |
+
# Ipython notebook
|
13 |
+
*.ipynb
|
14 |
+
|
15 |
+
# Temporary files or benchmark resources
|
16 |
+
animations/*
|
17 |
+
tmp/*
|
18 |
+
|
19 |
+
# more ignores
|
20 |
+
.DS_Store
|
21 |
+
*.log
|
22 |
+
.idea/
|
23 |
+
.vscode/
|
24 |
+
*.pyc
|
25 |
+
.ipynb_checkpoints
|
26 |
+
results/
|
27 |
+
data/audio/*.wav
|
28 |
+
data/video/*.mp4
|
29 |
+
ffmpeg-7.0-amd64-static
|
30 |
+
venv/
|
31 |
+
.cog/
|
Dockerfile
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
4 |
+
|
5 |
+
ENV PYTHONUNBUFFERED=1
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
8 |
+
build-essential \
|
9 |
+
python3.11 \
|
10 |
+
python3-pip \
|
11 |
+
python3-dev \
|
12 |
+
git \
|
13 |
+
curl \
|
14 |
+
ffmpeg \
|
15 |
+
libglib2.0-0 \
|
16 |
+
libsm6 \
|
17 |
+
libxrender1 \
|
18 |
+
libxext6 \
|
19 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
20 |
+
|
21 |
+
WORKDIR /code
|
22 |
+
|
23 |
+
COPY ./requirements.txt /code/requirements.txt
|
24 |
+
|
25 |
+
# Install pget as root
|
26 |
+
RUN echo "Installing pget" && \
|
27 |
+
curl -o /usr/local/bin/pget -L 'https://github.com/replicate/pget/releases/download/v0.2.1/pget' && \
|
28 |
+
chmod +x /usr/local/bin/pget
|
29 |
+
|
30 |
+
# Set up a new user named "user" with user ID 1000
|
31 |
+
RUN useradd -m -u 1000 user
|
32 |
+
# Switch to the "user" user
|
33 |
+
USER user
|
34 |
+
# Set home to the user's home directory
|
35 |
+
ENV HOME=/home/user \
|
36 |
+
PATH=/home/user/.local/bin:$PATH
|
37 |
+
|
38 |
+
|
39 |
+
# Set home to the user's home directory
|
40 |
+
ENV PYTHONPATH=$HOME/app \
|
41 |
+
PYTHONUNBUFFERED=1 \
|
42 |
+
DATA_ROOT=/tmp/data
|
43 |
+
|
44 |
+
RUN echo "Installing requirements.txt"
|
45 |
+
RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
|
46 |
+
|
47 |
+
# yeah.. this is manual for now
|
48 |
+
#RUN cd client
|
49 |
+
#RUN bun i
|
50 |
+
#RUN bun build ./src/index.tsx --outdir ../public/
|
51 |
+
|
52 |
+
RUN echo "Installing openmim and mim dependencies"
|
53 |
+
RUN pip3 install --no-cache-dir -U openmim
|
54 |
+
RUN mim install mmengine
|
55 |
+
RUN mim install "mmcv>=2.0.1"
|
56 |
+
RUN mim install "mmdet>=3.3.0"
|
57 |
+
RUN mim install "mmpose>=1.3.2"
|
58 |
+
|
59 |
+
WORKDIR $HOME/app
|
60 |
+
|
61 |
+
COPY --chown=user . $HOME/app
|
62 |
+
|
63 |
+
EXPOSE 8080
|
64 |
+
|
65 |
+
ENV PORT 8080
|
66 |
+
|
67 |
+
CMD python3 app.py
|
LICENSE
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## For FacePoke (the modifications I made + the server itself)
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2024 Julian Bilcke
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
|
25 |
+
## For LivePortrait
|
26 |
+
|
27 |
+
MIT License
|
28 |
+
|
29 |
+
Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center
|
30 |
+
|
31 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
32 |
+
of this software and associated documentation files (the "Software"), to deal
|
33 |
+
in the Software without restriction, including without limitation the rights
|
34 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
35 |
+
copies of the Software, and to permit persons to whom the Software is
|
36 |
+
furnished to do so, subject to the following conditions:
|
37 |
+
|
38 |
+
The above copyright notice and this permission notice shall be included in all
|
39 |
+
copies or substantial portions of the Software.
|
40 |
+
|
41 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
42 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
43 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
44 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
45 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
46 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
47 |
+
SOFTWARE.
|
48 |
+
|
49 |
+
---
|
50 |
+
|
51 |
+
The code of InsightFace is released under the MIT License.
|
52 |
+
The models of InsightFace are for non-commercial research purposes only.
|
53 |
+
|
54 |
+
If you want to use the LivePortrait project for commercial purposes, you
|
55 |
+
should remove and replace InsightFace’s detection models to fully comply with
|
56 |
+
the MIT license.
|
README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: FacePoke
|
3 |
+
emoji: 💬
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: docker
|
7 |
+
pinned: true
|
8 |
+
license: mit
|
9 |
+
header: mini
|
10 |
+
app_file: app.py
|
11 |
+
app_port: 8080
|
12 |
+
---
|
13 |
+
|
14 |
+
# FacePoke
|
15 |
+
|
16 |
+
![FacePoke Demo](https://your-demo-image-url-here.gif)
|
17 |
+
|
18 |
+
## Table of Contents
|
19 |
+
|
20 |
+
- [Introduction](#introduction)
|
21 |
+
- [Acknowledgements](#acknowledgements)
|
22 |
+
- [Installation](#installation)
|
23 |
+
- [Local Setup](#local-setup)
|
24 |
+
- [Docker Deployment](#docker-deployment)
|
25 |
+
- [Development](#development)
|
26 |
+
- [Contributing](#contributing)
|
27 |
+
- [License](#license)
|
28 |
+
|
29 |
+
## Introduction
|
30 |
+
|
31 |
+
A real-time head transformation app.
|
32 |
+
|
33 |
+
For best performance please run the app from your own machine (local or in the cloud).
|
34 |
+
|
35 |
+
**Repository**: [GitHub - jbilcke-hf/FacePoke](https://github.com/jbilcke-hf/FacePoke)
|
36 |
+
|
37 |
+
You can try the demo but it is a shared space, latency may be high if there are multiple users or if you live far from the datacenter hosting the Hugging Face Space.
|
38 |
+
|
39 |
+
**Live Demo**: [FacePoke on Hugging Face Spaces](https://huggingface.co/spaces/jbilcke-hf/FacePoke)
|
40 |
+
|
41 |
+
## Acknowledgements
|
42 |
+
|
43 |
+
This project is based on LivePortrait: https://arxiv.org/abs/2407.03168
|
44 |
+
|
45 |
+
It uses the face transformation routines from https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait
|
46 |
+
|
47 |
+
## Installation
|
48 |
+
|
49 |
+
### Local Setup
|
50 |
+
|
51 |
+
1. Clone the repository:
|
52 |
+
```bash
|
53 |
+
git clone https://github.com/jbilcke-hf/FacePoke.git
|
54 |
+
cd FacePoke
|
55 |
+
```
|
56 |
+
|
57 |
+
2. Install Python dependencies:
|
58 |
+
```bash
|
59 |
+
pip install -r requirements.txt
|
60 |
+
```
|
61 |
+
|
62 |
+
3. Install frontend dependencies:
|
63 |
+
```bash
|
64 |
+
cd client
|
65 |
+
bun install
|
66 |
+
```
|
67 |
+
|
68 |
+
4. Build the frontend:
|
69 |
+
```bash
|
70 |
+
bun build ./src/index.tsx --outdir ../public/
|
71 |
+
```
|
72 |
+
|
73 |
+
5. Start the backend server:
|
74 |
+
```bash
|
75 |
+
python app.py
|
76 |
+
```
|
77 |
+
|
78 |
+
6. Open `http://localhost:8080` in your web browser.
|
79 |
+
|
80 |
+
### Docker Deployment
|
81 |
+
|
82 |
+
1. Build the Docker image:
|
83 |
+
```bash
|
84 |
+
docker build -t facepoke .
|
85 |
+
```
|
86 |
+
|
87 |
+
2. Run the container:
|
88 |
+
```bash
|
89 |
+
docker run -p 8080:8080 facepoke
|
90 |
+
```
|
91 |
+
|
92 |
+
3. To deploy to Hugging Face Spaces:
|
93 |
+
- Fork the repository on GitHub.
|
94 |
+
- Create a new Space on Hugging Face.
|
95 |
+
- Connect your GitHub repository to the Space.
|
96 |
+
- Configure the Space to use the Docker runtime.
|
97 |
+
|
98 |
+
## Development
|
99 |
+
|
100 |
+
The project structure is organized as follows:
|
101 |
+
|
102 |
+
- `app.py`: Main backend server handling WebSocket connections.
|
103 |
+
- `engine.py`: Core logic.
|
104 |
+
- `loader.py`: Initializes and loads AI models.
|
105 |
+
- `client/`: Frontend React application.
|
106 |
+
- `src/`: TypeScript source files.
|
107 |
+
- `public/`: Static assets and built files.
|
108 |
+
|
109 |
+
## Contributing
|
110 |
+
|
111 |
+
Contributions to FacePoke are welcome! Please read our [Contributing Guidelines](CONTRIBUTING.md) for details on how to submit pull requests, report issues, or request features.
|
112 |
+
|
113 |
+
## License
|
114 |
+
|
115 |
+
FacePoke is released under the MIT License. See the [LICENSE](LICENSE) file for details.
|
116 |
+
|
117 |
+
---
|
118 |
+
|
119 |
+
Developed with ❤️ by Julian Bilcke at Hugging Face
|
app.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FacePoke API
|
3 |
+
|
4 |
+
Author: Julian Bilcke
|
5 |
+
Date: September 30, 2024
|
6 |
+
"""
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import asyncio
|
10 |
+
import hashlib
|
11 |
+
from aiohttp import web, WSMsgType
|
12 |
+
import json
|
13 |
+
import uuid
|
14 |
+
import logging
|
15 |
+
import os
|
16 |
+
import zipfile
|
17 |
+
import signal
|
18 |
+
from typing import Dict, Any, List, Optional
|
19 |
+
import base64
|
20 |
+
import io
|
21 |
+
from PIL import Image
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
# Configure logging
|
25 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# Set asyncio logger to DEBUG level
|
29 |
+
logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
30 |
+
|
31 |
+
logger.debug(f"Python version: {sys.version}")
|
32 |
+
|
33 |
+
# SIGSEGV handler
|
34 |
+
def SIGSEGV_signal_arises(signalNum, stack):
|
35 |
+
logger.critical(f"{signalNum} : SIGSEGV arises")
|
36 |
+
logger.critical(f"Stack trace: {stack}")
|
37 |
+
|
38 |
+
signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises)
|
39 |
+
|
40 |
+
from loader import initialize_models
|
41 |
+
from engine import Engine, base64_data_uri_to_PIL_Image, create_engine
|
42 |
+
|
43 |
+
# Global constants
|
44 |
+
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
45 |
+
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
46 |
+
|
47 |
+
image_cache: Dict[str, Image.Image] = {}
|
48 |
+
|
49 |
+
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
50 |
+
"""
|
51 |
+
Handle WebSocket connections for the FacePoke application.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
request (web.Request): The incoming request object.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
web.WebSocketResponse: The WebSocket response object.
|
58 |
+
"""
|
59 |
+
ws = web.WebSocketResponse()
|
60 |
+
await ws.prepare(request)
|
61 |
+
|
62 |
+
session: Optional[FacePokeSession] = None
|
63 |
+
try:
|
64 |
+
logger.info("New WebSocket connection established")
|
65 |
+
|
66 |
+
while True:
|
67 |
+
msg = await ws.receive()
|
68 |
+
|
69 |
+
if msg.type == WSMsgType.TEXT:
|
70 |
+
data = json.loads(msg.data)
|
71 |
+
|
72 |
+
# let's not log user requests, they are heavy
|
73 |
+
#logger.debug(f"Received message: {data}")
|
74 |
+
|
75 |
+
if data['type'] == 'modify_image':
|
76 |
+
uuid = data.get('uuid')
|
77 |
+
if not uuid:
|
78 |
+
logger.warning("Received message without UUID")
|
79 |
+
|
80 |
+
await handle_modify_image(request, ws, data, uuid)
|
81 |
+
|
82 |
+
|
83 |
+
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
84 |
+
logger.warning(f"WebSocket connection closed: {msg.type}")
|
85 |
+
break
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Error in websocket_handler: {str(e)}")
|
89 |
+
logger.exception("Full traceback:")
|
90 |
+
finally:
|
91 |
+
if session:
|
92 |
+
await session.stop()
|
93 |
+
del active_sessions[session.session_id]
|
94 |
+
logger.info("WebSocket connection closed")
|
95 |
+
return ws
|
96 |
+
|
97 |
+
async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str):
|
98 |
+
"""
|
99 |
+
Handle the 'modify_image' request.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
request (web.Request): The incoming request object.
|
103 |
+
ws (web.WebSocketResponse): The WebSocket response object.
|
104 |
+
msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters.
|
105 |
+
uuid: A unique identifier for the request.
|
106 |
+
"""
|
107 |
+
logger.info("Received modify_image request")
|
108 |
+
try:
|
109 |
+
engine = request.app['engine']
|
110 |
+
image_hash = msg.get('image_hash')
|
111 |
+
|
112 |
+
if image_hash:
|
113 |
+
image_or_hash = image_hash
|
114 |
+
else:
|
115 |
+
image_data = msg['image']
|
116 |
+
image_or_hash = image_data
|
117 |
+
|
118 |
+
modified_image_base64 = await engine.modify_image(image_or_hash, msg['params'])
|
119 |
+
|
120 |
+
await ws.send_json({
|
121 |
+
"type": "modified_image",
|
122 |
+
"image": modified_image_base64,
|
123 |
+
"image_hash": engine.get_image_hash(image_or_hash),
|
124 |
+
"success": True,
|
125 |
+
"uuid": uuid # Include the UUID in the response
|
126 |
+
})
|
127 |
+
logger.info("Successfully sent modified image")
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Error in modify_image: {str(e)}")
|
130 |
+
await ws.send_json({
|
131 |
+
"type": "modified_image",
|
132 |
+
"success": False,
|
133 |
+
"error": str(e),
|
134 |
+
"uuid": uuid # Include the UUID even in error responses
|
135 |
+
})
|
136 |
+
|
137 |
+
async def index(request: web.Request) -> web.Response:
|
138 |
+
"""Serve the index.html file"""
|
139 |
+
content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read()
|
140 |
+
return web.Response(content_type="text/html", text=content)
|
141 |
+
|
142 |
+
async def js_index(request: web.Request) -> web.Response:
|
143 |
+
"""Serve the index.js file"""
|
144 |
+
content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read()
|
145 |
+
return web.Response(content_type="application/javascript", text=content)
|
146 |
+
|
147 |
+
async def hf_logo(request: web.Request) -> web.Response:
|
148 |
+
"""Serve the hf-logo.svg file"""
|
149 |
+
content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read()
|
150 |
+
return web.Response(content_type="image/svg+xml", text=content)
|
151 |
+
|
152 |
+
async def on_shutdown(app: web.Application):
|
153 |
+
"""Cleanup function to be called on server shutdown."""
|
154 |
+
logger.info("Server shutdown initiated, cleaning up resources...")
|
155 |
+
for session in list(active_sessions.values()):
|
156 |
+
await session.stop()
|
157 |
+
active_sessions.clear()
|
158 |
+
logger.info("All active sessions have been closed")
|
159 |
+
|
160 |
+
if 'engine' in app:
|
161 |
+
await app['engine'].cleanup()
|
162 |
+
logger.info("Engine instance cleaned up")
|
163 |
+
|
164 |
+
logger.info("Server shutdown complete")
|
165 |
+
|
166 |
+
async def initialize_app() -> web.Application:
|
167 |
+
"""Initialize and configure the web application."""
|
168 |
+
try:
|
169 |
+
logger.info("Initializing application...")
|
170 |
+
models = await initialize_models()
|
171 |
+
logger.info("🚀 Creating Engine instance...")
|
172 |
+
engine = create_engine(models)
|
173 |
+
logger.info("✅ Engine instance created.")
|
174 |
+
|
175 |
+
app = web.Application()
|
176 |
+
app['engine'] = engine
|
177 |
+
|
178 |
+
app.on_shutdown.append(on_shutdown)
|
179 |
+
|
180 |
+
# Configure routes
|
181 |
+
app.router.add_get("/", index)
|
182 |
+
app.router.add_get("/index.js", js_index)
|
183 |
+
app.router.add_get("/hf-logo.svg", hf_logo)
|
184 |
+
app.router.add_get("/ws", websocket_handler)
|
185 |
+
|
186 |
+
logger.info("Application routes configured")
|
187 |
+
|
188 |
+
return app
|
189 |
+
except Exception as e:
|
190 |
+
logger.error(f"🚨 Error during application initialization: {str(e)}")
|
191 |
+
logger.exception("Full traceback:")
|
192 |
+
raise
|
193 |
+
|
194 |
+
async def start_background_tasks(app: web.Application):
|
195 |
+
"""
|
196 |
+
Start background tasks for the application.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
app (web.Application): The web application instance.
|
200 |
+
"""
|
201 |
+
app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app))
|
202 |
+
|
203 |
+
async def cleanup_background_tasks(app: web.Application):
|
204 |
+
"""
|
205 |
+
Clean up background tasks when the application is shutting down.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
app (web.Application): The web application instance.
|
209 |
+
"""
|
210 |
+
app['cleanup_task'].cancel()
|
211 |
+
await app['cleanup_task']
|
212 |
+
|
213 |
+
async def periodic_cleanup(app: web.Application):
|
214 |
+
"""
|
215 |
+
Perform periodic cleanup tasks for the application.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
app (web.Application): The web application instance.
|
219 |
+
"""
|
220 |
+
while True:
|
221 |
+
try:
|
222 |
+
await asyncio.sleep(3600) # Run cleanup every hour
|
223 |
+
await cleanup_inactive_sessions(app)
|
224 |
+
except asyncio.CancelledError:
|
225 |
+
break
|
226 |
+
except Exception as e:
|
227 |
+
logger.error(f"Error in periodic cleanup: {str(e)}")
|
228 |
+
logger.exception("Full traceback:")
|
229 |
+
|
230 |
+
async def cleanup_inactive_sessions(app: web.Application):
|
231 |
+
"""
|
232 |
+
Clean up inactive sessions.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
app (web.Application): The web application instance.
|
236 |
+
"""
|
237 |
+
logger.info("Starting cleanup of inactive sessions")
|
238 |
+
inactive_sessions = [
|
239 |
+
session_id for session_id, session in active_sessions.items()
|
240 |
+
if not session.is_running.is_set()
|
241 |
+
]
|
242 |
+
for session_id in inactive_sessions:
|
243 |
+
session = active_sessions.pop(session_id)
|
244 |
+
await session.stop()
|
245 |
+
logger.info(f"Cleaned up inactive session: {session_id}")
|
246 |
+
logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions")
|
247 |
+
|
248 |
+
def main():
|
249 |
+
"""
|
250 |
+
Main function to start the FacePoke application.
|
251 |
+
"""
|
252 |
+
try:
|
253 |
+
logger.info("Starting FacePoke application")
|
254 |
+
app = asyncio.run(initialize_app())
|
255 |
+
app.on_startup.append(start_background_tasks)
|
256 |
+
app.on_cleanup.append(cleanup_background_tasks)
|
257 |
+
logger.info("Application initialized, starting web server")
|
258 |
+
web.run_app(app, host="0.0.0.0", port=8080)
|
259 |
+
except Exception as e:
|
260 |
+
logger.critical(f"🚨 FATAL: Failed to start the app: {str(e)}")
|
261 |
+
logger.exception("Full traceback:")
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
main()
|
build.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
cd client
|
2 |
+
bun i
|
3 |
+
bun build ./src/index.tsx --outdir ../public/
|
client/.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://raw.githubusercontent.com/github/gitignore/main/Node.gitignore
|
2 |
+
|
3 |
+
# Logs
|
4 |
+
|
5 |
+
logs
|
6 |
+
_.log
|
7 |
+
npm-debug.log_
|
8 |
+
yarn-debug.log*
|
9 |
+
yarn-error.log*
|
10 |
+
lerna-debug.log*
|
11 |
+
.pnpm-debug.log*
|
12 |
+
|
13 |
+
# Caches
|
14 |
+
|
15 |
+
.cache
|
16 |
+
|
17 |
+
# Diagnostic reports (https://nodejs.org/api/report.html)
|
18 |
+
|
19 |
+
report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json
|
20 |
+
|
21 |
+
# Runtime data
|
22 |
+
|
23 |
+
pids
|
24 |
+
_.pid
|
25 |
+
_.seed
|
26 |
+
*.pid.lock
|
27 |
+
|
28 |
+
# Directory for instrumented libs generated by jscoverage/JSCover
|
29 |
+
|
30 |
+
lib-cov
|
31 |
+
|
32 |
+
# Coverage directory used by tools like istanbul
|
33 |
+
|
34 |
+
coverage
|
35 |
+
*.lcov
|
36 |
+
|
37 |
+
# nyc test coverage
|
38 |
+
|
39 |
+
.nyc_output
|
40 |
+
|
41 |
+
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
42 |
+
|
43 |
+
.grunt
|
44 |
+
|
45 |
+
# Bower dependency directory (https://bower.io/)
|
46 |
+
|
47 |
+
bower_components
|
48 |
+
|
49 |
+
# node-waf configuration
|
50 |
+
|
51 |
+
.lock-wscript
|
52 |
+
|
53 |
+
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
54 |
+
|
55 |
+
build/Release
|
56 |
+
|
57 |
+
# Dependency directories
|
58 |
+
|
59 |
+
node_modules/
|
60 |
+
jspm_packages/
|
61 |
+
|
62 |
+
# Snowpack dependency directory (https://snowpack.dev/)
|
63 |
+
|
64 |
+
web_modules/
|
65 |
+
|
66 |
+
# TypeScript cache
|
67 |
+
|
68 |
+
*.tsbuildinfo
|
69 |
+
|
70 |
+
# Optional npm cache directory
|
71 |
+
|
72 |
+
.npm
|
73 |
+
|
74 |
+
# Optional eslint cache
|
75 |
+
|
76 |
+
.eslintcache
|
77 |
+
|
78 |
+
# Optional stylelint cache
|
79 |
+
|
80 |
+
.stylelintcache
|
81 |
+
|
82 |
+
# Microbundle cache
|
83 |
+
|
84 |
+
.rpt2_cache/
|
85 |
+
.rts2_cache_cjs/
|
86 |
+
.rts2_cache_es/
|
87 |
+
.rts2_cache_umd/
|
88 |
+
|
89 |
+
# Optional REPL history
|
90 |
+
|
91 |
+
.node_repl_history
|
92 |
+
|
93 |
+
# Output of 'npm pack'
|
94 |
+
|
95 |
+
*.tgz
|
96 |
+
|
97 |
+
# Yarn Integrity file
|
98 |
+
|
99 |
+
.yarn-integrity
|
100 |
+
|
101 |
+
# dotenv environment variable files
|
102 |
+
|
103 |
+
.env
|
104 |
+
.env.development.local
|
105 |
+
.env.test.local
|
106 |
+
.env.production.local
|
107 |
+
.env.local
|
108 |
+
|
109 |
+
# parcel-bundler cache (https://parceljs.org/)
|
110 |
+
|
111 |
+
.parcel-cache
|
112 |
+
|
113 |
+
# Next.js build output
|
114 |
+
|
115 |
+
.next
|
116 |
+
out
|
117 |
+
|
118 |
+
# Nuxt.js build / generate output
|
119 |
+
|
120 |
+
.nuxt
|
121 |
+
dist
|
122 |
+
|
123 |
+
# Gatsby files
|
124 |
+
|
125 |
+
# Comment in the public line in if your project uses Gatsby and not Next.js
|
126 |
+
|
127 |
+
# https://nextjs.org/blog/next-9-1#public-directory-support
|
128 |
+
|
129 |
+
# public
|
130 |
+
|
131 |
+
# vuepress build output
|
132 |
+
|
133 |
+
.vuepress/dist
|
134 |
+
|
135 |
+
# vuepress v2.x temp and cache directory
|
136 |
+
|
137 |
+
.temp
|
138 |
+
|
139 |
+
# Docusaurus cache and generated files
|
140 |
+
|
141 |
+
.docusaurus
|
142 |
+
|
143 |
+
# Serverless directories
|
144 |
+
|
145 |
+
.serverless/
|
146 |
+
|
147 |
+
# FuseBox cache
|
148 |
+
|
149 |
+
.fusebox/
|
150 |
+
|
151 |
+
# DynamoDB Local files
|
152 |
+
|
153 |
+
.dynamodb/
|
154 |
+
|
155 |
+
# TernJS port file
|
156 |
+
|
157 |
+
.tern-port
|
158 |
+
|
159 |
+
# Stores VSCode versions used for testing VSCode extensions
|
160 |
+
|
161 |
+
.vscode-test
|
162 |
+
|
163 |
+
# yarn v2
|
164 |
+
|
165 |
+
.yarn/cache
|
166 |
+
.yarn/unplugged
|
167 |
+
.yarn/build-state.yml
|
168 |
+
.yarn/install-state.gz
|
169 |
+
.pnp.*
|
170 |
+
|
171 |
+
# IntelliJ based IDEs
|
172 |
+
.idea
|
173 |
+
|
174 |
+
# Finder (MacOS) folder config
|
175 |
+
.DS_Store
|
client/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FacePoke.js
|
2 |
+
|
3 |
+
To install dependencies:
|
4 |
+
|
5 |
+
```bash
|
6 |
+
bun i
|
7 |
+
```
|
8 |
+
|
9 |
+
To build:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
bun build ./src/index.tsx --outdir ../public
|
13 |
+
```
|
client/bun.lockb
ADDED
Binary file (54.9 kB). View file
|
|
client/package.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "@aitube/facepoke",
|
3 |
+
"module": "src/index.ts",
|
4 |
+
"type": "module",
|
5 |
+
"scripts": {
|
6 |
+
"build": "bun build ./src/index.tsx --outdir ../public/"
|
7 |
+
},
|
8 |
+
"devDependencies": {
|
9 |
+
"@types/bun": "latest"
|
10 |
+
},
|
11 |
+
"peerDependencies": {
|
12 |
+
"typescript": "^5.0.0"
|
13 |
+
},
|
14 |
+
"dependencies": {
|
15 |
+
"@mediapipe/tasks-vision": "^0.10.16",
|
16 |
+
"@radix-ui/react-icons": "^1.3.0",
|
17 |
+
"@types/lodash": "^4.17.10",
|
18 |
+
"@types/react": "^18.3.9",
|
19 |
+
"@types/react-dom": "^18.3.0",
|
20 |
+
"@types/uuid": "^10.0.0",
|
21 |
+
"beautiful-react-hooks": "^5.0.2",
|
22 |
+
"class-variance-authority": "^0.7.0",
|
23 |
+
"clsx": "^2.1.1",
|
24 |
+
"lodash": "^4.17.21",
|
25 |
+
"lucide-react": "^0.446.0",
|
26 |
+
"react": "^18.3.1",
|
27 |
+
"react-dom": "^18.3.1",
|
28 |
+
"tailwind-merge": "^2.5.2",
|
29 |
+
"tailwindcss": "^3.4.13",
|
30 |
+
"tailwindcss-animate": "^1.0.7",
|
31 |
+
"usehooks-ts": "^3.1.0",
|
32 |
+
"uuid": "^10.0.0",
|
33 |
+
"zustand": "^5.0.0-rc.2"
|
34 |
+
}
|
35 |
+
}
|
client/src/app.tsx
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { useState, useEffect, useRef, useCallback, useMemo } from 'react';
|
2 |
+
import { RotateCcw } from 'lucide-react';
|
3 |
+
|
4 |
+
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
|
5 |
+
import { truncateFileName } from './lib/utils';
|
6 |
+
import { useFaceLandmarkDetection } from './hooks/useFaceLandmarkDetection';
|
7 |
+
import { PoweredBy } from './components/PoweredBy';
|
8 |
+
import { Spinner } from './components/Spinner';
|
9 |
+
import { DoubleCard } from './components/DoubleCard';
|
10 |
+
import { useFacePokeAPI } from './hooks/useFacePokeAPI';
|
11 |
+
import { Layout } from './layout';
|
12 |
+
import { useMainStore } from './hooks/useMainStore';
|
13 |
+
import { convertImageToBase64 } from './lib/convertImageToBase64';
|
14 |
+
|
15 |
+
export function App() {
|
16 |
+
const error = useMainStore(s => s.error);
|
17 |
+
const setError = useMainStore(s => s.setError);
|
18 |
+
const imageFile = useMainStore(s => s.imageFile);
|
19 |
+
const setImageFile = useMainStore(s => s.setImageFile);
|
20 |
+
const originalImage = useMainStore(s => s.originalImage);
|
21 |
+
const setOriginalImage = useMainStore(s => s.setOriginalImage);
|
22 |
+
const previewImage = useMainStore(s => s.previewImage);
|
23 |
+
const setPreviewImage = useMainStore(s => s.setPreviewImage);
|
24 |
+
const resetImage = useMainStore(s => s.resetImage);
|
25 |
+
|
26 |
+
const {
|
27 |
+
status,
|
28 |
+
setStatus,
|
29 |
+
isDebugMode,
|
30 |
+
setIsDebugMode,
|
31 |
+
interruptMessage,
|
32 |
+
} = useFacePokeAPI()
|
33 |
+
|
34 |
+
// State for face detection
|
35 |
+
const {
|
36 |
+
canvasRef,
|
37 |
+
canvasRefCallback,
|
38 |
+
mediaPipeRef,
|
39 |
+
faceLandmarks,
|
40 |
+
isMediaPipeReady,
|
41 |
+
blendShapes,
|
42 |
+
|
43 |
+
setFaceLandmarks,
|
44 |
+
setBlendShapes,
|
45 |
+
|
46 |
+
handleMouseDown,
|
47 |
+
handleMouseUp,
|
48 |
+
handleMouseMove,
|
49 |
+
handleMouseEnter,
|
50 |
+
handleMouseLeave,
|
51 |
+
currentOpacity
|
52 |
+
} = useFaceLandmarkDetection()
|
53 |
+
|
54 |
+
// Refs
|
55 |
+
const videoRef = useRef<HTMLDivElement>(null);
|
56 |
+
|
57 |
+
// Handle file change
|
58 |
+
const handleFileChange = useCallback(async (event: React.ChangeEvent<HTMLInputElement>) => {
|
59 |
+
const files = event.target.files;
|
60 |
+
if (files && files[0]) {
|
61 |
+
setImageFile(files[0]);
|
62 |
+
setStatus(`File selected: ${truncateFileName(files[0].name, 16)}`);
|
63 |
+
|
64 |
+
try {
|
65 |
+
const image = await convertImageToBase64(files[0]);
|
66 |
+
setPreviewImage(image);
|
67 |
+
setOriginalImage(image);
|
68 |
+
} catch (err) {
|
69 |
+
console.log(`failed to convert the image: `, err);
|
70 |
+
setImageFile(null);
|
71 |
+
setStatus('');
|
72 |
+
setPreviewImage('');
|
73 |
+
setOriginalImage('');
|
74 |
+
setFaceLandmarks([]);
|
75 |
+
setBlendShapes([]);
|
76 |
+
}
|
77 |
+
} else {
|
78 |
+
setImageFile(null);
|
79 |
+
setStatus('');
|
80 |
+
setPreviewImage('');
|
81 |
+
setOriginalImage('');
|
82 |
+
setFaceLandmarks([]);
|
83 |
+
setBlendShapes([]);
|
84 |
+
}
|
85 |
+
}, [isMediaPipeReady, setImageFile, setPreviewImage, setOriginalImage, setFaceLandmarks, setBlendShapes, setStatus]);
|
86 |
+
|
87 |
+
const canDisplayBlendShapes = false
|
88 |
+
|
89 |
+
// Display blend shapes
|
90 |
+
const displayBlendShapes = useMemo(() => (
|
91 |
+
<div className="mt-4">
|
92 |
+
<h3 className="text-lg font-semibold mb-2">Blend Shapes</h3>
|
93 |
+
<ul className="space-y-1">
|
94 |
+
{(blendShapes?.[0]?.categories || []).map((shape, index) => (
|
95 |
+
<li key={index} className="flex items-center">
|
96 |
+
<span className="w-32 text-sm">{shape.categoryName || shape.displayName}</span>
|
97 |
+
<div className="w-full bg-gray-200 rounded-full h-2.5">
|
98 |
+
<div
|
99 |
+
className="bg-blue-600 h-2.5 rounded-full"
|
100 |
+
style={{ width: `${shape.score * 100}%` }}
|
101 |
+
></div>
|
102 |
+
</div>
|
103 |
+
<span className="ml-2 text-sm">{shape.score.toFixed(2)}</span>
|
104 |
+
</li>
|
105 |
+
))}
|
106 |
+
</ul>
|
107 |
+
</div>
|
108 |
+
), [JSON.stringify(blendShapes)])
|
109 |
+
|
110 |
+
// JSX
|
111 |
+
return (
|
112 |
+
<Layout>
|
113 |
+
{error && (
|
114 |
+
<Alert variant="destructive">
|
115 |
+
<AlertTitle>Error</AlertTitle>
|
116 |
+
<AlertDescription>{error}</AlertDescription>
|
117 |
+
</Alert>
|
118 |
+
)}
|
119 |
+
{interruptMessage && (
|
120 |
+
<Alert>
|
121 |
+
<AlertTitle>Notice</AlertTitle>
|
122 |
+
<AlertDescription>{interruptMessage}</AlertDescription>
|
123 |
+
</Alert>
|
124 |
+
)}
|
125 |
+
<div className="mb-5 relative">
|
126 |
+
<div className="flex flex-row items-center justify-between w-full">
|
127 |
+
<div className="relative">
|
128 |
+
<input
|
129 |
+
id="imageInput"
|
130 |
+
type="file"
|
131 |
+
accept="image/*"
|
132 |
+
onChange={handleFileChange}
|
133 |
+
className="hidden"
|
134 |
+
disabled={!isMediaPipeReady}
|
135 |
+
/>
|
136 |
+
<label
|
137 |
+
htmlFor="imageInput"
|
138 |
+
className={`cursor-pointer inline-flex items-center px-3 py-1.5 border border-transparent text-sm font-medium rounded-md text-white ${
|
139 |
+
isMediaPipeReady ? 'bg-gray-600 hover:bg-gray-500' : 'bg-gray-500 cursor-not-allowed'
|
140 |
+
} focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-gray-500 shadow-xl`}
|
141 |
+
>
|
142 |
+
<Spinner />
|
143 |
+
{imageFile ? truncateFileName(imageFile.name, 32) : (isMediaPipeReady ? 'Choose an image' : 'Initializing...')}
|
144 |
+
</label>
|
145 |
+
</div>
|
146 |
+
{previewImage && <label className="mt-4 flex items-center">
|
147 |
+
<input
|
148 |
+
type="checkbox"
|
149 |
+
checked={isDebugMode}
|
150 |
+
onChange={(e) => setIsDebugMode(e.target.checked)}
|
151 |
+
className="mr-2"
|
152 |
+
/>
|
153 |
+
Show face landmarks on hover
|
154 |
+
</label>}
|
155 |
+
</div>
|
156 |
+
{previewImage && (
|
157 |
+
<div className="mt-5 relative shadow-2xl rounded-xl overflow-hidden">
|
158 |
+
<img
|
159 |
+
src={previewImage}
|
160 |
+
alt="Preview"
|
161 |
+
className="w-full"
|
162 |
+
/>
|
163 |
+
<canvas
|
164 |
+
ref={canvasRefCallback}
|
165 |
+
className="absolute top-0 left-0 w-full h-full select-none"
|
166 |
+
onMouseEnter={handleMouseEnter}
|
167 |
+
onMouseLeave={handleMouseLeave}
|
168 |
+
onMouseDown={handleMouseDown}
|
169 |
+
onMouseUp={handleMouseUp}
|
170 |
+
onMouseMove={handleMouseMove}
|
171 |
+
style={{
|
172 |
+
position: 'absolute',
|
173 |
+
top: 0,
|
174 |
+
left: 0,
|
175 |
+
width: '100%',
|
176 |
+
height: '100%',
|
177 |
+
opacity: isDebugMode ? currentOpacity : 0.0,
|
178 |
+
transition: 'opacity 0.2s ease-in-out'
|
179 |
+
}}
|
180 |
+
|
181 |
+
/>
|
182 |
+
</div>
|
183 |
+
)}
|
184 |
+
{canDisplayBlendShapes && displayBlendShapes}
|
185 |
+
</div>
|
186 |
+
<PoweredBy />
|
187 |
+
|
188 |
+
</Layout>
|
189 |
+
);
|
190 |
+
}
|
client/src/components/DoubleCard.tsx
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { type ReactNode } from 'react';
|
2 |
+
|
3 |
+
export function DoubleCard({ children }: { children: ReactNode }) {
|
4 |
+
return (
|
5 |
+
<>
|
6 |
+
<div className="absolute inset-0 bg-gradient-to-r from-cyan-200 to-sky-300 shadow-2xl transform -skew-y-6 sm:skew-y-0 sm:-rotate-6 sm:rounded-3xl" style={{ borderTop: "solid 2px rgba(255, 255, 255, 0.2)" }}></div>
|
7 |
+
<div className="relative px-5 py-8 bg-gradient-to-r from-cyan-100 to-sky-200 shadow-2xl sm:rounded-3xl sm:p-12" style={{ borderTop: "solid 2px #ffffff33" }}>
|
8 |
+
<div className="max-w-lg mx-auto">
|
9 |
+
<div className="divide-y divide-gray-200">
|
10 |
+
<div className="text-lg leading-7 space-y-5 text-gray-700 sm:text-xl sm:leading-8">
|
11 |
+
{children}
|
12 |
+
</div>
|
13 |
+
</div>
|
14 |
+
</div>
|
15 |
+
</div>
|
16 |
+
</>
|
17 |
+
);
|
18 |
+
}
|
client/src/components/PoweredBy.tsx
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export function PoweredBy() {
|
2 |
+
return (
|
3 |
+
<div className="flex flex-row items-center justify-center font-sans mt-4 w-full">
|
4 |
+
{/*<span className="text-neutral-900 text-sm"
|
5 |
+
style={{ textShadow: "rgb(255 255 255 / 80%) 0px 0px 2px" }}>
|
6 |
+
Best hosted on
|
7 |
+
</span>*/}
|
8 |
+
<span className="ml-2 mr-1">
|
9 |
+
<img src="/hf-logo.svg" alt="Hugging Face" className="w-5 h-5" />
|
10 |
+
</span>
|
11 |
+
<span className="text-neutral-900 text-sm font-semibold"
|
12 |
+
style={{ textShadow: "rgb(255 255 255 / 80%) 0px 0px 2px" }}>
|
13 |
+
Hugging Face
|
14 |
+
</span>
|
15 |
+
</div>
|
16 |
+
)
|
17 |
+
}
|
client/src/components/Spinner.tsx
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export function Spinner() {
|
2 |
+
return (
|
3 |
+
<svg className="mr-3 h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
4 |
+
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" />
|
5 |
+
</svg>
|
6 |
+
)
|
7 |
+
}
|
client/src/components/Title.tsx
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export function Title() {
|
2 |
+
return (
|
3 |
+
<h2 className="bg-gradient-to-bl from-sky-500 to-sky-800 bg-clip-text text-5xl font-extrabold text-transparent leading-normal text-center"
|
4 |
+
style={{ textShadow: "rgb(176 229 255 / 61%) 0px 0px 2px" }}>
|
5 |
+
💬 FacePoke
|
6 |
+
</h2>
|
7 |
+
)
|
8 |
+
}
|
client/src/components/ui/alert.tsx
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import * as React from "react"
|
2 |
+
import { cva, type VariantProps } from "class-variance-authority"
|
3 |
+
|
4 |
+
import { cn } from "@/lib/utils"
|
5 |
+
|
6 |
+
const alertVariants = cva(
|
7 |
+
"relative w-full rounded-lg border p-4 [&>svg~*]:pl-7 [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg]:text-foreground",
|
8 |
+
{
|
9 |
+
variants: {
|
10 |
+
variant: {
|
11 |
+
default: "bg-background text-foreground",
|
12 |
+
destructive:
|
13 |
+
"border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive",
|
14 |
+
},
|
15 |
+
},
|
16 |
+
defaultVariants: {
|
17 |
+
variant: "default",
|
18 |
+
},
|
19 |
+
}
|
20 |
+
)
|
21 |
+
|
22 |
+
const Alert = React.forwardRef<
|
23 |
+
HTMLDivElement,
|
24 |
+
React.HTMLAttributes<HTMLDivElement> & VariantProps<typeof alertVariants>
|
25 |
+
>(({ className, variant, ...props }, ref) => (
|
26 |
+
<div
|
27 |
+
ref={ref}
|
28 |
+
role="alert"
|
29 |
+
className={cn(alertVariants({ variant }), className)}
|
30 |
+
{...props}
|
31 |
+
/>
|
32 |
+
))
|
33 |
+
Alert.displayName = "Alert"
|
34 |
+
|
35 |
+
const AlertTitle = React.forwardRef<
|
36 |
+
HTMLParagraphElement,
|
37 |
+
React.HTMLAttributes<HTMLHeadingElement>
|
38 |
+
>(({ className, ...props }, ref) => (
|
39 |
+
<h5
|
40 |
+
ref={ref}
|
41 |
+
className={cn("mb-1 font-medium leading-none tracking-tight", className)}
|
42 |
+
{...props}
|
43 |
+
/>
|
44 |
+
))
|
45 |
+
AlertTitle.displayName = "AlertTitle"
|
46 |
+
|
47 |
+
const AlertDescription = React.forwardRef<
|
48 |
+
HTMLParagraphElement,
|
49 |
+
React.HTMLAttributes<HTMLParagraphElement>
|
50 |
+
>(({ className, ...props }, ref) => (
|
51 |
+
<div
|
52 |
+
ref={ref}
|
53 |
+
className={cn("text-sm [&_p]:leading-relaxed", className)}
|
54 |
+
{...props}
|
55 |
+
/>
|
56 |
+
))
|
57 |
+
AlertDescription.displayName = "AlertDescription"
|
58 |
+
|
59 |
+
export { Alert, AlertTitle, AlertDescription }
|
client/src/hooks/landmarks.ts
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import * as vision from '@mediapipe/tasks-vision';
|
2 |
+
|
3 |
+
// Define unique colors for each landmark group
|
4 |
+
export const landmarkColors: { [key: string]: string } = {
|
5 |
+
lips: '#FF0000',
|
6 |
+
leftEye: '#00FF00',
|
7 |
+
leftEyebrow: '#0000FF',
|
8 |
+
leftIris: '#FFFF00',
|
9 |
+
rightEye: '#FF00FF',
|
10 |
+
rightEyebrow: '#00FFFF',
|
11 |
+
rightIris: '#FFA500',
|
12 |
+
faceOval: '#800080',
|
13 |
+
tesselation: '#C0C0C0',
|
14 |
+
};
|
15 |
+
|
16 |
+
// Define landmark groups with their semantic names
|
17 |
+
export const landmarkGroups: { [key: string]: any } = {
|
18 |
+
lips: vision.FaceLandmarker.FACE_LANDMARKS_LIPS,
|
19 |
+
leftEye: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYE,
|
20 |
+
leftEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYEBROW,
|
21 |
+
leftIris: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_IRIS,
|
22 |
+
rightEye: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYE,
|
23 |
+
rightEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYEBROW,
|
24 |
+
rightIris: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_IRIS,
|
25 |
+
faceOval: vision.FaceLandmarker.FACE_LANDMARKS_FACE_OVAL,
|
26 |
+
// tesselation: vision.FaceLandmarker.FACE_LANDMARKS_TESSELATION,
|
27 |
+
};
|
28 |
+
|
29 |
+
export const FACEMESH_LIPS = Object.freeze(new Set([[61, 146], [146, 91], [91, 181], [181, 84], [84, 17],
|
30 |
+
[17, 314], [314, 405], [405, 321], [321, 375],
|
31 |
+
[375, 291], [61, 185], [185, 40], [40, 39], [39, 37],
|
32 |
+
[37, 0], [0, 267],
|
33 |
+
[267, 269], [269, 270], [270, 409], [409, 291],
|
34 |
+
[78, 95], [95, 88], [88, 178], [178, 87], [87, 14],
|
35 |
+
[14, 317], [317, 402], [402, 318], [318, 324],
|
36 |
+
[324, 308], [78, 191], [191, 80], [80, 81], [81, 82],
|
37 |
+
[82, 13], [13, 312], [312, 311], [311, 310],
|
38 |
+
[310, 415], [415, 308]]))
|
39 |
+
|
40 |
+
export const FACEMESH_LEFT_EYE = Object.freeze(new Set([[263, 249], [249, 390], [390, 373], [373, 374],
|
41 |
+
[374, 380], [380, 381], [381, 382], [382, 362],
|
42 |
+
[263, 466], [466, 388], [388, 387], [387, 386],
|
43 |
+
[386, 385], [385, 384], [384, 398], [398, 362]]))
|
44 |
+
|
45 |
+
export const FACEMESH_LEFT_IRIS = Object.freeze(new Set([[474, 475], [475, 476], [476, 477],
|
46 |
+
[477, 474]]))
|
47 |
+
|
48 |
+
export const FACEMESH_LEFT_EYEBROW = Object.freeze(new Set([[276, 283], [283, 282], [282, 295],
|
49 |
+
[295, 285], [300, 293], [293, 334],
|
50 |
+
[334, 296], [296, 336]]))
|
51 |
+
|
52 |
+
export const FACEMESH_RIGHT_EYE = Object.freeze(new Set([[33, 7], [7, 163], [163, 144], [144, 145],
|
53 |
+
[145, 153], [153, 154], [154, 155], [155, 133],
|
54 |
+
[33, 246], [246, 161], [161, 160], [160, 159],
|
55 |
+
[159, 158], [158, 157], [157, 173], [173, 133]]))
|
56 |
+
|
57 |
+
export const FACEMESH_RIGHT_EYEBROW = Object.freeze(new Set([[46, 53], [53, 52], [52, 65], [65, 55],
|
58 |
+
[70, 63], [63, 105], [105, 66], [66, 107]]))
|
59 |
+
|
60 |
+
export const FACEMESH_RIGHT_IRIS = Object.freeze(new Set([[469, 470], [470, 471], [471, 472],
|
61 |
+
[472, 469]]))
|
62 |
+
|
63 |
+
export const FACEMESH_FACE_OVAL = Object.freeze(new Set([[10, 338], [338, 297], [297, 332], [332, 284],
|
64 |
+
[284, 251], [251, 389], [389, 356], [356, 454],
|
65 |
+
[454, 323], [323, 361], [361, 288], [288, 397],
|
66 |
+
[397, 365], [365, 379], [379, 378], [378, 400],
|
67 |
+
[400, 377], [377, 152], [152, 148], [148, 176],
|
68 |
+
[176, 149], [149, 150], [150, 136], [136, 172],
|
69 |
+
[172, 58], [58, 132], [132, 93], [93, 234],
|
70 |
+
[234, 127], [127, 162], [162, 21], [21, 54],
|
71 |
+
[54, 103], [103, 67], [67, 109], [109, 10]]))
|
72 |
+
|
73 |
+
export const FACEMESH_NOSE = Object.freeze(new Set([[168, 6], [6, 197], [197, 195], [195, 5],
|
74 |
+
[5, 4], [4, 1], [1, 19], [19, 94], [94, 2], [98, 97],
|
75 |
+
[97, 2], [2, 326], [326, 327], [327, 294],
|
76 |
+
[294, 278], [278, 344], [344, 440], [440, 275],
|
77 |
+
[275, 4], [4, 45], [45, 220], [220, 115], [115, 48],
|
78 |
+
[48, 64], [64, 98]]))
|
79 |
+
|
80 |
+
export const FACEMESH_CONTOURS = Object.freeze(new Set([
|
81 |
+
...FACEMESH_LIPS,
|
82 |
+
...FACEMESH_LEFT_EYE,
|
83 |
+
...FACEMESH_LEFT_EYEBROW,
|
84 |
+
...FACEMESH_RIGHT_EYE,
|
85 |
+
...FACEMESH_RIGHT_EYEBROW,
|
86 |
+
...FACEMESH_FACE_OVAL
|
87 |
+
]));
|
88 |
+
|
89 |
+
export const FACEMESH_IRISES = Object.freeze(new Set([
|
90 |
+
...FACEMESH_LEFT_IRIS,
|
91 |
+
...FACEMESH_RIGHT_IRIS
|
92 |
+
]));
|
93 |
+
|
94 |
+
export const FACEMESH_TESSELATION = Object.freeze(new Set([
|
95 |
+
[127, 34], [34, 139], [139, 127], [11, 0], [0, 37], [37, 11],
|
96 |
+
[232, 231], [231, 120], [120, 232], [72, 37], [37, 39], [39, 72],
|
97 |
+
[128, 121], [121, 47], [47, 128], [232, 121], [121, 128], [128, 232],
|
98 |
+
[104, 69], [69, 67], [67, 104], [175, 171], [171, 148], [148, 175],
|
99 |
+
[118, 50], [50, 101], [101, 118], [73, 39], [39, 40], [40, 73],
|
100 |
+
[9, 151], [151, 108], [108, 9], [48, 115], [115, 131], [131, 48],
|
101 |
+
[194, 204], [204, 211], [211, 194], [74, 40], [40, 185], [185, 74],
|
102 |
+
[80, 42], [42, 183], [183, 80], [40, 92], [92, 186], [186, 40],
|
103 |
+
[230, 229], [229, 118], [118, 230], [202, 212], [212, 214], [214, 202],
|
104 |
+
[83, 18], [18, 17], [17, 83], [76, 61], [61, 146], [146, 76],
|
105 |
+
[160, 29], [29, 30], [30, 160], [56, 157], [157, 173], [173, 56],
|
106 |
+
[106, 204], [204, 194], [194, 106], [135, 214], [214, 192], [192, 135],
|
107 |
+
[203, 165], [165, 98], [98, 203], [21, 71], [71, 68], [68, 21],
|
108 |
+
[51, 45], [45, 4], [4, 51], [144, 24], [24, 23], [23, 144],
|
109 |
+
[77, 146], [146, 91], [91, 77], [205, 50], [50, 187], [187, 205],
|
110 |
+
[201, 200], [200, 18], [18, 201], [91, 106], [106, 182], [182, 91],
|
111 |
+
[90, 91], [91, 181], [181, 90], [85, 84], [84, 17], [17, 85],
|
112 |
+
[206, 203], [203, 36], [36, 206], [148, 171], [171, 140], [140, 148],
|
113 |
+
[92, 40], [40, 39], [39, 92], [193, 189], [189, 244], [244, 193],
|
114 |
+
[159, 158], [158, 28], [28, 159], [247, 246], [246, 161], [161, 247],
|
115 |
+
[236, 3], [3, 196], [196, 236], [54, 68], [68, 104], [104, 54],
|
116 |
+
[193, 168], [168, 8], [8, 193], [117, 228], [228, 31], [31, 117],
|
117 |
+
[189, 193], [193, 55], [55, 189], [98, 97], [97, 99], [99, 98],
|
118 |
+
[126, 47], [47, 100], [100, 126], [166, 79], [79, 218], [218, 166],
|
119 |
+
[155, 154], [154, 26], [26, 155], [209, 49], [49, 131], [131, 209],
|
120 |
+
[135, 136], [136, 150], [150, 135], [47, 126], [126, 217], [217, 47],
|
121 |
+
[223, 52], [52, 53], [53, 223], [45, 51], [51, 134], [134, 45],
|
122 |
+
[211, 170], [170, 140], [140, 211], [67, 69], [69, 108], [108, 67],
|
123 |
+
[43, 106], [106, 91], [91, 43], [230, 119], [119, 120], [120, 230],
|
124 |
+
[226, 130], [130, 247], [247, 226], [63, 53], [53, 52], [52, 63],
|
125 |
+
[238, 20], [20, 242], [242, 238], [46, 70], [70, 156], [156, 46],
|
126 |
+
[78, 62], [62, 96], [96, 78], [46, 53], [53, 63], [63, 46],
|
127 |
+
[143, 34], [34, 227], [227, 143], [123, 117], [117, 111], [111, 123],
|
128 |
+
[44, 125], [125, 19], [19, 44], [236, 134], [134, 51], [51, 236],
|
129 |
+
[216, 206], [206, 205], [205, 216], [154, 153], [153, 22], [22, 154],
|
130 |
+
[39, 37], [37, 167], [167, 39], [200, 201], [201, 208], [208, 200],
|
131 |
+
[36, 142], [142, 100], [100, 36], [57, 212], [212, 202], [202, 57],
|
132 |
+
[20, 60], [60, 99], [99, 20], [28, 158], [158, 157], [157, 28],
|
133 |
+
[35, 226], [226, 113], [113, 35], [160, 159], [159, 27], [27, 160],
|
134 |
+
[204, 202], [202, 210], [210, 204], [113, 225], [225, 46], [46, 113],
|
135 |
+
[43, 202], [202, 204], [204, 43], [62, 76], [76, 77], [77, 62],
|
136 |
+
[137, 123], [123, 116], [116, 137], [41, 38], [38, 72], [72, 41],
|
137 |
+
[203, 129], [129, 142], [142, 203], [64, 98], [98, 240], [240, 64],
|
138 |
+
[49, 102], [102, 64], [64, 49], [41, 73], [73, 74], [74, 41],
|
139 |
+
[212, 216], [216, 207], [207, 212], [42, 74], [74, 184], [184, 42],
|
140 |
+
[169, 170], [170, 211], [211, 169], [170, 149], [149, 176], [176, 170],
|
141 |
+
[105, 66], [66, 69], [69, 105], [122, 6], [6, 168], [168, 122],
|
142 |
+
[123, 147], [147, 187], [187, 123], [96, 77], [77, 90], [90, 96],
|
143 |
+
[65, 55], [55, 107], [107, 65], [89, 90], [90, 180], [180, 89],
|
144 |
+
[101, 100], [100, 120], [120, 101], [63, 105], [105, 104], [104, 63],
|
145 |
+
[93, 137], [137, 227], [227, 93], [15, 86], [86, 85], [85, 15],
|
146 |
+
[129, 102], [102, 49], [49, 129], [14, 87], [87, 86], [86, 14],
|
147 |
+
[55, 8], [8, 9], [9, 55], [100, 47], [47, 121], [121, 100],
|
148 |
+
[145, 23], [23, 22], [22, 145], [88, 89], [89, 179], [179, 88],
|
149 |
+
[6, 122], [122, 196], [196, 6], [88, 95], [95, 96], [96, 88],
|
150 |
+
[138, 172], [172, 136], [136, 138], [215, 58], [58, 172], [172, 215],
|
151 |
+
[115, 48], [48, 219], [219, 115], [42, 80], [80, 81], [81, 42],
|
152 |
+
[195, 3], [3, 51], [51, 195], [43, 146], [146, 61], [61, 43],
|
153 |
+
[171, 175], [175, 199], [199, 171], [81, 82], [82, 38], [38, 81],
|
154 |
+
[53, 46], [46, 225], [225, 53], [144, 163], [163, 110], [110, 144],
|
155 |
+
[52, 65], [65, 66], [66, 52], [229, 228], [228, 117], [117, 229],
|
156 |
+
[34, 127], [127, 234], [234, 34], [107, 108], [108, 69], [69, 107],
|
157 |
+
[109, 108], [108, 151], [151, 109], [48, 64], [64, 235], [235, 48],
|
158 |
+
[62, 78], [78, 191], [191, 62], [129, 209], [209, 126], [126, 129],
|
159 |
+
[111, 35], [35, 143], [143, 111], [117, 123], [123, 50], [50, 117],
|
160 |
+
[222, 65], [65, 52], [52, 222], [19, 125], [125, 141], [141, 19],
|
161 |
+
[221, 55], [55, 65], [65, 221], [3, 195], [195, 197], [197, 3],
|
162 |
+
[25, 7], [7, 33], [33, 25], [220, 237], [237, 44], [44, 220],
|
163 |
+
[70, 71], [71, 139], [139, 70], [122, 193], [193, 245], [245, 122],
|
164 |
+
[247, 130], [130, 33], [33, 247], [71, 21], [21, 162], [162, 71],
|
165 |
+
[170, 169], [169, 150], [150, 170], [188, 174], [174, 196], [196, 188],
|
166 |
+
[216, 186], [186, 92], [92, 216], [2, 97], [97, 167], [167, 2],
|
167 |
+
[141, 125], [125, 241], [241, 141], [164, 167], [167, 37], [37, 164],
|
168 |
+
[72, 38], [38, 12], [12, 72], [38, 82], [82, 13], [13, 38],
|
169 |
+
[63, 68], [68, 71], [71, 63], [226, 35], [35, 111], [111, 226],
|
170 |
+
[101, 50], [50, 205], [205, 101], [206, 92], [92, 165], [165, 206],
|
171 |
+
[209, 198], [198, 217], [217, 209], [165, 167], [167, 97], [97, 165],
|
172 |
+
[220, 115], [115, 218], [218, 220], [133, 112], [112, 243], [243, 133],
|
173 |
+
[239, 238], [238, 241], [241, 239], [214, 135], [135, 169], [169, 214],
|
174 |
+
[190, 173], [173, 133], [133, 190], [171, 208], [208, 32], [32, 171],
|
175 |
+
[125, 44], [44, 237], [237, 125], [86, 87], [87, 178], [178, 86],
|
176 |
+
[85, 86], [86, 179], [179, 85], [84, 85], [85, 180], [180, 84],
|
177 |
+
[83, 84], [84, 181], [181, 83], [201, 83], [83, 182], [182, 201],
|
178 |
+
[137, 93], [93, 132], [132, 137], [76, 62], [62, 183], [183, 76],
|
179 |
+
[61, 76], [76, 184], [184, 61], [57, 61], [61, 185], [185, 57],
|
180 |
+
[212, 57], [57, 186], [186, 212], [214, 207], [207, 187], [187, 214],
|
181 |
+
[34, 143], [143, 156], [156, 34], [79, 239], [239, 237], [237, 79],
|
182 |
+
[123, 137], [137, 177], [177, 123], [44, 1], [1, 4], [4, 44],
|
183 |
+
[201, 194], [194, 32], [32, 201], [64, 102], [102, 129], [129, 64],
|
184 |
+
[213, 215], [215, 138], [138, 213], [59, 166], [166, 219], [219, 59],
|
185 |
+
[242, 99], [99, 97], [97, 242], [2, 94], [94, 141], [141, 2],
|
186 |
+
[75, 59], [59, 235], [235, 75], [24, 110], [110, 228], [228, 24],
|
187 |
+
[25, 130], [130, 226], [226, 25], [23, 24], [24, 229], [229, 23],
|
188 |
+
[22, 23], [23, 230], [230, 22], [26, 22], [22, 231], [231, 26],
|
189 |
+
[112, 26], [26, 232], [232, 112], [189, 190], [190, 243], [243, 189],
|
190 |
+
[221, 56], [56, 190], [190, 221], [28, 56], [56, 221], [221, 28],
|
191 |
+
[27, 28], [28, 222], [222, 27], [29, 27], [27, 223], [223, 29],
|
192 |
+
[30, 29], [29, 224], [224, 30], [247, 30], [30, 225], [225, 247],
|
193 |
+
[238, 79], [79, 20], [20, 238], [166, 59], [59, 75], [75, 166],
|
194 |
+
[60, 75], [75, 240], [240, 60], [147, 177], [177, 215], [215, 147],
|
195 |
+
[20, 79], [79, 166], [166, 20], [187, 147], [147, 213], [213, 187],
|
196 |
+
[112, 233], [233, 244], [244, 112], [233, 128], [128, 245], [245, 233],
|
197 |
+
[128, 114], [114, 188], [188, 128], [114, 217], [217, 174], [174, 114],
|
198 |
+
[131, 115], [115, 220], [220, 131], [217, 198], [198, 236], [236, 217],
|
199 |
+
[198, 131], [131, 134], [134, 198], [177, 132], [132, 58], [58, 177],
|
200 |
+
[143, 35], [35, 124], [124, 143], [110, 163], [163, 7], [7, 110],
|
201 |
+
[228, 110], [110, 25], [25, 228], [356, 389], [389, 368], [368, 356],
|
202 |
+
[11, 302], [302, 267], [267, 11], [452, 350], [350, 349], [349, 452],
|
203 |
+
[302, 303], [303, 269], [269, 302], [357, 343], [343, 277], [277, 357],
|
204 |
+
[452, 453], [453, 357], [357, 452], [333, 332], [332, 297], [297, 333],
|
205 |
+
[175, 152], [152, 377], [377, 175], [347, 348], [348, 330], [330, 347],
|
206 |
+
[303, 304], [304, 270], [270, 303], [9, 336], [336, 337], [337, 9],
|
207 |
+
[278, 279], [279, 360], [360, 278], [418, 262], [262, 431], [431, 418],
|
208 |
+
[304, 408], [408, 409], [409, 304], [310, 415], [415, 407], [407, 310],
|
209 |
+
[270, 409], [409, 410], [410, 270], [450, 348], [348, 347], [347, 450],
|
210 |
+
[422, 430], [430, 434], [434, 422], [313, 314], [314, 17], [17, 313],
|
211 |
+
[306, 307], [307, 375], [375, 306], [387, 388], [388, 260], [260, 387],
|
212 |
+
[286, 414], [414, 398], [398, 286], [335, 406], [406, 418], [418, 335],
|
213 |
+
[364, 367], [367, 416], [416, 364], [423, 358], [358, 327], [327, 423],
|
214 |
+
[251, 284], [284, 298], [298, 251], [281, 5], [5, 4], [4, 281],
|
215 |
+
[373, 374], [374, 253], [253, 373], [307, 320], [320, 321], [321, 307],
|
216 |
+
[425, 427], [427, 411], [411, 425], [421, 313], [313, 18], [18, 421],
|
217 |
+
[321, 405], [405, 406], [406, 321], [320, 404], [404, 405], [405, 320],
|
218 |
+
[315, 16], [16, 17], [17, 315], [426, 425], [425, 266], [266, 426],
|
219 |
+
[377, 400], [400, 369], [369, 377], [322, 391], [391, 269], [269, 322],
|
220 |
+
[417, 465], [465, 464], [464, 417], [386, 257], [257, 258], [258, 386],
|
221 |
+
[466, 260], [260, 388], [388, 466], [456, 399], [399, 419], [419, 456],
|
222 |
+
[284, 332], [332, 333], [333, 284], [417, 285], [285, 8], [8, 417],
|
223 |
+
[346, 340], [340, 261], [261, 346], [413, 441], [441, 285], [285, 413],
|
224 |
+
[327, 460], [460, 328], [328, 327], [355, 371], [371, 329], [329, 355],
|
225 |
+
[392, 439], [439, 438], [438, 392], [382, 341], [341, 256], [256, 382],
|
226 |
+
[429, 420], [420, 360], [360, 429], [364, 394], [394, 379], [379, 364],
|
227 |
+
[277, 343], [343, 437], [437, 277], [443, 444], [444, 283], [283, 443],
|
228 |
+
[275, 440], [440, 363], [363, 275], [431, 262], [262, 369], [369, 431],
|
229 |
+
[297, 338], [338, 337], [337, 297], [273, 375], [375, 321], [321, 273],
|
230 |
+
[450, 451], [451, 349], [349, 450], [446, 342], [342, 467], [467, 446],
|
231 |
+
[293, 334], [334, 282], [282, 293], [458, 461], [461, 462], [462, 458],
|
232 |
+
[276, 353], [353, 383], [383, 276], [308, 324], [324, 325], [325, 308],
|
233 |
+
[276, 300], [300, 293], [293, 276], [372, 345], [345, 447], [447, 372],
|
234 |
+
[352, 345], [345, 340], [340, 352], [274, 1], [1, 19], [19, 274],
|
235 |
+
[456, 248], [248, 281], [281, 456], [436, 427], [427, 425], [425, 436],
|
236 |
+
[381, 256], [256, 252], [252, 381], [269, 391], [391, 393], [393, 269],
|
237 |
+
[200, 199], [199, 428], [428, 200], [266, 330], [330, 329], [329, 266],
|
238 |
+
[287, 273], [273, 422], [422, 287], [250, 462], [462, 328], [328, 250],
|
239 |
+
[258, 286], [286, 384], [384, 258], [265, 353], [353, 342], [342, 265],
|
240 |
+
[387, 259], [259, 257], [257, 387], [424, 431], [431, 430], [430, 424],
|
241 |
+
[342, 353], [353, 276], [276, 342], [273, 335], [335, 424], [424, 273],
|
242 |
+
[292, 325], [325, 307], [307, 292], [366, 447], [447, 345], [345, 366],
|
243 |
+
[271, 303], [303, 302], [302, 271], [423, 266], [266, 371], [371, 423],
|
244 |
+
[294, 455], [455, 460], [460, 294], [279, 278], [278, 294], [294, 279],
|
245 |
+
[271, 272], [272, 304], [304, 271], [432, 434], [434, 427], [427, 432],
|
246 |
+
[272, 407], [407, 408], [408, 272], [394, 430], [430, 431], [431, 394],
|
247 |
+
[395, 369], [369, 400], [400, 395], [334, 333], [333, 299], [299, 334],
|
248 |
+
[351, 417], [417, 168], [168, 351], [352, 280], [280, 411], [411, 352],
|
249 |
+
[325, 319], [319, 320], [320, 325], [295, 296], [296, 336], [336, 295],
|
250 |
+
[319, 403], [403, 404], [404, 319], [330, 348], [348, 349], [349, 330],
|
251 |
+
[293, 298], [298, 333], [333, 293], [323, 454], [454, 447], [447, 323],
|
252 |
+
[15, 16], [16, 315], [315, 15], [358, 429], [429, 279], [279, 358],
|
253 |
+
[14, 15], [15, 316], [316, 14], [285, 336], [336, 9], [9, 285],
|
254 |
+
[329, 349], [349, 350], [350, 329], [374, 380], [380, 252], [252, 374],
|
255 |
+
[318, 402], [402, 403], [403, 318], [6, 197], [197, 419], [419, 6],
|
256 |
+
[318, 319], [319, 325], [325, 318], [367, 364], [364, 365], [365, 367],
|
257 |
+
[435, 367], [367, 397], [397, 435], [344, 438], [438, 439], [439, 344],
|
258 |
+
[272, 271], [271, 311], [311, 272], [195, 5], [5, 281], [281, 195],
|
259 |
+
[273, 287], [287, 291], [291, 273], [396, 428], [428, 199], [199, 396],
|
260 |
+
[311, 271], [271, 268], [268, 311], [283, 444], [444, 445], [445, 283],
|
261 |
+
[373, 254], [254, 339], [339, 373], [282, 334], [334, 296], [296, 282],
|
262 |
+
[449, 347], [347, 346], [346, 449], [264, 447], [447, 454], [454, 264],
|
263 |
+
[336, 296], [296, 299], [299, 336], [338, 10], [10, 151], [151, 338],
|
264 |
+
[278, 439], [439, 455], [455, 278], [292, 407], [407, 415], [415, 292],
|
265 |
+
[358, 371], [371, 355], [355, 358], [340, 345], [345, 372], [372, 340],
|
266 |
+
[346, 347], [347, 280], [280, 346], [442, 443], [443, 282], [282, 442],
|
267 |
+
[19, 94], [94, 370], [370, 19], [441, 442], [442, 295], [295, 441],
|
268 |
+
[248, 419], [419, 197], [197, 248], [263, 255], [255, 359], [359, 263],
|
269 |
+
[440, 275], [275, 274], [274, 440], [300, 383], [383, 368], [368, 300],
|
270 |
+
[351, 412], [412, 465], [465, 351], [263, 467], [467, 466], [466, 263],
|
271 |
+
[301, 368], [368, 389], [389, 301], [395, 378], [378, 379], [379, 395],
|
272 |
+
[412, 351], [351, 419], [419, 412], [436, 426], [426, 322], [322, 436],
|
273 |
+
[2, 164], [164, 393], [393, 2], [370, 462], [462, 461], [461, 370],
|
274 |
+
[164, 0], [0, 267], [267, 164], [302, 11], [11, 12], [12, 302],
|
275 |
+
[268, 12], [12, 13], [13, 268], [293, 300], [300, 301], [301, 293],
|
276 |
+
[446, 261], [261, 340], [340, 446], [330, 266], [266, 425], [425, 330],
|
277 |
+
[426, 423], [423, 391], [391, 426], [429, 355], [355, 437], [437, 429],
|
278 |
+
[391, 327], [327, 326], [326, 391], [440, 457], [457, 438], [438, 440],
|
279 |
+
[341, 382], [382, 362], [362, 341], [459, 457], [457, 461], [461, 459],
|
280 |
+
[434, 430], [430, 394], [394, 434], [414, 463], [463, 362], [362, 414],
|
281 |
+
[396, 369], [369, 262], [262, 396], [354, 461], [461, 457], [457, 354],
|
282 |
+
[316, 403], [403, 402], [402, 316], [315, 404], [404, 403], [403, 315],
|
283 |
+
[314, 405], [405, 404], [404, 314], [313, 406], [406, 405], [405, 313],
|
284 |
+
[421, 418], [418, 406], [406, 421], [366, 401], [401, 361], [361, 366],
|
285 |
+
[306, 408], [408, 407], [407, 306], [291, 409], [409, 408], [408, 291],
|
286 |
+
[287, 410], [410, 409], [409, 287], [432, 436], [436, 410], [410, 432],
|
287 |
+
[434, 416], [416, 411], [411, 434], [264, 368], [368, 383], [383, 264],
|
288 |
+
[309, 438], [438, 457], [457, 309], [352, 376], [376, 401], [401, 352],
|
289 |
+
[274, 275], [275, 4], [4, 274], [421, 428], [428, 262], [262, 421],
|
290 |
+
[294, 327], [327, 358], [358, 294], [433, 416], [416, 367], [367, 433],
|
291 |
+
[289, 455], [455, 439], [439, 289], [462, 370], [370, 326], [326, 462],
|
292 |
+
[2, 326], [326, 370], [370, 2], [305, 460], [460, 455], [455, 305],
|
293 |
+
[254, 449], [449, 448], [448, 254], [255, 261], [261, 446], [446, 255],
|
294 |
+
[253, 450], [450, 449], [449, 253], [252, 451], [451, 450], [450, 252],
|
295 |
+
[256, 452], [452, 451], [451, 256], [341, 453], [453, 452], [452, 341],
|
296 |
+
[413, 464], [464, 463], [463, 413], [441, 413], [413, 414], [414, 441],
|
297 |
+
[258, 442], [442, 441], [441, 258], [257, 443], [443, 442], [442, 257],
|
298 |
+
[259, 444], [444, 443], [443, 259], [260, 445], [445, 444], [444, 260],
|
299 |
+
[467, 342], [342, 445], [445, 467], [459, 458], [458, 250], [250, 459],
|
300 |
+
[289, 392], [392, 290], [290, 289], [290, 328], [328, 460], [460, 290],
|
301 |
+
[376, 433], [433, 435], [435, 376], [250, 290], [290, 392], [392, 250],
|
302 |
+
[411, 416], [416, 433], [433, 411], [341, 463], [463, 464], [464, 341],
|
303 |
+
[453, 464], [464, 465], [465, 453], [357, 465], [465, 412], [412, 357],
|
304 |
+
[343, 412], [412, 399], [399, 343], [360, 363], [363, 440], [440, 360],
|
305 |
+
[437, 399], [399, 456], [456, 437], [420, 456], [456, 363], [363, 420],
|
306 |
+
[401, 435], [435, 288], [288, 401], [372, 383], [383, 353], [353, 372],
|
307 |
+
[339, 255], [255, 249], [249, 339], [448, 261], [261, 255], [255, 448],
|
308 |
+
[133, 243], [243, 190], [190, 133], [133, 155], [155, 112], [112, 133],
|
309 |
+
[33, 246], [246, 247], [247, 33], [33, 130], [130, 25], [25, 33],
|
310 |
+
[398, 384], [384, 286], [286, 398], [362, 398], [398, 414], [414, 362],
|
311 |
+
[362, 463], [463, 341], [341, 362], [263, 359], [359, 467], [467, 263],
|
312 |
+
[263, 249], [249, 255], [255, 263], [466, 467], [467, 260], [260, 466],
|
313 |
+
[75, 60], [60, 166], [166, 75], [238, 239], [239, 79], [79, 238],
|
314 |
+
[162, 127], [127, 139], [139, 162], [72, 11], [11, 37], [37, 72],
|
315 |
+
[121, 232], [232, 120], [120, 121], [73, 72], [72, 39], [39, 73],
|
316 |
+
[114, 128], [128, 47], [47, 114], [233, 232], [232, 128], [128, 233],
|
317 |
+
[103, 104], [104, 67], [67, 103], [152, 175], [175, 148], [148, 152],
|
318 |
+
[119, 118], [118, 101], [101, 119], [74, 73], [73, 40], [40, 74],
|
319 |
+
[107, 9], [9, 108], [108, 107], [49, 48], [48, 131], [131, 49],
|
320 |
+
[32, 194], [194, 211], [211, 32], [184, 74], [74, 185], [185, 184],
|
321 |
+
[191, 80], [80, 183], [183, 191], [185, 40], [40, 186], [186, 185],
|
322 |
+
[119, 230], [230, 118], [118, 119], [210, 202], [202, 214], [214, 210],
|
323 |
+
[84, 83], [83, 17], [17, 84], [77, 76], [76, 146], [146, 77],
|
324 |
+
[161, 160], [160, 30], [30, 161], [190, 56], [56, 173], [173, 190],
|
325 |
+
[182, 106], [106, 194], [194, 182], [138, 135], [135, 192], [192, 138],
|
326 |
+
[129, 203], [203, 98], [98, 129], [54, 21], [21, 68], [68, 54],
|
327 |
+
[5, 51], [51, 4], [4, 5], [145, 144], [144, 23], [23, 145],
|
328 |
+
[90, 77], [77, 91], [91, 90], [207, 205], [205, 187], [187, 207],
|
329 |
+
[83, 201], [201, 18], [18, 83], [181, 91], [91, 182], [182, 181],
|
330 |
+
[180, 90], [90, 181], [181, 180], [16, 85], [85, 17], [17, 16],
|
331 |
+
[205, 206], [206, 36], [36, 205], [176, 148], [148, 140], [140, 176],
|
332 |
+
[165, 92], [92, 39], [39, 165], [245, 193], [193, 244], [244, 245],
|
333 |
+
[27, 159], [159, 28], [28, 27], [30, 247], [247, 161], [161, 30],
|
334 |
+
[174, 236], [236, 196], [196, 174], [103, 54], [54, 104], [104, 103],
|
335 |
+
[55, 193], [193, 8], [8, 55], [111, 117], [117, 31], [31, 111],
|
336 |
+
[221, 189], [189, 55], [55, 221], [240, 98], [98, 99], [99, 240],
|
337 |
+
[142, 126], [126, 100], [100, 142], [219, 166], [166, 218], [218, 219],
|
338 |
+
[112, 155], [155, 26], [26, 112], [198, 209], [209, 131], [131, 198],
|
339 |
+
[169, 135], [135, 150], [150, 169], [114, 47], [47, 217], [217, 114],
|
340 |
+
[224, 223], [223, 53], [53, 224], [220, 45], [45, 134], [134, 220],
|
341 |
+
[32, 211], [211, 140], [140, 32], [109, 67], [67, 108], [108, 109],
|
342 |
+
[146, 43], [43, 91], [91, 146], [231, 230], [230, 120], [120, 231],
|
343 |
+
[113, 226], [226, 247], [247, 113], [105, 63], [63, 52], [52, 105],
|
344 |
+
[241, 238], [238, 242], [242, 241], [124, 46], [46, 156], [156, 124],
|
345 |
+
[95, 78], [78, 96], [96, 95], [70, 46], [46, 63], [63, 70],
|
346 |
+
[116, 143], [143, 227], [227, 116], [116, 123], [123, 111], [111, 116],
|
347 |
+
[1, 44], [44, 19], [19, 1], [3, 236], [236, 51], [51, 3],
|
348 |
+
[207, 216], [216, 205], [205, 207], [26, 154], [154, 22], [22, 26],
|
349 |
+
[165, 39], [39, 167], [167, 165], [199, 200], [200, 208], [208, 199],
|
350 |
+
[101, 36], [36, 100], [100, 101], [43, 57], [57, 202], [202, 43],
|
351 |
+
[242, 20], [20, 99], [99, 242], [56, 28], [28, 157], [157, 56],
|
352 |
+
[124, 35], [35, 113], [113, 124], [29, 160], [160, 27], [27, 29],
|
353 |
+
[211, 204], [204, 210], [210, 211], [124, 113], [113, 46], [46, 124],
|
354 |
+
[106, 43], [43, 204], [204, 106], [96, 62], [62, 77], [77, 96],
|
355 |
+
[227, 137], [137, 116], [116, 227], [73, 41], [41, 72], [72, 73],
|
356 |
+
[36, 203], [203, 142], [142, 36], [235, 64], [64, 240], [240, 235],
|
357 |
+
[48, 49], [49, 64], [64, 48], [42, 41], [41, 74], [74, 42],
|
358 |
+
[214, 212], [212, 207], [207, 214], [183, 42], [42, 184], [184, 183],
|
359 |
+
[210, 169], [169, 211], [211, 210], [140, 170], [170, 176], [176, 140],
|
360 |
+
[104, 105], [105, 69], [69, 104], [193, 122], [122, 168], [168, 193],
|
361 |
+
[50, 123], [123, 187], [187, 50], [89, 96], [96, 90], [90, 89],
|
362 |
+
[66, 65], [65, 107], [107, 66], [179, 89], [89, 180], [180, 179],
|
363 |
+
[119, 101], [101, 120], [120, 119], [68, 63], [63, 104], [104, 68],
|
364 |
+
[234, 93], [93, 227], [227, 234], [16, 15], [15, 85], [85, 16],
|
365 |
+
[209, 129], [129, 49], [49, 209], [15, 14], [14, 86], [86, 15],
|
366 |
+
[107, 55], [55, 9], [9, 107], [120, 100], [100, 121], [121, 120],
|
367 |
+
[153, 145], [145, 22], [22, 153], [178, 88], [88, 179], [179, 178],
|
368 |
+
[197, 6], [6, 196], [196, 197], [89, 88], [88, 96], [96, 89],
|
369 |
+
[135, 138], [138, 136], [136, 135], [138, 215], [215, 172], [172, 138],
|
370 |
+
[218, 115], [115, 219], [219, 218], [41, 42], [42, 81], [81, 41],
|
371 |
+
[5, 195], [195, 51], [51, 5], [57, 43], [43, 61], [61, 57],
|
372 |
+
[208, 171], [171, 199], [199, 208], [41, 81], [81, 38], [38, 41],
|
373 |
+
[224, 53], [53, 225], [225, 224], [24, 144], [144, 110], [110, 24],
|
374 |
+
[105, 52], [52, 66], [66, 105], [118, 229], [229, 117], [117, 118],
|
375 |
+
[227, 34], [34, 234], [234, 227], [66, 107], [107, 69], [69, 66],
|
376 |
+
[10, 109], [109, 151], [151, 10], [219, 48], [48, 235], [235, 219],
|
377 |
+
[183, 62], [62, 191], [191, 183], [142, 129], [129, 126], [126, 142],
|
378 |
+
[116, 111], [111, 143], [143, 116], [118, 117], [117, 50], [50, 118],
|
379 |
+
[223, 222], [222, 52], [52, 223], [94, 19], [19, 141], [141, 94],
|
380 |
+
[222, 221], [221, 65], [65, 222], [196, 3], [3, 197], [197, 196],
|
381 |
+
[45, 220], [220, 44], [44, 45], [156, 70], [70, 139], [139, 156],
|
382 |
+
[188, 122], [122, 245], [245, 188], [139, 71], [71, 162], [162, 139],
|
383 |
+
[149, 170], [170, 150], [150, 149], [122, 188], [188, 196], [196, 122],
|
384 |
+
[206, 216], [216, 92], [92, 206], [164, 2], [2, 167], [167, 164],
|
385 |
+
[242, 141], [141, 241], [241, 242], [0, 164], [164, 37], [37, 0],
|
386 |
+
[11, 72], [72, 12], [12, 11], [12, 38], [38, 13], [13, 12],
|
387 |
+
[70, 63], [63, 71], [71, 70], [31, 226], [226, 111], [111, 31],
|
388 |
+
[36, 101], [101, 205], [205, 36], [203, 206], [206, 165], [165, 203],
|
389 |
+
[126, 209], [209, 217], [217, 126], [98, 165], [165, 97], [97, 98],
|
390 |
+
[237, 220], [220, 218], [218, 237], [237, 239], [239, 241], [241, 237],
|
391 |
+
[210, 214], [214, 169], [169, 210], [140, 171], [171, 32], [32, 140],
|
392 |
+
[241, 125], [125, 237], [237, 241], [179, 86], [86, 178], [178, 179],
|
393 |
+
[180, 85], [85, 179], [179, 180], [181, 84], [84, 180], [180, 181],
|
394 |
+
[182, 83], [83, 181], [181, 182], [194, 201], [201, 182], [182, 194],
|
395 |
+
[177, 137], [137, 132], [132, 177], [184, 76], [76, 183], [183, 184],
|
396 |
+
[185, 61], [61, 184], [184, 185], [186, 57], [57, 185], [185, 186],
|
397 |
+
[216, 212], [212, 186], [186, 216], [192, 214], [214, 187], [187, 192],
|
398 |
+
[139, 34], [34, 156], [156, 139], [218, 79], [79, 237], [237, 218],
|
399 |
+
[147, 123], [123, 177], [177, 147], [45, 44], [44, 4], [4, 45],
|
400 |
+
[208, 201], [201, 32], [32, 208], [98, 64], [64, 129], [129, 98],
|
401 |
+
[192, 213], [213, 138], [138, 192], [235, 59], [59, 219], [219, 235],
|
402 |
+
[141, 242], [242, 97], [97, 141], [97, 2], [2, 141], [141, 97],
|
403 |
+
[240, 75], [75, 235], [235, 240], [229, 24], [24, 228], [228, 229],
|
404 |
+
[31, 25], [25, 226], [226, 31], [230, 23], [23, 229], [229, 230],
|
405 |
+
[231, 22], [22, 230], [230, 231], [232, 26], [26, 231], [231, 232],
|
406 |
+
[233, 112], [112, 232], [232, 233], [244, 189], [189, 243], [243, 244],
|
407 |
+
[189, 221], [221, 190], [190, 189], [222, 28], [28, 221], [221, 222],
|
408 |
+
[223, 27], [27, 222], [222, 223], [224, 29], [29, 223], [223, 224],
|
409 |
+
[225, 30], [30, 224], [224, 225], [113, 247], [247, 225], [225, 113],
|
410 |
+
[99, 60], [60, 240], [240, 99], [213, 147], [147, 215], [215, 213],
|
411 |
+
[60, 20], [20, 166], [166, 60], [192, 187], [187, 213], [213, 192],
|
412 |
+
[243, 112], [112, 244], [244, 243], [244, 233], [233, 245], [245, 244],
|
413 |
+
[245, 128], [128, 188], [188, 245], [188, 114], [114, 174], [174, 188],
|
414 |
+
[134, 131], [131, 220], [220, 134], [174, 217], [217, 236], [236, 174],
|
415 |
+
[236, 198], [198, 134], [134, 236], [215, 177], [177, 58], [58, 215],
|
416 |
+
[156, 143], [143, 124], [124, 156], [25, 110], [110, 7], [7, 25],
|
417 |
+
[31, 228], [228, 25], [25, 31], [264, 356], [356, 368], [368, 264],
|
418 |
+
[0, 11], [11, 267], [267, 0], [451, 452], [452, 349], [349, 451],
|
419 |
+
[267, 302], [302, 269], [269, 267], [350, 357], [357, 277], [277, 350],
|
420 |
+
[350, 452], [452, 357], [357, 350], [299, 333], [333, 297], [297, 299],
|
421 |
+
[396, 175], [175, 377], [377, 396], [280, 347], [347, 330], [330, 280],
|
422 |
+
[269, 303], [303, 270], [270, 269], [151, 9], [9, 337], [337, 151],
|
423 |
+
[344, 278], [278, 360], [360, 344], [424, 418], [418, 431], [431, 424],
|
424 |
+
[270, 304], [304, 409], [409, 270], [272, 310], [310, 407], [407, 272],
|
425 |
+
[322, 270], [270, 410], [410, 322], [449, 450], [450, 347], [347, 449],
|
426 |
+
[432, 422], [422, 434], [434, 432], [18, 313], [313, 17], [17, 18],
|
427 |
+
[291, 306], [306, 375], [375, 291], [259, 387], [387, 260], [260, 259],
|
428 |
+
[424, 335], [335, 418], [418, 424], [434, 364], [364, 416], [416, 434],
|
429 |
+
[391, 423], [423, 327], [327, 391], [301, 251], [251, 298], [298, 301],
|
430 |
+
[275, 281], [281, 4], [4, 275], [254, 373], [373, 253], [253, 254],
|
431 |
+
[375, 307], [307, 321], [321, 375], [280, 425], [425, 411], [411, 280],
|
432 |
+
[200, 421], [421, 18], [18, 200], [335, 321], [321, 406], [406, 335],
|
433 |
+
[321, 320], [320, 405], [405, 321], [314, 315], [315, 17], [17, 314],
|
434 |
+
[423, 426], [426, 266], [266, 423], [396, 377], [377, 369], [369, 396],
|
435 |
+
[270, 322], [322, 269], [269, 270], [413, 417], [417, 464], [464, 413],
|
436 |
+
[385, 386], [386, 258], [258, 385], [248, 456], [456, 419], [419, 248],
|
437 |
+
[298, 284], [284, 333], [333, 298], [168, 417], [417, 8], [8, 168],
|
438 |
+
[448, 346], [346, 261], [261, 448], [417, 413], [413, 285], [285, 417],
|
439 |
+
[326, 327], [327, 328], [328, 326], [277, 355], [355, 329], [329, 277],
|
440 |
+
[309, 392], [392, 438], [438, 309], [381, 382], [382, 256], [256, 381],
|
441 |
+
[279, 429], [429, 360], [360, 279], [365, 364], [364, 379], [379, 365],
|
442 |
+
[355, 277], [277, 437], [437, 355], [282, 443], [443, 283], [283, 282],
|
443 |
+
[281, 275], [275, 363], [363, 281], [395, 431], [431, 369], [369, 395],
|
444 |
+
[299, 297], [297, 337], [337, 299], [335, 273], [273, 321], [321, 335],
|
445 |
+
[348, 450], [450, 349], [349, 348], [359, 446], [446, 467], [467, 359],
|
446 |
+
[283, 293], [293, 282], [282, 283], [250, 458], [458, 462], [462, 250],
|
447 |
+
[300, 276], [276, 383], [383, 300], [292, 308], [308, 325], [325, 292],
|
448 |
+
[283, 276], [276, 293], [293, 283], [264, 372], [372, 447], [447, 264],
|
449 |
+
[346, 352], [352, 340], [340, 346], [354, 274], [274, 19], [19, 354],
|
450 |
+
[363, 456], [456, 281], [281, 363], [426, 436], [436, 425], [425, 426],
|
451 |
+
[380, 381], [381, 252], [252, 380], [267, 269], [269, 393], [393, 267],
|
452 |
+
[421, 200], [200, 428], [428, 421], [371, 266], [266, 329], [329, 371],
|
453 |
+
[432, 287], [287, 422], [422, 432], [290, 250], [250, 328], [328, 290],
|
454 |
+
[385, 258], [258, 384], [384, 385], [446, 265], [265, 342], [342, 446],
|
455 |
+
[386, 387], [387, 257], [257, 386], [422, 424], [424, 430], [430, 422],
|
456 |
+
[445, 342], [342, 276], [276, 445], [422, 273], [273, 424], [424, 422],
|
457 |
+
[306, 292], [292, 307], [307, 306], [352, 366], [366, 345], [345, 352],
|
458 |
+
[268, 271], [271, 302], [302, 268], [358, 423], [423, 371], [371, 358],
|
459 |
+
[327, 294], [294, 460], [460, 327], [331, 279], [279, 294], [294, 331],
|
460 |
+
[303, 271], [271, 304], [304, 303], [436, 432], [432, 427], [427, 436],
|
461 |
+
[304, 272], [272, 408], [408, 304], [395, 394], [394, 431], [431, 395],
|
462 |
+
[378, 395], [395, 400], [400, 378], [296, 334], [334, 299], [299, 296],
|
463 |
+
[6, 351], [351, 168], [168, 6], [376, 352], [352, 411], [411, 376],
|
464 |
+
[307, 325], [325, 320], [320, 307], [285, 295], [295, 336], [336, 285],
|
465 |
+
[320, 319], [319, 404], [404, 320], [329, 330], [330, 349], [349, 329],
|
466 |
+
[334, 293], [293, 333], [333, 334], [366, 323], [323, 447], [447, 366],
|
467 |
+
[316, 15], [15, 315], [315, 316], [331, 358], [358, 279], [279, 331],
|
468 |
+
[317, 14], [14, 316], [316, 317], [8, 285], [285, 9], [9, 8],
|
469 |
+
[277, 329], [329, 350], [350, 277], [253, 374], [374, 252], [252, 253],
|
470 |
+
[319, 318], [318, 403], [403, 319], [351, 6], [6, 419], [419, 351],
|
471 |
+
[324, 318], [318, 325], [325, 324], [397, 367], [367, 365], [365, 397],
|
472 |
+
[288, 435], [435, 397], [397, 288], [278, 344], [344, 439], [439, 278],
|
473 |
+
[310, 272], [272, 311], [311, 310], [248, 195], [195, 281], [281, 248],
|
474 |
+
[375, 273], [273, 291], [291, 375], [175, 396], [396, 199], [199, 175],
|
475 |
+
[312, 311], [311, 268], [268, 312], [276, 283], [283, 445], [445, 276],
|
476 |
+
[390, 373], [373, 339], [339, 390], [295, 282], [282, 296], [296, 295],
|
477 |
+
[448, 449], [449, 346], [346, 448], [356, 264], [264, 454], [454, 356],
|
478 |
+
[337, 336], [336, 299], [299, 337], [337, 338], [338, 151], [151, 337],
|
479 |
+
[294, 278], [278, 455], [455, 294], [308, 292], [292, 415], [415, 308],
|
480 |
+
[429, 358], [358, 355], [355, 429], [265, 340], [340, 372], [372, 265],
|
481 |
+
[352, 346], [346, 280], [280, 352], [295, 442], [442, 282], [282, 295],
|
482 |
+
[354, 19], [19, 370], [370, 354], [285, 441], [441, 295], [295, 285],
|
483 |
+
[195, 248], [248, 197], [197, 195], [457, 440], [440, 274], [274, 457],
|
484 |
+
[301, 300], [300, 368], [368, 301], [417, 351], [351, 465], [465, 417],
|
485 |
+
[251, 301], [301, 389], [389, 251], [394, 395], [395, 379], [379, 394],
|
486 |
+
[399, 412], [412, 419], [419, 399], [410, 436], [436, 322], [322, 410],
|
487 |
+
[326, 2], [2, 393], [393, 326], [354, 370], [370, 461], [461, 354],
|
488 |
+
[393, 164], [164, 267], [267, 393], [268, 302], [302, 12], [12, 268],
|
489 |
+
[312, 268], [268, 13], [13, 312], [298, 293], [293, 301], [301, 298],
|
490 |
+
[265, 446], [446, 340], [340, 265], [280, 330], [330, 425], [425, 280],
|
491 |
+
[322, 426], [426, 391], [391, 322], [420, 429], [429, 437], [437, 420],
|
492 |
+
[393, 391], [391, 326], [326, 393], [344, 440], [440, 438], [438, 344],
|
493 |
+
[458, 459], [459, 461], [461, 458], [364, 434], [434, 394], [394, 364],
|
494 |
+
[428, 396], [396, 262], [262, 428], [274, 354], [354, 457], [457, 274],
|
495 |
+
[317, 316], [316, 402], [402, 317], [316, 315], [315, 403], [403, 316],
|
496 |
+
[315, 314], [314, 404], [404, 315], [314, 313], [313, 405], [405, 314],
|
497 |
+
[313, 421], [421, 406], [406, 313], [323, 366], [366, 361], [361, 323],
|
498 |
+
[292, 306], [306, 407], [407, 292], [306, 291], [291, 408], [408, 306],
|
499 |
+
[291, 287], [287, 409], [409, 291], [287, 432], [432, 410], [410, 287],
|
500 |
+
[427, 434], [434, 411], [411, 427], [372, 264], [264, 383], [383, 372],
|
501 |
+
[459, 309], [309, 457], [457, 459], [366, 352], [352, 401], [401, 366],
|
502 |
+
[1, 274], [274, 4], [4, 1], [418, 421], [421, 262], [262, 418],
|
503 |
+
[331, 294], [294, 358], [358, 331], [435, 433], [433, 367], [367, 435],
|
504 |
+
[392, 289], [289, 439], [439, 392], [328, 462], [462, 326], [326, 328],
|
505 |
+
[94, 2], [2, 370], [370, 94], [289, 305], [305, 455], [455, 289],
|
506 |
+
[339, 254], [254, 448], [448, 339], [359, 255], [255, 446], [446, 359],
|
507 |
+
[254, 253], [253, 449], [449, 254], [253, 252], [252, 450], [450, 253],
|
508 |
+
[252, 256], [256, 451], [451, 252], [256, 341], [341, 452], [452, 256],
|
509 |
+
[414, 413], [413, 463], [463, 414], [286, 441], [441, 414], [414, 286],
|
510 |
+
[286, 258], [258, 441], [441, 286], [258, 257], [257, 442], [442, 258],
|
511 |
+
[257, 259], [259, 443], [443, 257], [259, 260], [260, 444], [444, 259],
|
512 |
+
[260, 467], [467, 445], [445, 260], [309, 459], [459, 250], [250, 309],
|
513 |
+
[305, 289], [289, 290], [290, 305], [305, 290], [290, 460], [460, 305],
|
514 |
+
[401, 376], [376, 435], [435, 401], [309, 250], [250, 392], [392, 309],
|
515 |
+
[376, 411], [411, 433], [433, 376], [453, 341], [341, 464], [464, 453],
|
516 |
+
[357, 453], [453, 465], [465, 357], [343, 357], [357, 412], [412, 343],
|
517 |
+
[437, 343], [343, 399], [399, 437], [344, 360], [360, 440], [440, 344],
|
518 |
+
[420, 437], [437, 456], [456, 420], [360, 420], [420, 363], [363, 360],
|
519 |
+
[361, 401], [401, 288], [288, 361], [265, 372], [372, 353], [353, 265],
|
520 |
+
[390, 339], [339, 249], [249, 390], [339, 448], [448, 255], [255, 339]]))
|
client/src/hooks/useFaceLandmarkDetection.tsx
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
2 |
+
import * as vision from '@mediapipe/tasks-vision';
|
3 |
+
|
4 |
+
import { facePoke } from '@/lib/facePoke';
|
5 |
+
import { useMainStore } from './useMainStore';
|
6 |
+
import useThrottledCallback from 'beautiful-react-hooks/useThrottledCallback';
|
7 |
+
|
8 |
+
import { landmarkGroups, FACEMESH_LIPS, FACEMESH_LEFT_EYE, FACEMESH_LEFT_EYEBROW, FACEMESH_RIGHT_EYE, FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL } from './landmarks';
|
9 |
+
|
10 |
+
// New types for improved type safety
|
11 |
+
export type LandmarkGroup = 'lips' | 'leftEye' | 'leftEyebrow' | 'rightEye' | 'rightEyebrow' | 'faceOval' | 'background';
|
12 |
+
export type LandmarkCenter = { x: number; y: number; z: number };
|
13 |
+
export type ClosestLandmark = { group: LandmarkGroup; distance: number; vector: { x: number; y: number; z: number } };
|
14 |
+
|
15 |
+
export type MediaPipeResources = {
|
16 |
+
faceLandmarker: vision.FaceLandmarker | null;
|
17 |
+
drawingUtils: vision.DrawingUtils | null;
|
18 |
+
};
|
19 |
+
|
20 |
+
export function useFaceLandmarkDetection() {
|
21 |
+
const error = useMainStore(s => s.error);
|
22 |
+
const setError = useMainStore(s => s.setError);
|
23 |
+
const imageFile = useMainStore(s => s.imageFile);
|
24 |
+
const setImageFile = useMainStore(s => s.setImageFile);
|
25 |
+
const originalImage = useMainStore(s => s.originalImage);
|
26 |
+
const originalImageHash = useMainStore(s => s.originalImageHash);
|
27 |
+
const setOriginalImageHash = useMainStore(s => s.setOriginalImageHash);
|
28 |
+
const previewImage = useMainStore(s => s.previewImage);
|
29 |
+
const setPreviewImage = useMainStore(s => s.setPreviewImage);
|
30 |
+
const resetImage = useMainStore(s => s.resetImage);
|
31 |
+
|
32 |
+
;(window as any).debugJuju = useMainStore;
|
33 |
+
////////////////////////////////////////////////////////////////////////
|
34 |
+
// ok so apparently I cannot vary the latency, or else there is a bug
|
35 |
+
// const averageLatency = useMainStore(s => s.averageLatency);
|
36 |
+
const averageLatency = 220
|
37 |
+
////////////////////////////////////////////////////////////////////////
|
38 |
+
|
39 |
+
// State for face detection
|
40 |
+
const [faceLandmarks, setFaceLandmarks] = useState<vision.NormalizedLandmark[][]>([]);
|
41 |
+
const [isMediaPipeReady, setIsMediaPipeReady] = useState(false);
|
42 |
+
const [isDrawingUtilsReady, setIsDrawingUtilsReady] = useState(false);
|
43 |
+
const [blendShapes, setBlendShapes] = useState<vision.Classifications[]>([]);
|
44 |
+
|
45 |
+
// State for mouse interaction
|
46 |
+
const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null);
|
47 |
+
const [dragEnd, setDragEnd] = useState<{ x: number; y: number } | null>(null);
|
48 |
+
|
49 |
+
const [isDragging, setIsDragging] = useState(false);
|
50 |
+
const [isWaitingForResponse, setIsWaitingForResponse] = useState(false);
|
51 |
+
const dragStartRef = useRef<{ x: number; y: number } | null>(null);
|
52 |
+
const currentMousePosRef = useRef<{ x: number; y: number } | null>(null);
|
53 |
+
const lastModifiedImageHashRef = useRef<string | null>(null);
|
54 |
+
|
55 |
+
const [currentLandmark, setCurrentLandmark] = useState<ClosestLandmark | null>(null);
|
56 |
+
const [previousLandmark, setPreviousLandmark] = useState<ClosestLandmark | null>(null);
|
57 |
+
const [currentOpacity, setCurrentOpacity] = useState(0);
|
58 |
+
const [previousOpacity, setPreviousOpacity] = useState(0);
|
59 |
+
|
60 |
+
const [isHovering, setIsHovering] = useState(false);
|
61 |
+
|
62 |
+
// Refs
|
63 |
+
const canvasRef = useRef<HTMLCanvasElement>(null);
|
64 |
+
const mediaPipeRef = useRef<MediaPipeResources>({
|
65 |
+
faceLandmarker: null,
|
66 |
+
drawingUtils: null,
|
67 |
+
});
|
68 |
+
|
69 |
+
const setActiveLandmark = useCallback((newLandmark: ClosestLandmark | undefined) => {
|
70 |
+
//if (newLandmark && (!currentLandmark || newLandmark.group !== currentLandmark.group)) {
|
71 |
+
setPreviousLandmark(currentLandmark || null);
|
72 |
+
setCurrentLandmark(newLandmark || null);
|
73 |
+
setCurrentOpacity(0);
|
74 |
+
setPreviousOpacity(1);
|
75 |
+
//}
|
76 |
+
}, [currentLandmark, setPreviousLandmark, setCurrentLandmark, setCurrentOpacity, setPreviousOpacity]);
|
77 |
+
|
78 |
+
// Initialize MediaPipe
|
79 |
+
useEffect(() => {
|
80 |
+
console.log('Initializing MediaPipe...');
|
81 |
+
let isMounted = true;
|
82 |
+
|
83 |
+
const initializeMediaPipe = async () => {
|
84 |
+
const { FaceLandmarker, FilesetResolver, DrawingUtils } = vision;
|
85 |
+
|
86 |
+
try {
|
87 |
+
console.log('Initializing FilesetResolver...');
|
88 |
+
const filesetResolver = await FilesetResolver.forVisionTasks(
|
89 |
+
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.3/wasm"
|
90 |
+
);
|
91 |
+
|
92 |
+
console.log('Creating FaceLandmarker...');
|
93 |
+
const faceLandmarker = await FaceLandmarker.createFromOptions(filesetResolver, {
|
94 |
+
baseOptions: {
|
95 |
+
modelAssetPath: `https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task`,
|
96 |
+
delegate: "GPU"
|
97 |
+
},
|
98 |
+
outputFaceBlendshapes: true,
|
99 |
+
runningMode: "IMAGE",
|
100 |
+
numFaces: 1
|
101 |
+
});
|
102 |
+
|
103 |
+
if (isMounted) {
|
104 |
+
console.log('FaceLandmarker created successfully.');
|
105 |
+
mediaPipeRef.current.faceLandmarker = faceLandmarker;
|
106 |
+
setIsMediaPipeReady(true);
|
107 |
+
} else {
|
108 |
+
faceLandmarker.close();
|
109 |
+
}
|
110 |
+
} catch (error) {
|
111 |
+
console.error('Error during MediaPipe initialization:', error);
|
112 |
+
setError('Failed to initialize face detection. Please try refreshing the page.');
|
113 |
+
}
|
114 |
+
};
|
115 |
+
|
116 |
+
initializeMediaPipe();
|
117 |
+
|
118 |
+
|
119 |
+
return () => {
|
120 |
+
isMounted = false;
|
121 |
+
if (mediaPipeRef.current.faceLandmarker) {
|
122 |
+
mediaPipeRef.current.faceLandmarker.close();
|
123 |
+
}
|
124 |
+
};
|
125 |
+
}, []);
|
126 |
+
|
127 |
+
// New state for storing landmark centers
|
128 |
+
const [landmarkCenters, setLandmarkCenters] = useState<Record<LandmarkGroup, LandmarkCenter>>({} as Record<LandmarkGroup, LandmarkCenter>);
|
129 |
+
|
130 |
+
// Function to compute the center of each landmark group
|
131 |
+
const computeLandmarkCenters = useCallback((landmarks: vision.NormalizedLandmark[]) => {
|
132 |
+
const centers: Record<LandmarkGroup, LandmarkCenter> = {} as Record<LandmarkGroup, LandmarkCenter>;
|
133 |
+
|
134 |
+
const computeGroupCenter = (group: Readonly<Set<number[]>>): LandmarkCenter => {
|
135 |
+
let sumX = 0, sumY = 0, sumZ = 0, count = 0;
|
136 |
+
group.forEach(([index]) => {
|
137 |
+
if (landmarks[index]) {
|
138 |
+
sumX += landmarks[index].x;
|
139 |
+
sumY += landmarks[index].y;
|
140 |
+
sumZ += landmarks[index].z || 0;
|
141 |
+
count++;
|
142 |
+
}
|
143 |
+
});
|
144 |
+
return { x: sumX / count, y: sumY / count, z: sumZ / count };
|
145 |
+
};
|
146 |
+
|
147 |
+
centers.lips = computeGroupCenter(FACEMESH_LIPS);
|
148 |
+
centers.leftEye = computeGroupCenter(FACEMESH_LEFT_EYE);
|
149 |
+
centers.leftEyebrow = computeGroupCenter(FACEMESH_LEFT_EYEBROW);
|
150 |
+
centers.rightEye = computeGroupCenter(FACEMESH_RIGHT_EYE);
|
151 |
+
centers.rightEyebrow = computeGroupCenter(FACEMESH_RIGHT_EYEBROW);
|
152 |
+
centers.faceOval = computeGroupCenter(FACEMESH_FACE_OVAL);
|
153 |
+
centers.background = { x: 0.5, y: 0.5, z: 0 };
|
154 |
+
|
155 |
+
setLandmarkCenters(centers);
|
156 |
+
// console.log('Landmark centers computed:', centers);
|
157 |
+
}, []);
|
158 |
+
|
159 |
+
// Function to find the closest landmark to the mouse position
|
160 |
+
const findClosestLandmark = useCallback((mouseX: number, mouseY: number, isGroup?: LandmarkGroup): ClosestLandmark => {
|
161 |
+
const defaultLandmark: ClosestLandmark = {
|
162 |
+
group: 'background',
|
163 |
+
distance: 0,
|
164 |
+
vector: {
|
165 |
+
x: mouseX,
|
166 |
+
y: mouseY,
|
167 |
+
z: 0
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
if (Object.keys(landmarkCenters).length === 0) {
|
172 |
+
console.warn('Landmark centers not computed yet');
|
173 |
+
return defaultLandmark;
|
174 |
+
}
|
175 |
+
|
176 |
+
let closestGroup: LandmarkGroup | null = null;
|
177 |
+
let minDistance = Infinity;
|
178 |
+
let closestVector = { x: 0, y: 0, z: 0 };
|
179 |
+
let faceOvalDistance = Infinity;
|
180 |
+
let faceOvalVector = { x: 0, y: 0, z: 0 };
|
181 |
+
|
182 |
+
Object.entries(landmarkCenters).forEach(([group, center]) => {
|
183 |
+
const dx = mouseX - center.x;
|
184 |
+
const dy = mouseY - center.y;
|
185 |
+
const distance = Math.sqrt(dx * dx + dy * dy);
|
186 |
+
|
187 |
+
if (group === 'faceOval') {
|
188 |
+
faceOvalDistance = distance;
|
189 |
+
faceOvalVector = { x: dx, y: dy, z: 0 };
|
190 |
+
}
|
191 |
+
|
192 |
+
// filter to keep the group if it is belonging to `ofGroup`
|
193 |
+
if (isGroup) {
|
194 |
+
if (group !== isGroup) {
|
195 |
+
return
|
196 |
+
}
|
197 |
+
}
|
198 |
+
|
199 |
+
if (distance < minDistance) {
|
200 |
+
minDistance = distance;
|
201 |
+
closestGroup = group as LandmarkGroup;
|
202 |
+
closestVector = { x: dx, y: dy, z: 0 }; // Z is 0 as mouse interaction is 2D
|
203 |
+
}
|
204 |
+
});
|
205 |
+
|
206 |
+
// Fallback to faceOval if no group found or distance is too large
|
207 |
+
if (minDistance > 0.05) {
|
208 |
+
// console.log('Distance is too high, so we use the faceOval group');
|
209 |
+
closestGroup = 'background';
|
210 |
+
minDistance = faceOvalDistance;
|
211 |
+
closestVector = faceOvalVector;
|
212 |
+
}
|
213 |
+
|
214 |
+
if (closestGroup) {
|
215 |
+
// console.log(`Closest landmark: ${closestGroup}, distance: ${minDistance.toFixed(4)}`);
|
216 |
+
return { group: closestGroup, distance: minDistance, vector: closestVector };
|
217 |
+
} else {
|
218 |
+
// console.log('No group found, returning fallback');
|
219 |
+
return defaultLandmark
|
220 |
+
}
|
221 |
+
}, [landmarkCenters]);
|
222 |
+
|
223 |
+
// Detect face landmarks
|
224 |
+
const detectFaceLandmarks = useCallback(async (imageDataUrl: string) => {
|
225 |
+
// console.log('Attempting to detect face landmarks...');
|
226 |
+
if (!isMediaPipeReady) {
|
227 |
+
console.log('MediaPipe not ready. Skipping detection.');
|
228 |
+
return;
|
229 |
+
}
|
230 |
+
|
231 |
+
const faceLandmarker = mediaPipeRef.current.faceLandmarker;
|
232 |
+
|
233 |
+
if (!faceLandmarker) {
|
234 |
+
console.error('FaceLandmarker is not initialized.');
|
235 |
+
return;
|
236 |
+
}
|
237 |
+
|
238 |
+
const drawingUtils = mediaPipeRef.current.drawingUtils;
|
239 |
+
|
240 |
+
const image = new Image();
|
241 |
+
image.src = imageDataUrl;
|
242 |
+
await new Promise((resolve) => { image.onload = resolve; });
|
243 |
+
|
244 |
+
const faceLandmarkerResult = faceLandmarker.detect(image);
|
245 |
+
// console.log("Face landmarks detected:", faceLandmarkerResult);
|
246 |
+
|
247 |
+
setFaceLandmarks(faceLandmarkerResult.faceLandmarks);
|
248 |
+
setBlendShapes(faceLandmarkerResult.faceBlendshapes || []);
|
249 |
+
|
250 |
+
if (faceLandmarkerResult.faceLandmarks && faceLandmarkerResult.faceLandmarks[0]) {
|
251 |
+
computeLandmarkCenters(faceLandmarkerResult.faceLandmarks[0]);
|
252 |
+
}
|
253 |
+
|
254 |
+
if (canvasRef.current && drawingUtils) {
|
255 |
+
drawLandmarks(faceLandmarkerResult.faceLandmarks[0], canvasRef.current, drawingUtils);
|
256 |
+
}
|
257 |
+
}, [isMediaPipeReady, isDrawingUtilsReady, computeLandmarkCenters]);
|
258 |
+
|
259 |
+
const drawLandmarks = useCallback((
|
260 |
+
landmarks: vision.NormalizedLandmark[],
|
261 |
+
canvas: HTMLCanvasElement,
|
262 |
+
drawingUtils: vision.DrawingUtils
|
263 |
+
) => {
|
264 |
+
const ctx = canvas.getContext('2d');
|
265 |
+
if (!ctx) return;
|
266 |
+
|
267 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
268 |
+
|
269 |
+
if (canvasRef.current && previewImage) {
|
270 |
+
const img = new Image();
|
271 |
+
img.onload = () => {
|
272 |
+
canvas.width = img.width;
|
273 |
+
canvas.height = img.height;
|
274 |
+
|
275 |
+
const drawLandmarkGroup = (landmark: ClosestLandmark | null, opacity: number) => {
|
276 |
+
if (!landmark) return;
|
277 |
+
const connections = landmarkGroups[landmark.group];
|
278 |
+
if (connections) {
|
279 |
+
ctx.globalAlpha = opacity;
|
280 |
+
drawingUtils.drawConnectors(
|
281 |
+
landmarks,
|
282 |
+
connections,
|
283 |
+
{ color: 'orange', lineWidth: 4 }
|
284 |
+
);
|
285 |
+
}
|
286 |
+
};
|
287 |
+
|
288 |
+
drawLandmarkGroup(previousLandmark, previousOpacity);
|
289 |
+
drawLandmarkGroup(currentLandmark, currentOpacity);
|
290 |
+
|
291 |
+
ctx.globalAlpha = 1;
|
292 |
+
};
|
293 |
+
img.src = previewImage;
|
294 |
+
}
|
295 |
+
}, [previewImage, currentLandmark, previousLandmark, currentOpacity, previousOpacity]);
|
296 |
+
|
297 |
+
useEffect(() => {
|
298 |
+
if (isMediaPipeReady && isDrawingUtilsReady && faceLandmarks.length > 0 && canvasRef.current && mediaPipeRef.current.drawingUtils) {
|
299 |
+
drawLandmarks(faceLandmarks[0], canvasRef.current, mediaPipeRef.current.drawingUtils);
|
300 |
+
}
|
301 |
+
}, [isMediaPipeReady, isDrawingUtilsReady, faceLandmarks, currentLandmark, previousLandmark, currentOpacity, previousOpacity, drawLandmarks]);
|
302 |
+
useEffect(() => {
|
303 |
+
let animationFrame: number;
|
304 |
+
const animate = () => {
|
305 |
+
setCurrentOpacity((prev) => Math.min(prev + 0.2, 1));
|
306 |
+
setPreviousOpacity((prev) => Math.max(prev - 0.2, 0));
|
307 |
+
|
308 |
+
if (currentOpacity < 1 || previousOpacity > 0) {
|
309 |
+
animationFrame = requestAnimationFrame(animate);
|
310 |
+
}
|
311 |
+
};
|
312 |
+
animationFrame = requestAnimationFrame(animate);
|
313 |
+
return () => cancelAnimationFrame(animationFrame);
|
314 |
+
}, [currentLandmark]);
|
315 |
+
|
316 |
+
// Canvas ref callback
|
317 |
+
const canvasRefCallback = useCallback((node: HTMLCanvasElement | null) => {
|
318 |
+
if (node !== null) {
|
319 |
+
const ctx = node.getContext('2d');
|
320 |
+
if (ctx) {
|
321 |
+
// Get device pixel ratio
|
322 |
+
const pixelRatio = window.devicePixelRatio || 1;
|
323 |
+
|
324 |
+
// Scale canvas based on the pixel ratio
|
325 |
+
node.width = node.clientWidth * pixelRatio;
|
326 |
+
node.height = node.clientHeight * pixelRatio;
|
327 |
+
ctx.scale(pixelRatio, pixelRatio);
|
328 |
+
|
329 |
+
mediaPipeRef.current.drawingUtils = new vision.DrawingUtils(ctx);
|
330 |
+
setIsDrawingUtilsReady(true);
|
331 |
+
} else {
|
332 |
+
console.error('Failed to get 2D context from canvas.');
|
333 |
+
}
|
334 |
+
canvasRef.current = node;
|
335 |
+
}
|
336 |
+
}, []);
|
337 |
+
|
338 |
+
|
339 |
+
useEffect(() => {
|
340 |
+
if (!isMediaPipeReady) {
|
341 |
+
console.log('MediaPipe not ready. Skipping landmark detection.');
|
342 |
+
return
|
343 |
+
}
|
344 |
+
if (!previewImage) {
|
345 |
+
console.log('Preview image not ready. Skipping landmark detection.');
|
346 |
+
return
|
347 |
+
}
|
348 |
+
if (!isDrawingUtilsReady) {
|
349 |
+
console.log('DrawingUtils not ready. Skipping landmark detection.');
|
350 |
+
return
|
351 |
+
}
|
352 |
+
detectFaceLandmarks(previewImage);
|
353 |
+
}, [isMediaPipeReady, isDrawingUtilsReady, previewImage])
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
const modifyImage = useCallback(({ landmark, vector }: {
|
358 |
+
landmark: ClosestLandmark
|
359 |
+
vector: { x: number; y: number; z: number }
|
360 |
+
}) => {
|
361 |
+
|
362 |
+
const {
|
363 |
+
originalImage,
|
364 |
+
originalImageHash,
|
365 |
+
params: previousParams,
|
366 |
+
setParams,
|
367 |
+
setError
|
368 |
+
} = useMainStore.getState()
|
369 |
+
|
370 |
+
|
371 |
+
if (!originalImage) {
|
372 |
+
console.error('Image file or facePoke not available');
|
373 |
+
return;
|
374 |
+
}
|
375 |
+
|
376 |
+
const params = {
|
377 |
+
...previousParams
|
378 |
+
}
|
379 |
+
|
380 |
+
const minX = -0.50;
|
381 |
+
const maxX = 0.50;
|
382 |
+
const minY = -0.50;
|
383 |
+
const maxY = 0.50;
|
384 |
+
|
385 |
+
// Function to map a value from one range to another
|
386 |
+
const mapRange = (value: number, inMin: number, inMax: number, outMin: number, outMax: number): number => {
|
387 |
+
return Math.min(outMax, Math.max(outMin, ((value - inMin) * (outMax - outMin)) / (inMax - inMin) + outMin));
|
388 |
+
};
|
389 |
+
|
390 |
+
console.log("modifyImage:", {
|
391 |
+
originalImage,
|
392 |
+
originalImageHash,
|
393 |
+
landmark,
|
394 |
+
vector,
|
395 |
+
minX,
|
396 |
+
maxX,
|
397 |
+
minY,
|
398 |
+
maxY,
|
399 |
+
})
|
400 |
+
|
401 |
+
// Map landmarks to ImageModificationParams
|
402 |
+
switch (landmark.group) {
|
403 |
+
case 'leftEye':
|
404 |
+
case 'rightEye':
|
405 |
+
// eyebrow (min: -20, max: 5, default: 0)
|
406 |
+
const eyesMin = 210
|
407 |
+
const eyesMax = 5
|
408 |
+
params.eyes = mapRange(vector.x, minX, maxX, eyesMin, eyesMax);
|
409 |
+
|
410 |
+
break;
|
411 |
+
case 'leftEyebrow':
|
412 |
+
case 'rightEyebrow':
|
413 |
+
// moving the mouse vertically for the eyebrow
|
414 |
+
// should make them up/down
|
415 |
+
// eyebrow (min: -10, max: 15, default: 0)
|
416 |
+
const eyebrowMin = -10
|
417 |
+
const eyebrowMax = 15
|
418 |
+
params.eyebrow = mapRange(vector.y, minY, maxY, eyebrowMin, eyebrowMax);
|
419 |
+
|
420 |
+
break;
|
421 |
+
case 'lips':
|
422 |
+
// aaa (min: -30, max: 120, default: 0)
|
423 |
+
//const aaaMin = -30
|
424 |
+
//const aaaMax = 120
|
425 |
+
//params.aaa = mapRange(vector.x, minY, maxY, aaaMin, aaaMax);
|
426 |
+
|
427 |
+
// eee (min: -20, max: 15, default: 0)
|
428 |
+
const eeeMin = -20
|
429 |
+
const eeeMax = 15
|
430 |
+
params.eee = mapRange(vector.y, minY, maxY, eeeMin, eeeMax);
|
431 |
+
|
432 |
+
|
433 |
+
// woo (min: -20, max: 15, default: 0)
|
434 |
+
const wooMin = -20
|
435 |
+
const wooMax = 15
|
436 |
+
params.woo = mapRange(vector.x, minX, maxX, wooMin, wooMax);
|
437 |
+
|
438 |
+
break;
|
439 |
+
case 'faceOval':
|
440 |
+
// displacing the face horizontally by moving the mouse on the X axis
|
441 |
+
// should perform a yaw rotation
|
442 |
+
// rotate_roll (min: -20, max: 20, default: 0)
|
443 |
+
const rollMin = -40
|
444 |
+
const rollMax = 40
|
445 |
+
|
446 |
+
// note: we invert the axis here
|
447 |
+
params.rotate_roll = mapRange(vector.x, minX, maxX, rollMin, rollMax);
|
448 |
+
break;
|
449 |
+
|
450 |
+
case 'background':
|
451 |
+
// displacing the face horizontally by moving the mouse on the X axis
|
452 |
+
// should perform a yaw rotation
|
453 |
+
// rotate_yaw (min: -20, max: 20, default: 0)
|
454 |
+
const yawMin = -40
|
455 |
+
const yawMax = 40
|
456 |
+
|
457 |
+
// note: we invert the axis here
|
458 |
+
params.rotate_yaw = mapRange(-vector.x, minX, maxX, yawMin, yawMax);
|
459 |
+
|
460 |
+
// displacing the face vertically by moving the mouse on the Y axis
|
461 |
+
// should perform a pitch rotation
|
462 |
+
// rotate_pitch (min: -20, max: 20, default: 0)
|
463 |
+
const pitchMin = -40
|
464 |
+
const pitchMax = 40
|
465 |
+
params.rotate_pitch = mapRange(vector.y, minY, maxY, pitchMin, pitchMax);
|
466 |
+
break;
|
467 |
+
default:
|
468 |
+
return
|
469 |
+
}
|
470 |
+
|
471 |
+
for (const [key, value] of Object.entries(params)) {
|
472 |
+
if (isNaN(value as any) || !isFinite(value as any)) {
|
473 |
+
console.log(`${key} is NaN, aborting`)
|
474 |
+
return
|
475 |
+
}
|
476 |
+
}
|
477 |
+
console.log(`PITCH=${params.rotate_pitch || 0}, YAW=${params.rotate_yaw || 0}, ROLL=${params.rotate_roll || 0}`);
|
478 |
+
|
479 |
+
setParams(params)
|
480 |
+
try {
|
481 |
+
// For the first request or when the image file changes, send the full image
|
482 |
+
if (!lastModifiedImageHashRef.current || lastModifiedImageHashRef.current !== originalImageHash) {
|
483 |
+
lastModifiedImageHashRef.current = originalImageHash;
|
484 |
+
facePoke.modifyImage(originalImage, null, params);
|
485 |
+
} else {
|
486 |
+
// For subsequent requests, send only the hash
|
487 |
+
facePoke.modifyImage(null, lastModifiedImageHashRef.current, params);
|
488 |
+
}
|
489 |
+
} catch (error) {
|
490 |
+
// console.error('Error modifying image:', error);
|
491 |
+
setError('Failed to modify image');
|
492 |
+
}
|
493 |
+
}, []);
|
494 |
+
|
495 |
+
// this is throttled by our average latency
|
496 |
+
const modifyImageWithRateLimit = useThrottledCallback((params: {
|
497 |
+
landmark: ClosestLandmark
|
498 |
+
vector: { x: number; y: number; z: number }
|
499 |
+
}) => {
|
500 |
+
modifyImage(params);
|
501 |
+
}, [modifyImage], averageLatency);
|
502 |
+
|
503 |
+
const handleMouseEnter = useCallback(() => {
|
504 |
+
setIsHovering(true);
|
505 |
+
}, []);
|
506 |
+
|
507 |
+
const handleMouseLeave = useCallback(() => {
|
508 |
+
setIsHovering(false);
|
509 |
+
}, []);
|
510 |
+
|
511 |
+
// Update mouse event handlers
|
512 |
+
const handleMouseDown = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
|
513 |
+
if (!canvasRef.current) return;
|
514 |
+
|
515 |
+
const rect = canvasRef.current.getBoundingClientRect();
|
516 |
+
const x = (event.clientX - rect.left) / rect.width;
|
517 |
+
const y = (event.clientY - rect.top) / rect.height;
|
518 |
+
|
519 |
+
const landmark = findClosestLandmark(x, y);
|
520 |
+
console.log(`Mouse down on ${landmark.group}`);
|
521 |
+
setActiveLandmark(landmark);
|
522 |
+
setDragStart({ x, y });
|
523 |
+
dragStartRef.current = { x, y };
|
524 |
+
}, [findClosestLandmark, setActiveLandmark, setDragStart]);
|
525 |
+
|
526 |
+
const handleMouseMove = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
|
527 |
+
if (!canvasRef.current) return;
|
528 |
+
|
529 |
+
const rect = canvasRef.current.getBoundingClientRect();
|
530 |
+
const x = (event.clientX - rect.left) / rect.width;
|
531 |
+
const y = (event.clientY - rect.top) / rect.height;
|
532 |
+
|
533 |
+
// only send an API request to modify the image if we are actively dragging
|
534 |
+
if (dragStart && dragStartRef.current) {
|
535 |
+
|
536 |
+
const landmark = findClosestLandmark(x, y, currentLandmark?.group);
|
537 |
+
|
538 |
+
console.log(`Dragging mouse (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`);
|
539 |
+
|
540 |
+
// Compute the vector from the landmark center to the current mouse position
|
541 |
+
modifyImageWithRateLimit({
|
542 |
+
landmark: currentLandmark || landmark, // this will still use the initially selected landmark
|
543 |
+
vector: {
|
544 |
+
x: x - landmarkCenters[landmark.group].x,
|
545 |
+
y: y - landmarkCenters[landmark.group].y,
|
546 |
+
z: 0 // Z is 0 as mouse interaction is 2D
|
547 |
+
}
|
548 |
+
});
|
549 |
+
setIsDragging(true);
|
550 |
+
} else {
|
551 |
+
const landmark = findClosestLandmark(x, y);
|
552 |
+
|
553 |
+
//console.log(`Moving mouse over ${landmark.group}`);
|
554 |
+
// console.log(`Simple mouse move over ${landmark.group}`);
|
555 |
+
|
556 |
+
// we need to be careful here, we don't want to change the active
|
557 |
+
// landmark dynamically if we are busy dragging
|
558 |
+
|
559 |
+
if (!currentLandmark || (currentLandmark?.group !== landmark?.group)) {
|
560 |
+
// console.log("setting activeLandmark to ", landmark);
|
561 |
+
setActiveLandmark(landmark);
|
562 |
+
}
|
563 |
+
setIsHovering(true); // Ensure hovering state is maintained during movement
|
564 |
+
}
|
565 |
+
}, [currentLandmark, dragStart, setIsHovering, setActiveLandmark, setIsDragging, modifyImageWithRateLimit, landmarkCenters]);
|
566 |
+
|
567 |
+
const handleMouseUp = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
|
568 |
+
if (!canvasRef.current) return;
|
569 |
+
|
570 |
+
const rect = canvasRef.current.getBoundingClientRect();
|
571 |
+
const x = (event.clientX - rect.left) / rect.width;
|
572 |
+
const y = (event.clientY - rect.top) / rect.height;
|
573 |
+
|
574 |
+
// only send an API request to modify the image if we are actively dragging
|
575 |
+
if (dragStart && dragStartRef.current) {
|
576 |
+
|
577 |
+
const landmark = findClosestLandmark(x, y, currentLandmark?.group);
|
578 |
+
|
579 |
+
console.log(`Mouse up (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`);
|
580 |
+
|
581 |
+
// Compute the vector from the landmark center to the current mouse position
|
582 |
+
modifyImageWithRateLimit({
|
583 |
+
landmark: currentLandmark || landmark, // this will still use the initially selected landmark
|
584 |
+
vector: {
|
585 |
+
x: x - landmarkCenters[landmark.group].x,
|
586 |
+
y: y - landmarkCenters[landmark.group].y,
|
587 |
+
z: 0 // Z is 0 as mouse interaction is 2D
|
588 |
+
}
|
589 |
+
});
|
590 |
+
}
|
591 |
+
|
592 |
+
setIsDragging(false);
|
593 |
+
dragStartRef.current = null;
|
594 |
+
setActiveLandmark(undefined);
|
595 |
+
}, [currentLandmark, isDragging, modifyImageWithRateLimit, findClosestLandmark, setActiveLandmark, landmarkCenters, modifyImageWithRateLimit, setIsDragging]);
|
596 |
+
|
597 |
+
useEffect(() => {
|
598 |
+
facePoke.setOnModifiedImage((image: string, image_hash: string) => {
|
599 |
+
if (image) {
|
600 |
+
setPreviewImage(image);
|
601 |
+
}
|
602 |
+
setOriginalImageHash(image_hash);
|
603 |
+
lastModifiedImageHashRef.current = image_hash;
|
604 |
+
});
|
605 |
+
}, [setPreviewImage, setOriginalImageHash]);
|
606 |
+
|
607 |
+
return {
|
608 |
+
canvasRef,
|
609 |
+
canvasRefCallback,
|
610 |
+
mediaPipeRef,
|
611 |
+
faceLandmarks,
|
612 |
+
isMediaPipeReady,
|
613 |
+
isDrawingUtilsReady,
|
614 |
+
blendShapes,
|
615 |
+
|
616 |
+
//dragStart,
|
617 |
+
//setDragStart,
|
618 |
+
//dragEnd,
|
619 |
+
//setDragEnd,
|
620 |
+
setFaceLandmarks,
|
621 |
+
setBlendShapes,
|
622 |
+
|
623 |
+
handleMouseDown,
|
624 |
+
handleMouseUp,
|
625 |
+
handleMouseMove,
|
626 |
+
handleMouseEnter,
|
627 |
+
handleMouseLeave,
|
628 |
+
|
629 |
+
currentLandmark,
|
630 |
+
currentOpacity,
|
631 |
+
}
|
632 |
+
}
|
client/src/hooks/useFacePokeAPI.ts
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { useEffect, useState } from "react";
|
2 |
+
|
3 |
+
import { facePoke } from "../lib/facePoke";
|
4 |
+
import { useMainStore } from "./useMainStore";
|
5 |
+
|
6 |
+
export function useFacePokeAPI() {
|
7 |
+
|
8 |
+
// State for FacePoke
|
9 |
+
const [status, setStatus] = useState('');
|
10 |
+
const [isDebugMode, setIsDebugMode] = useState(false);
|
11 |
+
const [interruptMessage, setInterruptMessage] = useState<string | null>(null);
|
12 |
+
|
13 |
+
const [isLoading, setIsLoading] = useState(false);
|
14 |
+
|
15 |
+
// Initialize FacePoke
|
16 |
+
useEffect(() => {
|
17 |
+
const urlParams = new URLSearchParams(window.location.search);
|
18 |
+
setIsDebugMode(urlParams.get('debug') === 'true');
|
19 |
+
}, []);
|
20 |
+
|
21 |
+
// Handle WebSocket interruptions
|
22 |
+
useEffect(() => {
|
23 |
+
const handleInterruption = (event: CustomEvent) => {
|
24 |
+
setInterruptMessage(event.detail.message);
|
25 |
+
};
|
26 |
+
|
27 |
+
window.addEventListener('websocketInterruption' as any, handleInterruption);
|
28 |
+
|
29 |
+
return () => {
|
30 |
+
window.removeEventListener('websocketInterruption' as any, handleInterruption);
|
31 |
+
};
|
32 |
+
}, []);
|
33 |
+
|
34 |
+
return {
|
35 |
+
facePoke,
|
36 |
+
status,
|
37 |
+
setStatus,
|
38 |
+
isDebugMode,
|
39 |
+
setIsDebugMode,
|
40 |
+
interruptMessage,
|
41 |
+
isLoading,
|
42 |
+
setIsLoading,
|
43 |
+
}
|
44 |
+
}
|
client/src/hooks/useMainStore.ts
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { create } from 'zustand'
|
2 |
+
import type { ClosestLandmark } from './useFaceLandmarkDetection'
|
3 |
+
import type { ImageModificationParams } from '@/lib/facePoke'
|
4 |
+
|
5 |
+
interface ImageState {
|
6 |
+
error: string
|
7 |
+
imageFile: File | null
|
8 |
+
originalImage: string
|
9 |
+
previewImage: string
|
10 |
+
originalImageHash: string
|
11 |
+
minLatency: number
|
12 |
+
averageLatency: number
|
13 |
+
maxLatency: number
|
14 |
+
activeLandmark?: ClosestLandmark
|
15 |
+
params: Partial<ImageModificationParams>
|
16 |
+
setError: (error?: string) => void
|
17 |
+
setImageFile: (file: File | null) => void
|
18 |
+
setOriginalImage: (url: string) => void
|
19 |
+
setOriginalImageHash: (hash: string) => void
|
20 |
+
setPreviewImage: (url: string) => void
|
21 |
+
resetImage: () => void
|
22 |
+
setAverageLatency: (averageLatency: number) => void
|
23 |
+
setActiveLandmark: (activeLandmark?: ClosestLandmark) => void
|
24 |
+
setParams: (params: Partial<ImageModificationParams>) => void
|
25 |
+
}
|
26 |
+
|
27 |
+
export const useMainStore = create<ImageState>((set, get) => ({
|
28 |
+
error: '',
|
29 |
+
imageFile: null,
|
30 |
+
originalImage: '',
|
31 |
+
originalImageHash: '',
|
32 |
+
previewImage: '',
|
33 |
+
minLatency: 20, // min time between requests
|
34 |
+
averageLatency: 190, // this should be the average for most people
|
35 |
+
maxLatency: 4000, // max time between requests
|
36 |
+
activeLandmark: undefined,
|
37 |
+
params: {},
|
38 |
+
setError: (error: string = '') => set({ error }),
|
39 |
+
setImageFile: (file) => set({ imageFile: file }),
|
40 |
+
setOriginalImage: (url) => set({ originalImage: url }),
|
41 |
+
setOriginalImageHash: (originalImageHash) => set({ originalImageHash }),
|
42 |
+
setPreviewImage: (url) => set({ previewImage: url }),
|
43 |
+
resetImage: () => {
|
44 |
+
const { originalImage } = get()
|
45 |
+
if (originalImage) {
|
46 |
+
set({ previewImage: originalImage })
|
47 |
+
}
|
48 |
+
},
|
49 |
+
setAverageLatency: (averageLatency: number) => set({ averageLatency }),
|
50 |
+
setActiveLandmark: (activeLandmark?: ClosestLandmark) => set({ activeLandmark }),
|
51 |
+
setParams: (params: Partial<ImageModificationParams>) => {
|
52 |
+
const {params: previousParams } = get()
|
53 |
+
set({ params: {
|
54 |
+
...previousParams,
|
55 |
+
...params
|
56 |
+
}})
|
57 |
+
},
|
58 |
+
}))
|
client/src/index.tsx
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { createRoot } from 'react-dom/client';
|
2 |
+
|
3 |
+
import { App } from './app';
|
4 |
+
|
5 |
+
const root = createRoot(document.getElementById('root')!);
|
6 |
+
root.render(<App />);
|
client/src/layout.tsx
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { type ReactNode } from 'react';
|
2 |
+
|
3 |
+
export function Layout({ children }: { children: ReactNode }) {
|
4 |
+
return (
|
5 |
+
<div className="fixed min-h-screen w-full flex items-center justify-center bg-gradient-to-br from-gray-300 to-stone-300"
|
6 |
+
style={{ boxShadow: "inset 0 0 10vh 0 rgb(0 0 0 / 30%)" }}>
|
7 |
+
<div className="min-h-screen w-full py-8 flex flex-col justify-center">
|
8 |
+
<div className="relative p-4 sm:max-w-5xl sm:mx-auto">
|
9 |
+
{children}
|
10 |
+
</div>
|
11 |
+
</div>
|
12 |
+
</div>
|
13 |
+
);
|
14 |
+
}
|
client/src/lib/circularBuffer.ts
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
/**
|
4 |
+
* Circular buffer for storing and managing response times.
|
5 |
+
*/
|
6 |
+
export class CircularBuffer<T> {
|
7 |
+
private buffer: T[];
|
8 |
+
private pointer: number;
|
9 |
+
|
10 |
+
constructor(private capacity: number) {
|
11 |
+
this.buffer = new Array<T>(capacity);
|
12 |
+
this.pointer = 0;
|
13 |
+
}
|
14 |
+
|
15 |
+
/**
|
16 |
+
* Adds an item to the buffer, overwriting the oldest item if full.
|
17 |
+
* @param item - The item to add to the buffer.
|
18 |
+
*/
|
19 |
+
push(item: T): void {
|
20 |
+
this.buffer[this.pointer] = item;
|
21 |
+
this.pointer = (this.pointer + 1) % this.capacity;
|
22 |
+
}
|
23 |
+
|
24 |
+
/**
|
25 |
+
* Retrieves all items currently in the buffer.
|
26 |
+
* @returns An array of all items in the buffer.
|
27 |
+
*/
|
28 |
+
getAll(): T[] {
|
29 |
+
return this.buffer.filter(item => item !== undefined);
|
30 |
+
}
|
31 |
+
}
|
client/src/lib/convertImageToBase64.ts
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export async function convertImageToBase64(imageFile: File): Promise<string> {
|
2 |
+
return new Promise((resolve, reject) => {
|
3 |
+
const reader = new FileReader();
|
4 |
+
|
5 |
+
reader.onload = () => {
|
6 |
+
if (typeof reader.result === 'string') {
|
7 |
+
resolve(reader.result);
|
8 |
+
} else {
|
9 |
+
reject(new Error('Failed to convert image to base64'));
|
10 |
+
}
|
11 |
+
};
|
12 |
+
|
13 |
+
reader.onerror = () => {
|
14 |
+
reject(new Error('Error reading file'));
|
15 |
+
};
|
16 |
+
|
17 |
+
reader.readAsDataURL(imageFile);
|
18 |
+
});
|
19 |
+
}
|
client/src/lib/facePoke.ts
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { v4 as uuidv4 } from 'uuid';
|
2 |
+
import { CircularBuffer } from './circularBuffer';
|
3 |
+
import { useMainStore } from '@/hooks/useMainStore';
|
4 |
+
|
5 |
+
/**
|
6 |
+
* Represents a tracked request with its UUID and timestamp.
|
7 |
+
*/
|
8 |
+
export interface TrackedRequest {
|
9 |
+
uuid: string;
|
10 |
+
timestamp: number;
|
11 |
+
}
|
12 |
+
|
13 |
+
/**
|
14 |
+
* Represents the parameters for image modification.
|
15 |
+
*/
|
16 |
+
export interface ImageModificationParams {
|
17 |
+
eyes: number;
|
18 |
+
eyebrow: number;
|
19 |
+
wink: number;
|
20 |
+
pupil_x: number;
|
21 |
+
pupil_y: number;
|
22 |
+
aaa: number;
|
23 |
+
eee: number;
|
24 |
+
woo: number;
|
25 |
+
smile: number;
|
26 |
+
rotate_pitch: number;
|
27 |
+
rotate_yaw: number;
|
28 |
+
rotate_roll: number;
|
29 |
+
}
|
30 |
+
|
31 |
+
/**
|
32 |
+
* Represents a message to modify an image.
|
33 |
+
*/
|
34 |
+
export interface ModifyImageMessage {
|
35 |
+
type: 'modify_image';
|
36 |
+
image?: string;
|
37 |
+
image_hash?: string;
|
38 |
+
params: Partial<ImageModificationParams>;
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
/**
|
43 |
+
* Callback type for handling modified images.
|
44 |
+
*/
|
45 |
+
type OnModifiedImage = (image: string, image_hash: string) => void;
|
46 |
+
|
47 |
+
/**
|
48 |
+
* Enum representing the different states of a WebSocket connection.
|
49 |
+
*/
|
50 |
+
enum WebSocketState {
|
51 |
+
CONNECTING = 0,
|
52 |
+
OPEN = 1,
|
53 |
+
CLOSING = 2,
|
54 |
+
CLOSED = 3
|
55 |
+
}
|
56 |
+
|
57 |
+
/**
|
58 |
+
* FacePoke class manages the WebSocket connection
|
59 |
+
*/
|
60 |
+
export class FacePoke {
|
61 |
+
private ws: WebSocket | null = null;
|
62 |
+
private readonly connectionId: string = uuidv4();
|
63 |
+
private isUnloading: boolean = false;
|
64 |
+
private onModifiedImage: OnModifiedImage = () => {};
|
65 |
+
private reconnectAttempts: number = 0;
|
66 |
+
private readonly maxReconnectAttempts: number = 5;
|
67 |
+
private readonly reconnectDelay: number = 5000;
|
68 |
+
private readonly eventListeners: Map<string, Set<Function>> = new Map();
|
69 |
+
|
70 |
+
private requestTracker: Map<string, TrackedRequest> = new Map();
|
71 |
+
private responseTimeBuffer: CircularBuffer<number>;
|
72 |
+
private readonly MAX_TRACKED_TIMES = 5; // Number of recent response times to track
|
73 |
+
|
74 |
+
/**
|
75 |
+
* Creates an instance of FacePoke.
|
76 |
+
* Initializes the WebSocket connection.
|
77 |
+
*/
|
78 |
+
constructor() {
|
79 |
+
console.log(`[FacePoke] Initializing FacePoke instance with connection ID: ${this.connectionId}`);
|
80 |
+
this.initializeWebSocket();
|
81 |
+
this.setupUnloadHandler();
|
82 |
+
|
83 |
+
this.responseTimeBuffer = new CircularBuffer<number>(this.MAX_TRACKED_TIMES);
|
84 |
+
console.log(`[FacePoke] Initialized response time tracker with capacity: ${this.MAX_TRACKED_TIMES}`);
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
/**
|
89 |
+
* Generates a unique UUID for a request and starts tracking it.
|
90 |
+
* @returns The generated UUID for the request.
|
91 |
+
*/
|
92 |
+
private trackRequest(): string {
|
93 |
+
const uuid = uuidv4();
|
94 |
+
this.requestTracker.set(uuid, { uuid, timestamp: Date.now() });
|
95 |
+
// console.log(`[FacePoke] Started tracking request with UUID: ${uuid}`);
|
96 |
+
return uuid;
|
97 |
+
}
|
98 |
+
|
99 |
+
/**
|
100 |
+
* Completes tracking for a request and updates response time statistics.
|
101 |
+
* @param uuid - The UUID of the completed request.
|
102 |
+
*/
|
103 |
+
private completeRequest(uuid: string): void {
|
104 |
+
const request = this.requestTracker.get(uuid);
|
105 |
+
if (request) {
|
106 |
+
const responseTime = Date.now() - request.timestamp;
|
107 |
+
this.responseTimeBuffer.push(responseTime);
|
108 |
+
this.requestTracker.delete(uuid);
|
109 |
+
this.updateThrottleTime();
|
110 |
+
console.log(`[FacePoke] Completed request ${uuid}. Response time: ${responseTime}ms`);
|
111 |
+
} else {
|
112 |
+
console.warn(`[FacePoke] Attempted to complete unknown request: ${uuid}`);
|
113 |
+
}
|
114 |
+
}
|
115 |
+
|
116 |
+
/**
|
117 |
+
* Calculates the average response time from recent requests.
|
118 |
+
* @returns The average response time in milliseconds.
|
119 |
+
*/
|
120 |
+
private calculateAverageResponseTime(): number {
|
121 |
+
const times = this.responseTimeBuffer.getAll();
|
122 |
+
|
123 |
+
const averageLatency = useMainStore.getState().averageLatency;
|
124 |
+
|
125 |
+
if (times.length === 0) return averageLatency;
|
126 |
+
const sum = times.reduce((acc, time) => acc + time, 0);
|
127 |
+
return sum / times.length;
|
128 |
+
}
|
129 |
+
|
130 |
+
/**
|
131 |
+
* Updates the throttle time based on recent response times.
|
132 |
+
*/
|
133 |
+
private updateThrottleTime(): void {
|
134 |
+
const { minLatency, maxLatency, averageLatency, setAverageLatency } = useMainStore.getState();
|
135 |
+
const avgResponseTime = this.calculateAverageResponseTime();
|
136 |
+
const newLatency = Math.min(minLatency, Math.max(minLatency, avgResponseTime));
|
137 |
+
|
138 |
+
if (newLatency !== averageLatency) {
|
139 |
+
setAverageLatency(newLatency)
|
140 |
+
console.log(`[FacePoke] Updated throttle time (latency is ${newLatency}ms)`);
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
/**
|
145 |
+
* Sets the callback function for handling modified images.
|
146 |
+
* @param handler - The function to be called when a modified image is received.
|
147 |
+
*/
|
148 |
+
public setOnModifiedImage(handler: OnModifiedImage): void {
|
149 |
+
this.onModifiedImage = handler;
|
150 |
+
console.log(`[FacePoke] onModifiedImage handler set`);
|
151 |
+
}
|
152 |
+
|
153 |
+
/**
|
154 |
+
* Starts or restarts the WebSocket connection.
|
155 |
+
*/
|
156 |
+
public async startWebSocket(): Promise<void> {
|
157 |
+
console.log(`[FacePoke] Starting WebSocket connection.`);
|
158 |
+
if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) {
|
159 |
+
await this.initializeWebSocket();
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
/**
|
164 |
+
* Initializes the WebSocket connection.
|
165 |
+
* Implements exponential backoff for reconnection attempts.
|
166 |
+
*/
|
167 |
+
private async initializeWebSocket(): Promise<void> {
|
168 |
+
console.log(`[FacePoke][${this.connectionId}] Initializing WebSocket connection`);
|
169 |
+
|
170 |
+
const connect = () => {
|
171 |
+
this.ws = new WebSocket(`wss://${window.location.host}/ws`);
|
172 |
+
|
173 |
+
this.ws.onopen = this.handleWebSocketOpen.bind(this);
|
174 |
+
this.ws.onmessage = this.handleWebSocketMessage.bind(this);
|
175 |
+
this.ws.onclose = this.handleWebSocketClose.bind(this);
|
176 |
+
this.ws.onerror = this.handleWebSocketError.bind(this);
|
177 |
+
};
|
178 |
+
|
179 |
+
// const debouncedConnect = debounce(connect, this.reconnectDelay, { leading: true, trailing: false });
|
180 |
+
|
181 |
+
connect(); // Initial connection attempt
|
182 |
+
}
|
183 |
+
|
184 |
+
/**
|
185 |
+
* Handles the WebSocket open event.
|
186 |
+
*/
|
187 |
+
private handleWebSocketOpen(): void {
|
188 |
+
console.log(`[FacePoke][${this.connectionId}] WebSocket connection opened`);
|
189 |
+
this.reconnectAttempts = 0; // Reset reconnect attempts on successful connection
|
190 |
+
this.emitEvent('websocketOpen');
|
191 |
+
}
|
192 |
+
|
193 |
+
// Update handleWebSocketMessage to complete request tracking
|
194 |
+
private handleWebSocketMessage(event: MessageEvent): void {
|
195 |
+
try {
|
196 |
+
const data = JSON.parse(event.data);
|
197 |
+
// console.log(`[FacePoke][${this.connectionId}] Received JSON data:`, data);
|
198 |
+
|
199 |
+
if (data.uuid) {
|
200 |
+
this.completeRequest(data.uuid);
|
201 |
+
}
|
202 |
+
|
203 |
+
if (data.type === 'modified_image') {
|
204 |
+
if (data?.image) {
|
205 |
+
this.onModifiedImage(data.image, data.image_hash);
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
this.emitEvent('message', data);
|
210 |
+
} catch (error) {
|
211 |
+
console.error(`[FacePoke][${this.connectionId}] Error parsing WebSocket message:`, error);
|
212 |
+
}
|
213 |
+
}
|
214 |
+
|
215 |
+
/**
|
216 |
+
* Handles WebSocket close events.
|
217 |
+
* Implements reconnection logic with exponential backoff.
|
218 |
+
* @param event - The CloseEvent containing close information.
|
219 |
+
*/
|
220 |
+
private handleWebSocketClose(event: CloseEvent): void {
|
221 |
+
if (event.wasClean) {
|
222 |
+
console.log(`[FacePoke][${this.connectionId}] WebSocket connection closed cleanly, code=${event.code}, reason=${event.reason}`);
|
223 |
+
} else {
|
224 |
+
console.warn(`[FacePoke][${this.connectionId}] WebSocket connection abruptly closed`);
|
225 |
+
}
|
226 |
+
|
227 |
+
this.emitEvent('websocketClose', event);
|
228 |
+
|
229 |
+
// Attempt to reconnect after a delay, unless the page is unloading or max attempts reached
|
230 |
+
if (!this.isUnloading && this.reconnectAttempts < this.maxReconnectAttempts) {
|
231 |
+
this.reconnectAttempts++;
|
232 |
+
const delay = Math.min(1000 * (2 ** this.reconnectAttempts), 30000); // Exponential backoff, max 30 seconds
|
233 |
+
console.log(`[FacePoke][${this.connectionId}] Attempting to reconnect in ${delay}ms (Attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts})...`);
|
234 |
+
setTimeout(() => this.initializeWebSocket(), delay);
|
235 |
+
} else if (this.reconnectAttempts >= this.maxReconnectAttempts) {
|
236 |
+
console.error(`[FacePoke][${this.connectionId}] Max reconnect attempts reached. Please refresh the page.`);
|
237 |
+
this.emitEvent('maxReconnectAttemptsReached');
|
238 |
+
}
|
239 |
+
}
|
240 |
+
|
241 |
+
/**
|
242 |
+
* Handles WebSocket errors.
|
243 |
+
* @param error - The error event.
|
244 |
+
*/
|
245 |
+
private handleWebSocketError(error: Event): void {
|
246 |
+
console.error(`[FacePoke][${this.connectionId}] WebSocket error:`, error);
|
247 |
+
this.emitEvent('websocketError', error);
|
248 |
+
}
|
249 |
+
|
250 |
+
/**
|
251 |
+
* Handles interruption messages from the server.
|
252 |
+
* @param message - The interruption message.
|
253 |
+
*/
|
254 |
+
private handleInterruption(message: string): void {
|
255 |
+
console.warn(`[FacePoke] Interruption: ${message}`);
|
256 |
+
this.emitEvent('interruption', message);
|
257 |
+
}
|
258 |
+
|
259 |
+
/**
|
260 |
+
* Toggles the microphone on or off.
|
261 |
+
* @param isOn - Whether to turn the microphone on (true) or off (false).
|
262 |
+
*/
|
263 |
+
public async toggleMicrophone(isOn: boolean): Promise<void> {
|
264 |
+
console.log(`[FacePoke] Attempting to ${isOn ? 'start' : 'stop'} microphone`);
|
265 |
+
try {
|
266 |
+
if (isOn) {
|
267 |
+
await this.startMicrophone();
|
268 |
+
} else {
|
269 |
+
this.stopMicrophone();
|
270 |
+
}
|
271 |
+
this.emitEvent('microphoneToggled', isOn);
|
272 |
+
} catch (error) {
|
273 |
+
console.error(`[FacePoke] Error toggling microphone:`, error);
|
274 |
+
this.emitEvent('microphoneError', error);
|
275 |
+
throw error;
|
276 |
+
}
|
277 |
+
}
|
278 |
+
|
279 |
+
|
280 |
+
/**
|
281 |
+
* Cleans up resources and closes connections.
|
282 |
+
*/
|
283 |
+
public cleanup(): void {
|
284 |
+
console.log('[FacePoke] Starting cleanup process');
|
285 |
+
if (this.ws) {
|
286 |
+
this.ws.close();
|
287 |
+
this.ws = null;
|
288 |
+
}
|
289 |
+
this.eventListeners.clear();
|
290 |
+
console.log('[FacePoke] Cleanup completed');
|
291 |
+
this.emitEvent('cleanup');
|
292 |
+
}
|
293 |
+
|
294 |
+
/**
|
295 |
+
* Modifies an image based on the provided parameters
|
296 |
+
* @param image - The data-uri base64 image to modify.
|
297 |
+
* @param imageHash - The hash of the image to modify.
|
298 |
+
* @param params - The parameters for image modification.
|
299 |
+
*/
|
300 |
+
public modifyImage(image: string | null, imageHash: string | null, params: Partial<ImageModificationParams>): void {
|
301 |
+
try {
|
302 |
+
const message: ModifyImageMessage = {
|
303 |
+
type: 'modify_image',
|
304 |
+
params: params
|
305 |
+
};
|
306 |
+
|
307 |
+
if (image) {
|
308 |
+
message.image = image;
|
309 |
+
} else if (imageHash) {
|
310 |
+
message.image_hash = imageHash;
|
311 |
+
} else {
|
312 |
+
throw new Error('Either image or imageHash must be provided');
|
313 |
+
}
|
314 |
+
|
315 |
+
this.sendJsonMessage(message);
|
316 |
+
// console.log(`[FacePoke] Sent modify image request with UUID: ${uuid}`);
|
317 |
+
} catch (err) {
|
318 |
+
console.error(`[FacePoke] Failed to modify the image:`, err);
|
319 |
+
}
|
320 |
+
}
|
321 |
+
|
322 |
+
/**
|
323 |
+
* Sends a JSON message through the WebSocket connection with request tracking.
|
324 |
+
* @param message - The message to send.
|
325 |
+
* @throws Error if the WebSocket is not open.
|
326 |
+
*/
|
327 |
+
private sendJsonMessage<T>(message: T): void {
|
328 |
+
if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) {
|
329 |
+
const error = new Error('WebSocket connection is not open');
|
330 |
+
console.error('[FacePoke] Error sending JSON message:', error);
|
331 |
+
this.emitEvent('sendJsonMessageError', error);
|
332 |
+
throw error;
|
333 |
+
}
|
334 |
+
|
335 |
+
const uuid = this.trackRequest();
|
336 |
+
const messageWithUuid = { ...message, uuid };
|
337 |
+
// console.log(`[FacePoke] Sending JSON message with UUID ${uuid}:`, messageWithUuid);
|
338 |
+
this.ws.send(JSON.stringify(messageWithUuid));
|
339 |
+
}
|
340 |
+
|
341 |
+
/**
|
342 |
+
* Sets up the unload handler to clean up resources when the page is unloading.
|
343 |
+
*/
|
344 |
+
private setupUnloadHandler(): void {
|
345 |
+
window.addEventListener('beforeunload', () => {
|
346 |
+
console.log('[FacePoke] Page is unloading, cleaning up resources');
|
347 |
+
this.isUnloading = true;
|
348 |
+
if (this.ws) {
|
349 |
+
this.ws.close(1000, 'Page is unloading');
|
350 |
+
}
|
351 |
+
this.cleanup();
|
352 |
+
});
|
353 |
+
}
|
354 |
+
|
355 |
+
/**
|
356 |
+
* Adds an event listener for a specific event type.
|
357 |
+
* @param eventType - The type of event to listen for.
|
358 |
+
* @param listener - The function to be called when the event is emitted.
|
359 |
+
*/
|
360 |
+
public addEventListener(eventType: string, listener: Function): void {
|
361 |
+
if (!this.eventListeners.has(eventType)) {
|
362 |
+
this.eventListeners.set(eventType, new Set());
|
363 |
+
}
|
364 |
+
this.eventListeners.get(eventType)!.add(listener);
|
365 |
+
console.log(`[FacePoke] Added event listener for '${eventType}'`);
|
366 |
+
}
|
367 |
+
|
368 |
+
/**
|
369 |
+
* Removes an event listener for a specific event type.
|
370 |
+
* @param eventType - The type of event to remove the listener from.
|
371 |
+
* @param listener - The function to be removed from the listeners.
|
372 |
+
*/
|
373 |
+
public removeEventListener(eventType: string, listener: Function): void {
|
374 |
+
const listeners = this.eventListeners.get(eventType);
|
375 |
+
if (listeners) {
|
376 |
+
listeners.delete(listener);
|
377 |
+
console.log(`[FacePoke] Removed event listener for '${eventType}'`);
|
378 |
+
}
|
379 |
+
}
|
380 |
+
|
381 |
+
/**
|
382 |
+
* Emits an event to all registered listeners for that event type.
|
383 |
+
* @param eventType - The type of event to emit.
|
384 |
+
* @param data - Optional data to pass to the event listeners.
|
385 |
+
*/
|
386 |
+
private emitEvent(eventType: string, data?: any): void {
|
387 |
+
const listeners = this.eventListeners.get(eventType);
|
388 |
+
if (listeners) {
|
389 |
+
console.log(`[FacePoke] Emitting event '${eventType}' with data:`, data);
|
390 |
+
listeners.forEach(listener => listener(data));
|
391 |
+
}
|
392 |
+
}
|
393 |
+
}
|
394 |
+
|
395 |
+
/**
|
396 |
+
* Singleton instance of the FacePoke class.
|
397 |
+
*/
|
398 |
+
export const facePoke = new FacePoke();
|
client/src/lib/throttle.ts
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
/**
|
3 |
+
* Custom throttle function that allows the first call to go through immediately
|
4 |
+
* and then limits subsequent calls.
|
5 |
+
* @param func - The function to throttle.
|
6 |
+
* @param limit - The minimum time between function calls in milliseconds.
|
7 |
+
* @returns A throttled version of the function.
|
8 |
+
*/
|
9 |
+
export function throttle<T extends (...args: any[]) => any>(func: T, limit: number): T {
|
10 |
+
let lastCall = 0;
|
11 |
+
let timeoutId: NodeJS.Timer | null = null;
|
12 |
+
|
13 |
+
return function (this: any, ...args: Parameters<T>) {
|
14 |
+
const context = this;
|
15 |
+
const now = Date.now();
|
16 |
+
|
17 |
+
if (now - lastCall >= limit) {
|
18 |
+
if (timeoutId !== null) {
|
19 |
+
clearTimeout(timeoutId);
|
20 |
+
timeoutId = null;
|
21 |
+
}
|
22 |
+
lastCall = now;
|
23 |
+
return func.apply(context, args);
|
24 |
+
} else if (!timeoutId) {
|
25 |
+
timeoutId = setTimeout(() => {
|
26 |
+
lastCall = Date.now();
|
27 |
+
timeoutId = null;
|
28 |
+
func.apply(context, args);
|
29 |
+
}, limit - (now - lastCall));
|
30 |
+
}
|
31 |
+
} as T;
|
32 |
+
}
|
client/src/lib/utils.ts
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { clsx, type ClassValue } from "clsx"
|
2 |
+
import { twMerge } from "tailwind-merge"
|
3 |
+
|
4 |
+
export function cn(...inputs: ClassValue[]) {
|
5 |
+
return twMerge(clsx(inputs))
|
6 |
+
}
|
7 |
+
|
8 |
+
export function truncateFileName(fileName: string, maxLength: number = 16) {
|
9 |
+
if (fileName.length <= maxLength) return fileName;
|
10 |
+
|
11 |
+
const start = fileName.slice(0, maxLength / 2 - 1);
|
12 |
+
const end = fileName.slice(-maxLength / 2 + 2);
|
13 |
+
|
14 |
+
return `${start}...${end}`;
|
15 |
+
};
|
client/src/styles/globals.css
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@tailwind base;
|
2 |
+
@tailwind components;
|
3 |
+
@tailwind utilities;
|
4 |
+
|
5 |
+
@layer base {
|
6 |
+
:root {
|
7 |
+
--background: 0 0% 100%;
|
8 |
+
--foreground: 222.2 47.4% 11.2%;
|
9 |
+
|
10 |
+
--muted: 210 40% 96.1%;
|
11 |
+
--muted-foreground: 215.4 16.3% 46.9%;
|
12 |
+
|
13 |
+
--popover: 0 0% 100%;
|
14 |
+
--popover-foreground: 222.2 47.4% 11.2%;
|
15 |
+
|
16 |
+
--border: 214.3 31.8% 91.4%;
|
17 |
+
--input: 214.3 31.8% 91.4%;
|
18 |
+
|
19 |
+
--card: 0 0% 100%;
|
20 |
+
--card-foreground: 222.2 47.4% 11.2%;
|
21 |
+
|
22 |
+
--primary: 222.2 47.4% 11.2%;
|
23 |
+
--primary-foreground: 210 40% 98%;
|
24 |
+
|
25 |
+
--secondary: 210 40% 96.1%;
|
26 |
+
--secondary-foreground: 222.2 47.4% 11.2%;
|
27 |
+
|
28 |
+
--accent: 210 40% 96.1%;
|
29 |
+
--accent-foreground: 222.2 47.4% 11.2%;
|
30 |
+
|
31 |
+
--destructive: 0 100% 50%;
|
32 |
+
--destructive-foreground: 210 40% 98%;
|
33 |
+
|
34 |
+
--ring: 215 20.2% 65.1%;
|
35 |
+
|
36 |
+
--radius: 0.5rem;
|
37 |
+
}
|
38 |
+
|
39 |
+
.dark {
|
40 |
+
--background: 224 71% 4%;
|
41 |
+
--foreground: 213 31% 91%;
|
42 |
+
|
43 |
+
--muted: 223 47% 11%;
|
44 |
+
--muted-foreground: 215.4 16.3% 56.9%;
|
45 |
+
|
46 |
+
--accent: 216 34% 17%;
|
47 |
+
--accent-foreground: 210 40% 98%;
|
48 |
+
|
49 |
+
--popover: 224 71% 4%;
|
50 |
+
--popover-foreground: 215 20.2% 65.1%;
|
51 |
+
|
52 |
+
--border: 216 34% 17%;
|
53 |
+
--input: 216 34% 17%;
|
54 |
+
|
55 |
+
--card: 224 71% 4%;
|
56 |
+
--card-foreground: 213 31% 91%;
|
57 |
+
|
58 |
+
--primary: 210 40% 98%;
|
59 |
+
--primary-foreground: 222.2 47.4% 1.2%;
|
60 |
+
|
61 |
+
--secondary: 222.2 47.4% 11.2%;
|
62 |
+
--secondary-foreground: 210 40% 98%;
|
63 |
+
|
64 |
+
--destructive: 0 63% 31%;
|
65 |
+
--destructive-foreground: 210 40% 98%;
|
66 |
+
|
67 |
+
--ring: 216 34% 17%;
|
68 |
+
|
69 |
+
--radius: 0.5rem;
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
@layer base {
|
74 |
+
* {
|
75 |
+
@apply border-border;
|
76 |
+
}
|
77 |
+
body {
|
78 |
+
@apply bg-background text-foreground;
|
79 |
+
font-feature-settings: "rlig" 1, "calt" 1;
|
80 |
+
}
|
81 |
+
}
|
client/tailwind.config.js
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const { fontFamily } = require("tailwindcss/defaultTheme")
|
2 |
+
|
3 |
+
/** @type {import('tailwindcss').Config} */
|
4 |
+
module.exports = {
|
5 |
+
darkMode: ["class"],
|
6 |
+
content: [
|
7 |
+
"app/**/*.{ts,tsx}",
|
8 |
+
"components/**/*.{ts,tsx}",
|
9 |
+
'../public/index.html'
|
10 |
+
],
|
11 |
+
theme: {
|
12 |
+
container: {
|
13 |
+
center: true,
|
14 |
+
padding: "2rem",
|
15 |
+
screens: {
|
16 |
+
"2xl": "1400px",
|
17 |
+
},
|
18 |
+
},
|
19 |
+
extend: {
|
20 |
+
colors: {
|
21 |
+
border: "hsl(var(--border))",
|
22 |
+
input: "hsl(var(--input))",
|
23 |
+
ring: "hsl(var(--ring))",
|
24 |
+
background: "hsl(var(--background))",
|
25 |
+
foreground: "hsl(var(--foreground))",
|
26 |
+
primary: {
|
27 |
+
DEFAULT: "hsl(var(--primary))",
|
28 |
+
foreground: "hsl(var(--primary-foreground))",
|
29 |
+
},
|
30 |
+
secondary: {
|
31 |
+
DEFAULT: "hsl(var(--secondary))",
|
32 |
+
foreground: "hsl(var(--secondary-foreground))",
|
33 |
+
},
|
34 |
+
destructive: {
|
35 |
+
DEFAULT: "hsl(var(--destructive))",
|
36 |
+
foreground: "hsl(var(--destructive-foreground))",
|
37 |
+
},
|
38 |
+
muted: {
|
39 |
+
DEFAULT: "hsl(var(--muted))",
|
40 |
+
foreground: "hsl(var(--muted-foreground))",
|
41 |
+
},
|
42 |
+
accent: {
|
43 |
+
DEFAULT: "hsl(var(--accent))",
|
44 |
+
foreground: "hsl(var(--accent-foreground))",
|
45 |
+
},
|
46 |
+
popover: {
|
47 |
+
DEFAULT: "hsl(var(--popover))",
|
48 |
+
foreground: "hsl(var(--popover-foreground))",
|
49 |
+
},
|
50 |
+
card: {
|
51 |
+
DEFAULT: "hsl(var(--card))",
|
52 |
+
foreground: "hsl(var(--card-foreground))",
|
53 |
+
},
|
54 |
+
},
|
55 |
+
borderRadius: {
|
56 |
+
lg: `var(--radius)`,
|
57 |
+
md: `calc(var(--radius) - 2px)`,
|
58 |
+
sm: "calc(var(--radius) - 4px)",
|
59 |
+
},
|
60 |
+
fontFamily: {
|
61 |
+
sans: ["var(--font-sans)", ...fontFamily.sans],
|
62 |
+
},
|
63 |
+
fontSize: {
|
64 |
+
"5xs": "8px",
|
65 |
+
"4xs": "9px",
|
66 |
+
"3xs": "10px",
|
67 |
+
"2xs": "11px"
|
68 |
+
},
|
69 |
+
keyframes: {
|
70 |
+
"accordion-down": {
|
71 |
+
from: { height: "0" },
|
72 |
+
to: { height: "var(--radix-accordion-content-height)" },
|
73 |
+
},
|
74 |
+
"accordion-up": {
|
75 |
+
from: { height: "var(--radix-accordion-content-height)" },
|
76 |
+
to: { height: "0" },
|
77 |
+
},
|
78 |
+
},
|
79 |
+
animation: {
|
80 |
+
"accordion-down": "accordion-down 0.2s ease-out",
|
81 |
+
"accordion-up": "accordion-up 0.2s ease-out",
|
82 |
+
},
|
83 |
+
},
|
84 |
+
},
|
85 |
+
plugins: [require("tailwindcss-animate")],
|
86 |
+
}
|
client/tsconfig.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compilerOptions": {
|
3 |
+
// Enable latest features
|
4 |
+
"lib": ["ESNext", "DOM", "DOM.Iterable"],
|
5 |
+
"target": "ESNext",
|
6 |
+
"module": "ESNext",
|
7 |
+
"moduleDetection": "force",
|
8 |
+
"jsx": "react-jsx",
|
9 |
+
"allowJs": true,
|
10 |
+
|
11 |
+
// Bundler mode
|
12 |
+
"moduleResolution": "bundler",
|
13 |
+
"allowImportingTsExtensions": true,
|
14 |
+
"verbatimModuleSyntax": true,
|
15 |
+
"noEmit": true,
|
16 |
+
|
17 |
+
"baseUrl": ".",
|
18 |
+
"paths": {
|
19 |
+
"@/*": ["./src/*"]
|
20 |
+
},
|
21 |
+
|
22 |
+
// Best practices
|
23 |
+
"strict": true,
|
24 |
+
"skipLibCheck": true,
|
25 |
+
"noFallthroughCasesInSwitch": true,
|
26 |
+
|
27 |
+
// Some stricter flags (disabled by default)
|
28 |
+
"noUnusedLocals": false,
|
29 |
+
"noUnusedParameters": false,
|
30 |
+
"noPropertyAccessFromIndexSignature": false
|
31 |
+
}
|
32 |
+
}
|
engine.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import logging
|
3 |
+
import hashlib
|
4 |
+
import uuid
|
5 |
+
import os
|
6 |
+
import io
|
7 |
+
import shutil
|
8 |
+
import asyncio
|
9 |
+
import base64
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
from queue import Queue
|
12 |
+
from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple, Union
|
13 |
+
from functools import lru_cache
|
14 |
+
import av
|
15 |
+
import numpy as np
|
16 |
+
import cv2
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from liveportrait.config.argument_config import ArgumentConfig
|
22 |
+
from liveportrait.utils.camera import get_rotation_matrix
|
23 |
+
from liveportrait.utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
24 |
+
from liveportrait.utils.crop import prepare_paste_back, paste_back
|
25 |
+
|
26 |
+
# Configure logging
|
27 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
# Global constants
|
31 |
+
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
32 |
+
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
33 |
+
|
34 |
+
def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
|
35 |
+
"""
|
36 |
+
Convert a base64 data URI to a PIL Image.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
base64_string (str): The base64 encoded image data.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Image.Image: The decoded PIL Image.
|
43 |
+
"""
|
44 |
+
if ',' in base64_string:
|
45 |
+
base64_string = base64_string.split(',')[1]
|
46 |
+
img_data = base64.b64decode(base64_string)
|
47 |
+
return Image.open(io.BytesIO(img_data))
|
48 |
+
|
49 |
+
class Engine:
|
50 |
+
"""
|
51 |
+
The main engine class for FacePoke
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, live_portrait):
|
55 |
+
"""
|
56 |
+
Initialize the FacePoke engine with necessary models and processors.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
live_portrait (LivePortraitPipeline): The LivePortrait model for video generation.
|
60 |
+
"""
|
61 |
+
self.live_portrait = live_portrait
|
62 |
+
|
63 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
|
65 |
+
# cache for the "modify image" workflow
|
66 |
+
self.image_cache = {} # Stores the original images
|
67 |
+
self.processed_cache = {} # Stores the processed image data
|
68 |
+
|
69 |
+
logger.info("✅ FacePoke Engine initialized successfully.")
|
70 |
+
|
71 |
+
def get_image_hash(self, image: Union[Image.Image, str, bytes]) -> str:
|
72 |
+
"""
|
73 |
+
Compute or retrieve the hash for an image.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
image (Union[Image.Image, str, bytes]): The input image, either as a PIL Image,
|
77 |
+
base64 string, or bytes.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
str: The computed hash of the image.
|
81 |
+
"""
|
82 |
+
if isinstance(image, str):
|
83 |
+
# Assume it's already a hash if it's a string of the right length
|
84 |
+
if len(image) == 32:
|
85 |
+
return image
|
86 |
+
# Otherwise, assume it's a base64 string
|
87 |
+
image = base64_data_uri_to_PIL_Image(image)
|
88 |
+
|
89 |
+
if isinstance(image, Image.Image):
|
90 |
+
return hashlib.md5(image.tobytes()).hexdigest()
|
91 |
+
elif isinstance(image, bytes):
|
92 |
+
return hashlib.md5(image).hexdigest()
|
93 |
+
else:
|
94 |
+
raise ValueError("Unsupported image type")
|
95 |
+
|
96 |
+
@lru_cache(maxsize=128)
|
97 |
+
def _process_image(self, image_hash: str) -> Dict[str, Any]:
|
98 |
+
"""
|
99 |
+
Process the input image and cache the results.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
image_hash (str): Hash of the input image.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Dict[str, Any]: Processed image data.
|
106 |
+
"""
|
107 |
+
logger.info(f"Processing image with hash: {image_hash}")
|
108 |
+
if image_hash not in self.image_cache:
|
109 |
+
raise ValueError(f"Image with hash {image_hash} not found in cache")
|
110 |
+
|
111 |
+
image = self.image_cache[image_hash]
|
112 |
+
img_rgb = np.array(image)
|
113 |
+
|
114 |
+
inference_cfg = self.live_portrait.live_portrait_wrapper.cfg
|
115 |
+
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
116 |
+
crop_info = self.live_portrait.cropper.crop_single_image(img_rgb)
|
117 |
+
img_crop_256x256 = crop_info['img_crop_256x256']
|
118 |
+
|
119 |
+
I_s = self.live_portrait.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
120 |
+
x_s_info = self.live_portrait.live_portrait_wrapper.get_kp_info(I_s)
|
121 |
+
f_s = self.live_portrait.live_portrait_wrapper.extract_feature_3d(I_s)
|
122 |
+
x_s = self.live_portrait.live_portrait_wrapper.transform_keypoint(x_s_info)
|
123 |
+
|
124 |
+
processed_data = {
|
125 |
+
'img_rgb': img_rgb,
|
126 |
+
'crop_info': crop_info,
|
127 |
+
'x_s_info': x_s_info,
|
128 |
+
'f_s': f_s,
|
129 |
+
'x_s': x_s,
|
130 |
+
'inference_cfg': inference_cfg
|
131 |
+
}
|
132 |
+
|
133 |
+
self.processed_cache[image_hash] = processed_data
|
134 |
+
|
135 |
+
return processed_data
|
136 |
+
|
137 |
+
async def modify_image(self, image_or_hash: Union[Image.Image, str, bytes], params: Dict[str, float]) -> str:
|
138 |
+
"""
|
139 |
+
Modify the input image based on the provided parameters, using caching for efficiency
|
140 |
+
and outputting the result as a WebP image.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
image_or_hash (Union[Image.Image, str, bytes]): Input image as a PIL Image, base64-encoded string,
|
144 |
+
image bytes, or a hash string.
|
145 |
+
params (Dict[str, float]): Parameters for face transformation.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
str: Modified image as a base64-encoded WebP data URI.
|
149 |
+
|
150 |
+
Raises:
|
151 |
+
ValueError: If there's an error modifying the image or WebP is not supported.
|
152 |
+
"""
|
153 |
+
logger.info("Starting image modification")
|
154 |
+
logger.debug(f"Modification parameters: {params}")
|
155 |
+
|
156 |
+
try:
|
157 |
+
image_hash = self.get_image_hash(image_or_hash)
|
158 |
+
|
159 |
+
# If we don't have the image in cache yet, add it
|
160 |
+
if image_hash not in self.image_cache:
|
161 |
+
if isinstance(image_or_hash, (Image.Image, bytes)):
|
162 |
+
self.image_cache[image_hash] = image_or_hash
|
163 |
+
elif isinstance(image_or_hash, str) and len(image_or_hash) != 32:
|
164 |
+
# It's a base64 string, not a hash
|
165 |
+
self.image_cache[image_hash] = base64_data_uri_to_PIL_Image(image_or_hash)
|
166 |
+
else:
|
167 |
+
raise ValueError("Image not found in cache and no valid image provided")
|
168 |
+
|
169 |
+
# Process the image (this will use the cache if available)
|
170 |
+
if image_hash not in self.processed_cache:
|
171 |
+
processed_data = await asyncio.to_thread(self._process_image, image_hash)
|
172 |
+
else:
|
173 |
+
processed_data = self.processed_cache[image_hash]
|
174 |
+
|
175 |
+
# Apply modifications based on params
|
176 |
+
x_d_new = processed_data['x_s_info']['kp'].clone()
|
177 |
+
await self._apply_facial_modifications(x_d_new, params)
|
178 |
+
|
179 |
+
# Apply rotation
|
180 |
+
R_new = get_rotation_matrix(
|
181 |
+
processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0),
|
182 |
+
processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0),
|
183 |
+
processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0)
|
184 |
+
)
|
185 |
+
x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
|
186 |
+
|
187 |
+
# Apply stitching
|
188 |
+
x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new)
|
189 |
+
|
190 |
+
# Generate the output
|
191 |
+
out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new)
|
192 |
+
I_p = self.live_portrait.live_portrait_wrapper.parse_output(out['out'])[0]
|
193 |
+
|
194 |
+
# Paste back to full size
|
195 |
+
mask_ori = await asyncio.to_thread(
|
196 |
+
prepare_paste_back,
|
197 |
+
processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
|
198 |
+
dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
|
199 |
+
)
|
200 |
+
I_p_to_ori_blend = await asyncio.to_thread(
|
201 |
+
paste_back,
|
202 |
+
I_p, processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
|
203 |
+
)
|
204 |
+
|
205 |
+
# Convert the result to a PIL Image
|
206 |
+
result_image = Image.fromarray(I_p_to_ori_blend)
|
207 |
+
|
208 |
+
# Save as WebP
|
209 |
+
buffered = io.BytesIO()
|
210 |
+
result_image.save(buffered, format="WebP", quality=85) # Adjust quality as needed
|
211 |
+
modified_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
212 |
+
|
213 |
+
logger.info("Image modification completed successfully")
|
214 |
+
return f"data:image/webp;base64,{modified_image_base64}"
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
logger.error(f"Error in modify_image: {str(e)}")
|
218 |
+
logger.exception("Full traceback:")
|
219 |
+
raise ValueError(f"Failed to modify image: {str(e)}")
|
220 |
+
|
221 |
+
async def _apply_facial_modifications(self, x_d_new: torch.Tensor, params: Dict[str, float]) -> None:
|
222 |
+
"""
|
223 |
+
Apply facial modifications to the keypoints based on the provided parameters.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
x_d_new (torch.Tensor): Tensor of facial keypoints to be modified.
|
227 |
+
params (Dict[str, float]): Parameters for face transformation.
|
228 |
+
"""
|
229 |
+
modifications = [
|
230 |
+
('smile', [
|
231 |
+
(0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003),
|
232 |
+
(0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035)
|
233 |
+
]),
|
234 |
+
('aaa', [
|
235 |
+
(0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001)
|
236 |
+
]),
|
237 |
+
('eee', [
|
238 |
+
(0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001)
|
239 |
+
]),
|
240 |
+
('woo', [
|
241 |
+
(0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005)
|
242 |
+
]),
|
243 |
+
('wink', [
|
244 |
+
(0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003),
|
245 |
+
(0, 17, 1, 0.0003), (0, 3, 1, -0.0003)
|
246 |
+
]),
|
247 |
+
('pupil_x', [
|
248 |
+
(0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001),
|
249 |
+
(0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007)
|
250 |
+
]),
|
251 |
+
('pupil_y', [
|
252 |
+
(0, 11, 1, -0.001), (0, 15, 1, -0.001)
|
253 |
+
]),
|
254 |
+
('eyes', [
|
255 |
+
(0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003),
|
256 |
+
(0, 1, 1, -0.00025), (0, 2, 1, 0.00025)
|
257 |
+
]),
|
258 |
+
('eyebrow', [
|
259 |
+
(0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003),
|
260 |
+
(0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003),
|
261 |
+
(0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0),
|
262 |
+
(0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0)
|
263 |
+
])
|
264 |
+
]
|
265 |
+
|
266 |
+
for param_name, adjustments in modifications:
|
267 |
+
param_value = params.get(param_name, 0)
|
268 |
+
for i, j, k, factor in adjustments:
|
269 |
+
x_d_new[i, j, k] += param_value * factor
|
270 |
+
|
271 |
+
# Special case for pupil_y affecting eyes
|
272 |
+
x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001
|
273 |
+
x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001
|
274 |
+
params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2.
|
275 |
+
|
276 |
+
async def cleanup(self):
|
277 |
+
"""
|
278 |
+
Perform cleanup operations for the Engine.
|
279 |
+
This method should be called when shutting down the application.
|
280 |
+
"""
|
281 |
+
logger.info("Starting Engine cleanup")
|
282 |
+
try:
|
283 |
+
# TODO: Add any additional cleanup operations here
|
284 |
+
logger.info("Engine cleanup completed successfully")
|
285 |
+
except Exception as e:
|
286 |
+
logger.error(f"Error during Engine cleanup: {str(e)}")
|
287 |
+
logger.exception("Full traceback:")
|
288 |
+
|
289 |
+
def create_engine(models):
|
290 |
+
logger.info("Creating Engine instance...")
|
291 |
+
|
292 |
+
live_portrait = models
|
293 |
+
|
294 |
+
engine = Engine(
|
295 |
+
live_portrait=live_portrait,
|
296 |
+
# we might have more in the future
|
297 |
+
)
|
298 |
+
|
299 |
+
logger.info("Engine instance created successfully")
|
300 |
+
return engine
|
liveportrait/config/__init__.py
ADDED
File without changes
|
liveportrait/config/argument_config.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
config for user
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
import tyro
|
10 |
+
from typing_extensions import Annotated
|
11 |
+
from .base_config import PrintableConfig, make_abs_path
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
15 |
+
class ArgumentConfig(PrintableConfig):
|
16 |
+
########## input arguments ##########
|
17 |
+
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
|
18 |
+
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
|
19 |
+
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
|
20 |
+
#####################################
|
21 |
+
|
22 |
+
########## inference arguments ##########
|
23 |
+
device_id: int = 0
|
24 |
+
flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
25 |
+
flag_eye_retargeting: bool = False
|
26 |
+
flag_lip_retargeting: bool = False
|
27 |
+
flag_stitching: bool = True # we recommend setting it to True!
|
28 |
+
flag_relative: bool = True # whether to use relative motion
|
29 |
+
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
30 |
+
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
|
31 |
+
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
32 |
+
#########################################
|
33 |
+
|
34 |
+
########## crop arguments ##########
|
35 |
+
dsize: int = 512
|
36 |
+
scale: float = 2.3
|
37 |
+
vx_ratio: float = 0 # vx ratio
|
38 |
+
vy_ratio: float = -0.125 # vy ratio +up, -down
|
39 |
+
####################################
|
40 |
+
|
41 |
+
########## gradio arguments ##########
|
42 |
+
server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
|
43 |
+
share: bool = True
|
44 |
+
server_name: str = "0.0.0.0"
|
liveportrait/config/base_config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
pretty printing class
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import annotations
|
8 |
+
import os.path as osp
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
|
12 |
+
def make_abs_path(fn):
|
13 |
+
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
14 |
+
|
15 |
+
|
16 |
+
class PrintableConfig: # pylint: disable=too-few-public-methods
|
17 |
+
"""Printable Config defining str function"""
|
18 |
+
|
19 |
+
def __repr__(self):
|
20 |
+
lines = [self.__class__.__name__ + ":"]
|
21 |
+
for key, val in vars(self).items():
|
22 |
+
if isinstance(val, Tuple):
|
23 |
+
flattened_val = "["
|
24 |
+
for item in val:
|
25 |
+
flattened_val += str(item) + "\n"
|
26 |
+
flattened_val = flattened_val.rstrip("\n")
|
27 |
+
val = flattened_val + "]"
|
28 |
+
lines += f"{key}: {str(val)}".split("\n")
|
29 |
+
return "\n ".join(lines)
|
liveportrait/config/crop_config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
parameters used for crop faces
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Union, List
|
10 |
+
from .base_config import PrintableConfig
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
14 |
+
class CropConfig(PrintableConfig):
|
15 |
+
dsize: int = 512 # crop size
|
16 |
+
scale: float = 2.3 # scale factor
|
17 |
+
vx_ratio: float = 0 # vx ratio
|
18 |
+
vy_ratio: float = -0.125 # vy ratio +up, -down
|
liveportrait/config/inference_config.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
config dataclass used for inference
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import os.path as osp
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Literal, Tuple
|
11 |
+
from .base_config import PrintableConfig, make_abs_path
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
15 |
+
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
16 |
+
|
17 |
+
@dataclass(repr=False) # use repr from PrintableConfig
|
18 |
+
class InferenceConfig(PrintableConfig):
|
19 |
+
models_config: str = make_abs_path('./models.yaml') # portrait animation config
|
20 |
+
checkpoint_F = os.path.join(MODELS_DIR, "liveportrait", "appearance_feature_extractor.pth")
|
21 |
+
checkpoint_M = os.path.join(MODELS_DIR, "liveportrait", "motion_extractor.pth")
|
22 |
+
checkpoint_W = os.path.join(MODELS_DIR, "liveportrait", "warping_module.pth")
|
23 |
+
checkpoint_G = os.path.join(MODELS_DIR, "liveportrait", "spade_generator.pth")
|
24 |
+
checkpoint_S = os.path.join(MODELS_DIR, "liveportrait", "stitching_retargeting_module.pth")
|
25 |
+
|
26 |
+
flag_use_half_precision: bool = True # whether to use half precision
|
27 |
+
|
28 |
+
flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
|
29 |
+
lip_zero_threshold: float = 0.03
|
30 |
+
|
31 |
+
flag_eye_retargeting: bool = False
|
32 |
+
flag_lip_retargeting: bool = False
|
33 |
+
flag_stitching: bool = True # we recommend setting it to True!
|
34 |
+
|
35 |
+
flag_relative: bool = True # whether to use relative motion
|
36 |
+
anchor_frame: int = 0 # set this value if find_best_frame is True
|
37 |
+
|
38 |
+
input_shape: Tuple[int, int] = (256, 256) # input shape
|
39 |
+
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
|
40 |
+
output_fps: int = 25 # MuseTalk prefers 25 fps, so we use 25 as default fps for output video
|
41 |
+
crf: int = 15 # crf for output video
|
42 |
+
|
43 |
+
flag_write_result: bool = True # whether to write output video
|
44 |
+
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
|
45 |
+
mask_crop = None
|
46 |
+
flag_write_gif: bool = False
|
47 |
+
size_gif: int = 256
|
48 |
+
ref_max_shape: int = 1280
|
49 |
+
ref_shape_n: int = 2
|
50 |
+
|
51 |
+
device_id: int = 0
|
52 |
+
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
|
53 |
+
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
|
liveportrait/config/models.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
appearance_feature_extractor_params: # the F in the paper
|
3 |
+
image_channel: 3
|
4 |
+
block_expansion: 64
|
5 |
+
num_down_blocks: 2
|
6 |
+
max_features: 512
|
7 |
+
reshape_channel: 32
|
8 |
+
reshape_depth: 16
|
9 |
+
num_resblocks: 6
|
10 |
+
motion_extractor_params: # the M in the paper
|
11 |
+
num_kp: 21
|
12 |
+
backbone: convnextv2_tiny
|
13 |
+
warping_module_params: # the W in the paper
|
14 |
+
num_kp: 21
|
15 |
+
block_expansion: 64
|
16 |
+
max_features: 512
|
17 |
+
num_down_blocks: 2
|
18 |
+
reshape_channel: 32
|
19 |
+
estimate_occlusion_map: True
|
20 |
+
dense_motion_params:
|
21 |
+
block_expansion: 32
|
22 |
+
max_features: 1024
|
23 |
+
num_blocks: 5
|
24 |
+
reshape_depth: 16
|
25 |
+
compress: 4
|
26 |
+
spade_generator_params: # the G in the paper
|
27 |
+
upscale: 2 # represents upsample factor 256x256 -> 512x512
|
28 |
+
block_expansion: 64
|
29 |
+
max_features: 512
|
30 |
+
num_down_blocks: 2
|
31 |
+
stitching_retargeting_module_params: # the S in the paper
|
32 |
+
stitching:
|
33 |
+
input_size: 126 # (21*3)*2
|
34 |
+
hidden_sizes: [128, 128, 64]
|
35 |
+
output_size: 65 # (21*3)+2(tx,ty)
|
36 |
+
lip:
|
37 |
+
input_size: 65 # (21*3)+2
|
38 |
+
hidden_sizes: [128, 128, 64]
|
39 |
+
output_size: 63 # (21*3)
|
40 |
+
eye:
|
41 |
+
input_size: 66 # (21*3)+3
|
42 |
+
hidden_sizes: [256, 256, 128, 128, 64]
|
43 |
+
output_size: 63 # (21*3)
|
liveportrait/gradio_pipeline.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Pipeline for gradio
|
5 |
+
"""
|
6 |
+
import gradio as gr
|
7 |
+
from .config.argument_config import ArgumentConfig
|
8 |
+
from .live_portrait_pipeline import LivePortraitPipeline
|
9 |
+
from .utils.io import load_img_online
|
10 |
+
from .utils.rprint import rlog as log
|
11 |
+
from .utils.crop import prepare_paste_back, paste_back
|
12 |
+
from .utils.camera import get_rotation_matrix
|
13 |
+
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
14 |
+
|
15 |
+
def update_args(args, user_args):
|
16 |
+
"""update the args according to user inputs
|
17 |
+
"""
|
18 |
+
for k, v in user_args.items():
|
19 |
+
if hasattr(args, k):
|
20 |
+
setattr(args, k, v)
|
21 |
+
return args
|
22 |
+
|
23 |
+
class GradioPipeline(LivePortraitPipeline):
|
24 |
+
|
25 |
+
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
|
26 |
+
super().__init__(inference_cfg, crop_cfg)
|
27 |
+
# self.live_portrait_wrapper = self.live_portrait_wrapper
|
28 |
+
self.args = args
|
29 |
+
# for single image retargeting
|
30 |
+
self.start_prepare = False
|
31 |
+
self.f_s_user = None
|
32 |
+
self.x_c_s_info_user = None
|
33 |
+
self.x_s_user = None
|
34 |
+
self.source_lmk_user = None
|
35 |
+
self.mask_ori = None
|
36 |
+
self.img_rgb = None
|
37 |
+
self.crop_M_c2o = None
|
38 |
+
|
39 |
+
|
40 |
+
def execute_video(
|
41 |
+
self,
|
42 |
+
input_image_path,
|
43 |
+
input_video_path,
|
44 |
+
flag_relative_input,
|
45 |
+
flag_do_crop_input,
|
46 |
+
flag_remap_input,
|
47 |
+
):
|
48 |
+
""" for video driven potrait animation
|
49 |
+
"""
|
50 |
+
if input_image_path is not None and input_video_path is not None:
|
51 |
+
args_user = {
|
52 |
+
'source_image': input_image_path,
|
53 |
+
'driving_info': input_video_path,
|
54 |
+
'flag_relative': flag_relative_input,
|
55 |
+
'flag_do_crop': flag_do_crop_input,
|
56 |
+
'flag_pasteback': flag_remap_input,
|
57 |
+
}
|
58 |
+
# update config from user input
|
59 |
+
self.args = update_args(self.args, args_user)
|
60 |
+
self.live_portrait_wrapper.update_config(self.args.__dict__)
|
61 |
+
self.cropper.update_config(self.args.__dict__)
|
62 |
+
# video driven animation
|
63 |
+
video_path, video_path_concat = self.execute(self.args)
|
64 |
+
gr.Info("Run successfully!", duration=2)
|
65 |
+
return video_path, video_path_concat,
|
66 |
+
else:
|
67 |
+
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
|
68 |
+
|
69 |
+
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
|
70 |
+
""" for single image retargeting
|
71 |
+
"""
|
72 |
+
if input_eye_ratio is None or input_eye_ratio is None:
|
73 |
+
raise gr.Error("Invalid ratio input 💥!", duration=5)
|
74 |
+
elif self.f_s_user is None:
|
75 |
+
if self.start_prepare:
|
76 |
+
raise gr.Error(
|
77 |
+
"The source portrait is under processing 💥! Please wait for a second.",
|
78 |
+
duration=5
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
raise gr.Error(
|
82 |
+
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
|
83 |
+
duration=5
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
87 |
+
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
|
88 |
+
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
|
89 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
90 |
+
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
|
91 |
+
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
|
92 |
+
num_kp = self.x_s_user.shape[1]
|
93 |
+
# default: use x_s
|
94 |
+
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
|
95 |
+
# D(W(f_s; x_s, x′_d))
|
96 |
+
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
|
97 |
+
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
98 |
+
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
|
99 |
+
gr.Info("Run successfully!", duration=2)
|
100 |
+
return out, out_to_ori_blend
|
101 |
+
|
102 |
+
|
103 |
+
def prepare_retargeting(self, input_image_path, flag_do_crop = True):
|
104 |
+
""" for single image retargeting
|
105 |
+
"""
|
106 |
+
if input_image_path is not None:
|
107 |
+
gr.Info("Upload successfully!", duration=2)
|
108 |
+
self.start_prepare = True
|
109 |
+
inference_cfg = self.live_portrait_wrapper.cfg
|
110 |
+
######## process source portrait ########
|
111 |
+
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
|
112 |
+
log(f"Load source image from {input_image_path}.")
|
113 |
+
crop_info = self.cropper.crop_single_image(img_rgb)
|
114 |
+
if flag_do_crop:
|
115 |
+
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
|
116 |
+
else:
|
117 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
118 |
+
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
119 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
120 |
+
############################################
|
121 |
+
|
122 |
+
# record global info for next time use
|
123 |
+
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
124 |
+
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
125 |
+
self.x_s_info_user = x_s_info
|
126 |
+
self.source_lmk_user = crop_info['lmk_crop']
|
127 |
+
self.img_rgb = img_rgb
|
128 |
+
self.crop_M_c2o = crop_info['M_c2o']
|
129 |
+
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
130 |
+
# update slider
|
131 |
+
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
|
132 |
+
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
|
133 |
+
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
|
134 |
+
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
|
135 |
+
# for vis
|
136 |
+
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
|
137 |
+
return eye_close_ratio, lip_close_ratio, self.I_s_vis
|
138 |
+
else:
|
139 |
+
# when press the clear button, go here
|
140 |
+
return 0.8, 0.8, self.I_s_vis
|
liveportrait/live_portrait_pipeline.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Pipeline of LivePortrait
|
5 |
+
"""
|
6 |
+
|
7 |
+
# TODO:
|
8 |
+
# 1. 当前假定所有的模板都是已经裁好的,需要修改下
|
9 |
+
# 2. pick样例图 source + driving
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import pickle
|
14 |
+
import os.path as osp
|
15 |
+
from rich.progress import track
|
16 |
+
|
17 |
+
from .config.argument_config import ArgumentConfig
|
18 |
+
from .config.inference_config import InferenceConfig
|
19 |
+
from .config.crop_config import CropConfig
|
20 |
+
from .utils.cropper import Cropper
|
21 |
+
from .utils.camera import get_rotation_matrix
|
22 |
+
from .utils.video import images2video, concat_frames
|
23 |
+
from .utils.crop import _transform_img, prepare_paste_back, paste_back
|
24 |
+
from .utils.retargeting_utils import calc_lip_close_ratio
|
25 |
+
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
|
26 |
+
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
|
27 |
+
from .utils.rprint import rlog as log
|
28 |
+
from .live_portrait_wrapper import LivePortraitWrapper
|
29 |
+
|
30 |
+
|
31 |
+
def make_abs_path(fn):
|
32 |
+
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
33 |
+
|
34 |
+
|
35 |
+
class LivePortraitPipeline(object):
|
36 |
+
|
37 |
+
def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
|
38 |
+
self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
|
39 |
+
self.cropper = Cropper(crop_cfg=crop_cfg)
|
40 |
+
|
41 |
+
def execute(self, args: ArgumentConfig):
|
42 |
+
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
|
43 |
+
######## process source portrait ########
|
44 |
+
img_rgb = load_image_rgb(args.source_image)
|
45 |
+
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
|
46 |
+
log(f"Load source image from {args.source_image}")
|
47 |
+
crop_info = self.cropper.crop_single_image(img_rgb)
|
48 |
+
source_lmk = crop_info['lmk_crop']
|
49 |
+
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
50 |
+
if inference_cfg.flag_do_crop:
|
51 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
|
52 |
+
else:
|
53 |
+
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
|
54 |
+
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
|
55 |
+
x_c_s = x_s_info['kp']
|
56 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
57 |
+
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
|
58 |
+
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
|
59 |
+
|
60 |
+
if inference_cfg.flag_lip_zero:
|
61 |
+
# let lip-open scalar to be 0 at first
|
62 |
+
c_d_lip_before_animation = [0.]
|
63 |
+
combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
64 |
+
if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
|
65 |
+
inference_cfg.flag_lip_zero = False
|
66 |
+
else:
|
67 |
+
lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
68 |
+
############################################
|
69 |
+
|
70 |
+
######## process driving info ########
|
71 |
+
if is_video(args.driving_info):
|
72 |
+
log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
|
73 |
+
# TODO: 这里track一下驱动视频 -> 构建模板
|
74 |
+
driving_rgb_lst = load_driving_info(args.driving_info)
|
75 |
+
driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
|
76 |
+
I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
|
77 |
+
n_frames = I_d_lst.shape[0]
|
78 |
+
if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
|
79 |
+
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
|
80 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
81 |
+
elif is_template(args.driving_info):
|
82 |
+
log(f"Load from video templates {args.driving_info}")
|
83 |
+
with open(args.driving_info, 'rb') as f:
|
84 |
+
template_lst, driving_lmk_lst = pickle.load(f)
|
85 |
+
n_frames = template_lst[0]['n_frames']
|
86 |
+
input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
|
87 |
+
else:
|
88 |
+
raise Exception("Unsupported driving types!")
|
89 |
+
#########################################
|
90 |
+
|
91 |
+
######## prepare for pasteback ########
|
92 |
+
if inference_cfg.flag_pasteback:
|
93 |
+
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
94 |
+
I_p_paste_lst = []
|
95 |
+
#########################################
|
96 |
+
|
97 |
+
I_p_lst = []
|
98 |
+
R_d_0, x_d_0_info = None, None
|
99 |
+
for i in track(range(n_frames), description='Animating...', total=n_frames):
|
100 |
+
if is_video(args.driving_info):
|
101 |
+
# extract kp info by M
|
102 |
+
I_d_i = I_d_lst[i]
|
103 |
+
x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
|
104 |
+
R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
|
105 |
+
else:
|
106 |
+
# from template
|
107 |
+
x_d_i_info = template_lst[i]
|
108 |
+
x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
|
109 |
+
R_d_i = x_d_i_info['R_d']
|
110 |
+
|
111 |
+
if i == 0:
|
112 |
+
R_d_0 = R_d_i
|
113 |
+
x_d_0_info = x_d_i_info
|
114 |
+
|
115 |
+
if inference_cfg.flag_relative:
|
116 |
+
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
117 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
118 |
+
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
119 |
+
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
120 |
+
else:
|
121 |
+
R_new = R_d_i
|
122 |
+
delta_new = x_d_i_info['exp']
|
123 |
+
scale_new = x_s_info['scale']
|
124 |
+
t_new = x_d_i_info['t']
|
125 |
+
|
126 |
+
t_new[..., 2].fill_(0) # zero tz
|
127 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
128 |
+
|
129 |
+
# Algorithm 1:
|
130 |
+
if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
131 |
+
# without stitching or retargeting
|
132 |
+
if inference_cfg.flag_lip_zero:
|
133 |
+
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
134 |
+
else:
|
135 |
+
pass
|
136 |
+
elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
|
137 |
+
# with stitching and without retargeting
|
138 |
+
if inference_cfg.flag_lip_zero:
|
139 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
140 |
+
else:
|
141 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
142 |
+
else:
|
143 |
+
eyes_delta, lip_delta = None, None
|
144 |
+
if inference_cfg.flag_eye_retargeting:
|
145 |
+
c_d_eyes_i = input_eye_ratio_lst[i]
|
146 |
+
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
147 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
148 |
+
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
|
149 |
+
if inference_cfg.flag_lip_retargeting:
|
150 |
+
c_d_lip_i = input_lip_ratio_lst[i]
|
151 |
+
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
152 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
153 |
+
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
|
154 |
+
|
155 |
+
if inference_cfg.flag_relative: # use x_s
|
156 |
+
x_d_i_new = x_s + \
|
157 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
158 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
159 |
+
else: # use x_d,i
|
160 |
+
x_d_i_new = x_d_i_new + \
|
161 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
162 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
163 |
+
|
164 |
+
if inference_cfg.flag_stitching:
|
165 |
+
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
166 |
+
|
167 |
+
out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
168 |
+
I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
|
169 |
+
I_p_lst.append(I_p_i)
|
170 |
+
|
171 |
+
if inference_cfg.flag_pasteback:
|
172 |
+
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
|
173 |
+
I_p_paste_lst.append(I_p_i_to_ori_blend)
|
174 |
+
|
175 |
+
mkdir(args.output_dir)
|
176 |
+
wfp_concat = None
|
177 |
+
|
178 |
+
# note by @jbilcke-hf:
|
179 |
+
# I have disabled this block, since we don't need to debug it
|
180 |
+
#if is_video(args.driving_info):
|
181 |
+
# frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
|
182 |
+
# # save (driving frames, source image, drived frames) result
|
183 |
+
# wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
|
184 |
+
# images2video(frames_concatenated, wfp=wfp_concat)#
|
185 |
+
|
186 |
+
# save drived result
|
187 |
+
wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
|
188 |
+
if inference_cfg.flag_pasteback:
|
189 |
+
images2video(I_p_paste_lst, wfp=wfp)
|
190 |
+
else:
|
191 |
+
images2video(I_p_lst, wfp=wfp)
|
192 |
+
|
193 |
+
return wfp, wfp_concat
|
liveportrait/live_portrait_wrapper.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Wrapper for LivePortrait core functions
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
from .utils.timer import Timer
|
14 |
+
from .utils.helper import load_model, concat_feat
|
15 |
+
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
16 |
+
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
17 |
+
from .config.inference_config import InferenceConfig
|
18 |
+
from .utils.rprint import rlog as log
|
19 |
+
|
20 |
+
|
21 |
+
class LivePortraitWrapper(object):
|
22 |
+
|
23 |
+
def __init__(self, cfg: InferenceConfig):
|
24 |
+
|
25 |
+
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
26 |
+
|
27 |
+
# init F
|
28 |
+
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
|
29 |
+
#log(f'Load appearance_feature_extractor done.')
|
30 |
+
# init M
|
31 |
+
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
|
32 |
+
#log(f'Load motion_extractor done.')
|
33 |
+
# init W
|
34 |
+
self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
|
35 |
+
#log(f'Load warping_module done.')
|
36 |
+
# init G
|
37 |
+
self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
|
38 |
+
#log(f'Load spade_generator done.')
|
39 |
+
# init S and R
|
40 |
+
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
|
41 |
+
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
|
42 |
+
#log(f'Load stitching_retargeting_module done.')
|
43 |
+
else:
|
44 |
+
self.stitching_retargeting_module = None
|
45 |
+
|
46 |
+
self.cfg = cfg
|
47 |
+
self.device_id = cfg.device_id
|
48 |
+
self.timer = Timer()
|
49 |
+
|
50 |
+
def update_config(self, user_args):
|
51 |
+
for k, v in user_args.items():
|
52 |
+
if hasattr(self.cfg, k):
|
53 |
+
setattr(self.cfg, k, v)
|
54 |
+
|
55 |
+
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
56 |
+
""" construct the input as standard
|
57 |
+
img: HxWx3, uint8, 256x256
|
58 |
+
"""
|
59 |
+
h, w = img.shape[:2]
|
60 |
+
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
|
61 |
+
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
|
62 |
+
else:
|
63 |
+
x = img.copy()
|
64 |
+
|
65 |
+
if x.ndim == 3:
|
66 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
67 |
+
elif x.ndim == 4:
|
68 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
69 |
+
else:
|
70 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
71 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
72 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
73 |
+
x = x.cuda(self.device_id)
|
74 |
+
return x
|
75 |
+
|
76 |
+
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
77 |
+
""" construct the input as standard
|
78 |
+
imgs: NxBxHxWx3, uint8
|
79 |
+
"""
|
80 |
+
if isinstance(imgs, list):
|
81 |
+
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
82 |
+
elif isinstance(imgs, np.ndarray):
|
83 |
+
_imgs = imgs
|
84 |
+
else:
|
85 |
+
raise ValueError(f'imgs type error: {type(imgs)}')
|
86 |
+
|
87 |
+
y = _imgs.astype(np.float32) / 255.
|
88 |
+
y = np.clip(y, 0, 1) # clip to 0~1
|
89 |
+
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
90 |
+
y = y.cuda(self.device_id)
|
91 |
+
|
92 |
+
return y
|
93 |
+
|
94 |
+
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
""" get the appearance feature of the image by F
|
96 |
+
x: Bx3xHxW, normalized to 0~1
|
97 |
+
"""
|
98 |
+
with torch.no_grad():
|
99 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
100 |
+
feature_3d = self.appearance_feature_extractor(x)
|
101 |
+
|
102 |
+
return feature_3d.float()
|
103 |
+
|
104 |
+
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
|
105 |
+
""" get the implicit keypoint information
|
106 |
+
x: Bx3xHxW, normalized to 0~1
|
107 |
+
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
108 |
+
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
109 |
+
"""
|
110 |
+
with torch.no_grad():
|
111 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
112 |
+
kp_info = self.motion_extractor(x)
|
113 |
+
|
114 |
+
if self.cfg.flag_use_half_precision:
|
115 |
+
# float the dict
|
116 |
+
for k, v in kp_info.items():
|
117 |
+
if isinstance(v, torch.Tensor):
|
118 |
+
kp_info[k] = v.float()
|
119 |
+
|
120 |
+
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
|
121 |
+
if flag_refine_info:
|
122 |
+
bs = kp_info['kp'].shape[0]
|
123 |
+
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
|
124 |
+
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
|
125 |
+
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
|
126 |
+
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
127 |
+
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
|
128 |
+
|
129 |
+
return kp_info
|
130 |
+
|
131 |
+
def get_pose_dct(self, kp_info: dict) -> dict:
|
132 |
+
pose_dct = dict(
|
133 |
+
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
|
134 |
+
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
|
135 |
+
roll=headpose_pred_to_degree(kp_info['roll']).item(),
|
136 |
+
)
|
137 |
+
return pose_dct
|
138 |
+
|
139 |
+
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
|
140 |
+
|
141 |
+
# get the canonical keypoints of source image by M
|
142 |
+
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
|
143 |
+
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
|
144 |
+
|
145 |
+
# get the canonical keypoints of first driving frame by M
|
146 |
+
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
|
147 |
+
driving_first_frame_rotation = get_rotation_matrix(
|
148 |
+
driving_first_frame_kp_info['pitch'],
|
149 |
+
driving_first_frame_kp_info['yaw'],
|
150 |
+
driving_first_frame_kp_info['roll']
|
151 |
+
)
|
152 |
+
|
153 |
+
# get feature volume by F
|
154 |
+
source_feature_3d = self.extract_feature_3d(source_prepared)
|
155 |
+
|
156 |
+
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
|
157 |
+
|
158 |
+
def transform_keypoint(self, kp_info: dict):
|
159 |
+
"""
|
160 |
+
transform the implicit keypoints with the pose, shift, and expression deformation
|
161 |
+
kp: BxNx3
|
162 |
+
"""
|
163 |
+
kp = kp_info['kp'] # (bs, k, 3)
|
164 |
+
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
165 |
+
|
166 |
+
t, exp = kp_info['t'], kp_info['exp']
|
167 |
+
scale = kp_info['scale']
|
168 |
+
|
169 |
+
pitch = headpose_pred_to_degree(pitch)
|
170 |
+
yaw = headpose_pred_to_degree(yaw)
|
171 |
+
roll = headpose_pred_to_degree(roll)
|
172 |
+
|
173 |
+
bs = kp.shape[0]
|
174 |
+
if kp.ndim == 2:
|
175 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
176 |
+
else:
|
177 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
178 |
+
|
179 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
180 |
+
|
181 |
+
# Eqn.2: s * (R * x_c,s + exp) + t
|
182 |
+
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
|
183 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
184 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
185 |
+
|
186 |
+
return kp_transformed
|
187 |
+
|
188 |
+
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
|
189 |
+
"""
|
190 |
+
kp_source: BxNx3
|
191 |
+
eye_close_ratio: Bx3
|
192 |
+
Return: Bx(3*num_kp+2)
|
193 |
+
"""
|
194 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
195 |
+
|
196 |
+
with torch.no_grad():
|
197 |
+
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
198 |
+
|
199 |
+
return delta
|
200 |
+
|
201 |
+
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
202 |
+
"""
|
203 |
+
kp_source: BxNx3
|
204 |
+
lip_close_ratio: Bx2
|
205 |
+
"""
|
206 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
210 |
+
|
211 |
+
return delta
|
212 |
+
|
213 |
+
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
214 |
+
"""
|
215 |
+
kp_source: BxNx3
|
216 |
+
kp_driving: BxNx3
|
217 |
+
Return: Bx(3*num_kp+2)
|
218 |
+
"""
|
219 |
+
feat_stiching = concat_feat(kp_source, kp_driving)
|
220 |
+
|
221 |
+
with torch.no_grad():
|
222 |
+
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
|
223 |
+
|
224 |
+
return delta
|
225 |
+
|
226 |
+
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
227 |
+
""" conduct the stitching
|
228 |
+
kp_source: Bxnum_kpx3
|
229 |
+
kp_driving: Bxnum_kpx3
|
230 |
+
"""
|
231 |
+
|
232 |
+
if self.stitching_retargeting_module is not None:
|
233 |
+
|
234 |
+
bs, num_kp = kp_source.shape[:2]
|
235 |
+
|
236 |
+
kp_driving_new = kp_driving.clone()
|
237 |
+
delta = self.stitch(kp_source, kp_driving_new)
|
238 |
+
|
239 |
+
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
240 |
+
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
241 |
+
|
242 |
+
kp_driving_new += delta_exp
|
243 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
244 |
+
|
245 |
+
return kp_driving_new
|
246 |
+
|
247 |
+
return kp_driving
|
248 |
+
|
249 |
+
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
250 |
+
""" get the image after the warping of the implicit keypoints
|
251 |
+
feature_3d: Bx32x16x64x64, feature volume
|
252 |
+
kp_source: BxNx3
|
253 |
+
kp_driving: BxNx3
|
254 |
+
"""
|
255 |
+
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
256 |
+
with torch.no_grad():
|
257 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
|
258 |
+
# get decoder input
|
259 |
+
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
260 |
+
# decode
|
261 |
+
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
262 |
+
|
263 |
+
# float the dict
|
264 |
+
if self.cfg.flag_use_half_precision:
|
265 |
+
for k, v in ret_dct.items():
|
266 |
+
if isinstance(v, torch.Tensor):
|
267 |
+
ret_dct[k] = v.float()
|
268 |
+
|
269 |
+
return ret_dct
|
270 |
+
|
271 |
+
def parse_output(self, out: torch.Tensor) -> np.ndarray:
|
272 |
+
""" construct the output as standard
|
273 |
+
return: 1xHxWx3, uint8
|
274 |
+
"""
|
275 |
+
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
|
276 |
+
out = np.clip(out, 0, 1) # clip to 0~1
|
277 |
+
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
|
278 |
+
|
279 |
+
return out
|
280 |
+
|
281 |
+
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
|
282 |
+
input_eye_ratio_lst = []
|
283 |
+
input_lip_ratio_lst = []
|
284 |
+
for lmk in driving_lmk_lst:
|
285 |
+
# for eyes retargeting
|
286 |
+
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
287 |
+
# for lip retargeting
|
288 |
+
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
289 |
+
return input_eye_ratio_lst, input_lip_ratio_lst
|
290 |
+
|
291 |
+
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
|
292 |
+
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
|
293 |
+
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
|
294 |
+
input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
|
295 |
+
# [c_s,eyes, c_d,eyes,i]
|
296 |
+
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
297 |
+
return combined_eye_ratio_tensor
|
298 |
+
|
299 |
+
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
|
300 |
+
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
|
301 |
+
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
|
302 |
+
# [c_s,lip, c_d,lip,i]
|
303 |
+
input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
|
304 |
+
if input_lip_ratio_tensor.shape != [1, 1]:
|
305 |
+
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
|
306 |
+
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
307 |
+
return combined_lip_ratio_tensor
|
liveportrait/modules/__init__.py
ADDED
File without changes
|
liveportrait/modules/appearance_feature_extractor.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from .util import SameBlock2d, DownBlock2d, ResBlock3d
|
10 |
+
|
11 |
+
|
12 |
+
class AppearanceFeatureExtractor(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
|
15 |
+
super(AppearanceFeatureExtractor, self).__init__()
|
16 |
+
self.image_channel = image_channel
|
17 |
+
self.block_expansion = block_expansion
|
18 |
+
self.num_down_blocks = num_down_blocks
|
19 |
+
self.max_features = max_features
|
20 |
+
self.reshape_channel = reshape_channel
|
21 |
+
self.reshape_depth = reshape_depth
|
22 |
+
|
23 |
+
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
|
24 |
+
|
25 |
+
down_blocks = []
|
26 |
+
for i in range(num_down_blocks):
|
27 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
28 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
29 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
30 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
31 |
+
|
32 |
+
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
|
33 |
+
|
34 |
+
self.resblocks_3d = torch.nn.Sequential()
|
35 |
+
for i in range(num_resblocks):
|
36 |
+
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
|
37 |
+
|
38 |
+
def forward(self, source_image):
|
39 |
+
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
|
40 |
+
|
41 |
+
for i in range(len(self.down_blocks)):
|
42 |
+
out = self.down_blocks[i](out)
|
43 |
+
out = self.second(out)
|
44 |
+
bs, c, h, w = out.shape # ->Bx512x64x64
|
45 |
+
|
46 |
+
f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
|
47 |
+
f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
|
48 |
+
return f_s
|
liveportrait/modules/convnextv2.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
# from timm.models.layers import trunc_normal_, DropPath
|
10 |
+
from .util import LayerNorm, DropPath, trunc_normal_, GRN
|
11 |
+
|
12 |
+
__all__ = ['convnextv2_tiny']
|
13 |
+
|
14 |
+
|
15 |
+
class Block(nn.Module):
|
16 |
+
""" ConvNeXtV2 Block.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dim (int): Number of input channels.
|
20 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, dim, drop_path=0.):
|
24 |
+
super().__init__()
|
25 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
26 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
27 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
28 |
+
self.act = nn.GELU()
|
29 |
+
self.grn = GRN(4 * dim)
|
30 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
31 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
input = x
|
35 |
+
x = self.dwconv(x)
|
36 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
37 |
+
x = self.norm(x)
|
38 |
+
x = self.pwconv1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.grn(x)
|
41 |
+
x = self.pwconv2(x)
|
42 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
43 |
+
|
44 |
+
x = input + self.drop_path(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class ConvNeXtV2(nn.Module):
|
49 |
+
""" ConvNeXt V2
|
50 |
+
|
51 |
+
Args:
|
52 |
+
in_chans (int): Number of input image channels. Default: 3
|
53 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
54 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
55 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
56 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
57 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_chans=3,
|
63 |
+
depths=[3, 3, 9, 3],
|
64 |
+
dims=[96, 192, 384, 768],
|
65 |
+
drop_path_rate=0.,
|
66 |
+
**kwargs
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.depths = depths
|
70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
71 |
+
stem = nn.Sequential(
|
72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
74 |
+
)
|
75 |
+
self.downsample_layers.append(stem)
|
76 |
+
for i in range(3):
|
77 |
+
downsample_layer = nn.Sequential(
|
78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
80 |
+
)
|
81 |
+
self.downsample_layers.append(downsample_layer)
|
82 |
+
|
83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
84 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
85 |
+
cur = 0
|
86 |
+
for i in range(4):
|
87 |
+
stage = nn.Sequential(
|
88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
|
89 |
+
)
|
90 |
+
self.stages.append(stage)
|
91 |
+
cur += depths[i]
|
92 |
+
|
93 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
94 |
+
|
95 |
+
# NOTE: the output semantic items
|
96 |
+
num_bins = kwargs.get('num_bins', 66)
|
97 |
+
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
|
98 |
+
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
|
99 |
+
|
100 |
+
# print('dims[-1]: ', dims[-1])
|
101 |
+
self.fc_scale = nn.Linear(dims[-1], 1) # scale
|
102 |
+
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
|
103 |
+
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
|
104 |
+
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
|
105 |
+
self.fc_t = nn.Linear(dims[-1], 3) # translation
|
106 |
+
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
|
107 |
+
|
108 |
+
def _init_weights(self, m):
|
109 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
110 |
+
trunc_normal_(m.weight, std=.02)
|
111 |
+
nn.init.constant_(m.bias, 0)
|
112 |
+
|
113 |
+
def forward_features(self, x):
|
114 |
+
for i in range(4):
|
115 |
+
x = self.downsample_layers[i](x)
|
116 |
+
x = self.stages[i](x)
|
117 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.forward_features(x)
|
121 |
+
|
122 |
+
# implicit keypoints
|
123 |
+
kp = self.fc_kp(x)
|
124 |
+
|
125 |
+
# pose and expression deformation
|
126 |
+
pitch = self.fc_pitch(x)
|
127 |
+
yaw = self.fc_yaw(x)
|
128 |
+
roll = self.fc_roll(x)
|
129 |
+
t = self.fc_t(x)
|
130 |
+
exp = self.fc_exp(x)
|
131 |
+
scale = self.fc_scale(x)
|
132 |
+
|
133 |
+
ret_dct = {
|
134 |
+
'pitch': pitch,
|
135 |
+
'yaw': yaw,
|
136 |
+
'roll': roll,
|
137 |
+
't': t,
|
138 |
+
'exp': exp,
|
139 |
+
'scale': scale,
|
140 |
+
|
141 |
+
'kp': kp, # canonical keypoint
|
142 |
+
}
|
143 |
+
|
144 |
+
return ret_dct
|
145 |
+
|
146 |
+
|
147 |
+
def convnextv2_tiny(**kwargs):
|
148 |
+
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
149 |
+
return model
|
liveportrait/modules/dense_motion.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
5 |
+
"""
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch
|
10 |
+
from .util import Hourglass, make_coordinate_grid, kp2gaussian
|
11 |
+
|
12 |
+
|
13 |
+
class DenseMotionNetwork(nn.Module):
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
|
15 |
+
super(DenseMotionNetwork, self).__init__()
|
16 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
|
17 |
+
|
18 |
+
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
|
19 |
+
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
|
20 |
+
self.norm = nn.BatchNorm3d(compress, affine=True)
|
21 |
+
self.num_kp = num_kp
|
22 |
+
self.flag_estimate_occlusion_map = estimate_occlusion_map
|
23 |
+
|
24 |
+
if self.flag_estimate_occlusion_map:
|
25 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
|
26 |
+
else:
|
27 |
+
self.occlusion = None
|
28 |
+
|
29 |
+
def create_sparse_motions(self, feature, kp_driving, kp_source):
|
30 |
+
bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
|
31 |
+
identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
|
32 |
+
identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
|
33 |
+
coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
|
34 |
+
|
35 |
+
k = coordinate_grid.shape[1]
|
36 |
+
|
37 |
+
# NOTE: there lacks an one-order flow
|
38 |
+
driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
|
39 |
+
|
40 |
+
# adding background feature
|
41 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
|
42 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
|
43 |
+
return sparse_motions
|
44 |
+
|
45 |
+
def create_deformed_feature(self, feature, sparse_motions):
|
46 |
+
bs, _, d, h, w = feature.shape
|
47 |
+
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
|
48 |
+
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
|
49 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
|
50 |
+
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
|
51 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
|
52 |
+
|
53 |
+
return sparse_deformed
|
54 |
+
|
55 |
+
def create_heatmap_representations(self, feature, kp_driving, kp_source):
|
56 |
+
spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
|
57 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
58 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
59 |
+
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
60 |
+
|
61 |
+
# adding background feature
|
62 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
|
63 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
64 |
+
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
65 |
+
return heatmap
|
66 |
+
|
67 |
+
def forward(self, feature, kp_driving, kp_source):
|
68 |
+
bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
|
69 |
+
|
70 |
+
feature = self.compress(feature) # (bs, 4, 16, 64, 64)
|
71 |
+
feature = self.norm(feature) # (bs, 4, 16, 64, 64)
|
72 |
+
feature = F.relu(feature) # (bs, 4, 16, 64, 64)
|
73 |
+
|
74 |
+
out_dict = dict()
|
75 |
+
|
76 |
+
# 1. deform 3d feature
|
77 |
+
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
|
78 |
+
deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
|
79 |
+
|
80 |
+
# 2. (bs, 1+num_kp, d, h, w)
|
81 |
+
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
|
82 |
+
|
83 |
+
input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
|
84 |
+
input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
|
85 |
+
|
86 |
+
prediction = self.hourglass(input)
|
87 |
+
|
88 |
+
mask = self.mask(prediction)
|
89 |
+
mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
|
90 |
+
out_dict['mask'] = mask
|
91 |
+
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
|
92 |
+
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
|
93 |
+
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
|
94 |
+
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
|
95 |
+
|
96 |
+
out_dict['deformation'] = deformation
|
97 |
+
|
98 |
+
if self.flag_estimate_occlusion_map:
|
99 |
+
bs, _, d, h, w = prediction.shape
|
100 |
+
prediction_reshape = prediction.view(bs, -1, h, w)
|
101 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
|
102 |
+
out_dict['occlusion_map'] = occlusion_map
|
103 |
+
|
104 |
+
return out_dict
|
liveportrait/modules/motion_extractor.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
|
5 |
+
"""
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .convnextv2 import convnextv2_tiny
|
11 |
+
from .util import filter_state_dict
|
12 |
+
|
13 |
+
model_dict = {
|
14 |
+
'convnextv2_tiny': convnextv2_tiny,
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
class MotionExtractor(nn.Module):
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
super(MotionExtractor, self).__init__()
|
21 |
+
|
22 |
+
# default is convnextv2_base
|
23 |
+
backbone = kwargs.get('backbone', 'convnextv2_tiny')
|
24 |
+
self.detector = model_dict.get(backbone)(**kwargs)
|
25 |
+
|
26 |
+
def load_pretrained(self, init_path: str):
|
27 |
+
if init_path not in (None, ''):
|
28 |
+
state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
|
29 |
+
state_dict = filter_state_dict(state_dict, remove_name='head')
|
30 |
+
ret = self.detector.load_state_dict(state_dict, strict=False)
|
31 |
+
print(f'Load pretrained model from {init_path}, ret: {ret}')
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
out = self.detector(x)
|
35 |
+
return out
|
liveportrait/modules/spade_generator.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .util import SPADEResnetBlock
|
11 |
+
|
12 |
+
|
13 |
+
class SPADEDecoder(nn.Module):
|
14 |
+
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
|
15 |
+
for i in range(num_down_blocks):
|
16 |
+
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
|
17 |
+
self.upscale = upscale
|
18 |
+
super().__init__()
|
19 |
+
norm_G = 'spadespectralinstance'
|
20 |
+
label_num_channels = input_channels # 256
|
21 |
+
|
22 |
+
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
|
23 |
+
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
24 |
+
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
25 |
+
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
26 |
+
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
27 |
+
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
28 |
+
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
29 |
+
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
|
30 |
+
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
|
31 |
+
self.up = nn.Upsample(scale_factor=2)
|
32 |
+
|
33 |
+
if self.upscale is None or self.upscale <= 1:
|
34 |
+
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
|
35 |
+
else:
|
36 |
+
self.conv_img = nn.Sequential(
|
37 |
+
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
|
38 |
+
nn.PixelShuffle(upscale_factor=2)
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, feature):
|
42 |
+
seg = feature # Bx256x64x64
|
43 |
+
x = self.fc(feature) # Bx512x64x64
|
44 |
+
x = self.G_middle_0(x, seg)
|
45 |
+
x = self.G_middle_1(x, seg)
|
46 |
+
x = self.G_middle_2(x, seg)
|
47 |
+
x = self.G_middle_3(x, seg)
|
48 |
+
x = self.G_middle_4(x, seg)
|
49 |
+
x = self.G_middle_5(x, seg)
|
50 |
+
|
51 |
+
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
|
52 |
+
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
|
53 |
+
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
|
54 |
+
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
|
55 |
+
|
56 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
|
57 |
+
x = torch.sigmoid(x) # Bx3xHxW
|
58 |
+
|
59 |
+
return x
|
liveportrait/modules/stitching_retargeting_network.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Stitching module(S) and two retargeting modules(R) defined in the paper.
|
5 |
+
|
6 |
+
- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
|
7 |
+
the stitching region.
|
8 |
+
|
9 |
+
- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
|
10 |
+
when a person with small eyes drives a person with larger eyes.
|
11 |
+
|
12 |
+
- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
|
13 |
+
the lips are in a closed state, which facilitates better animation driving.
|
14 |
+
"""
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
class StitchingRetargetingNetwork(nn.Module):
|
19 |
+
def __init__(self, input_size, hidden_sizes, output_size):
|
20 |
+
super(StitchingRetargetingNetwork, self).__init__()
|
21 |
+
layers = []
|
22 |
+
for i in range(len(hidden_sizes)):
|
23 |
+
if i == 0:
|
24 |
+
layers.append(nn.Linear(input_size, hidden_sizes[i]))
|
25 |
+
else:
|
26 |
+
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
|
27 |
+
layers.append(nn.ReLU(inplace=True))
|
28 |
+
layers.append(nn.Linear(hidden_sizes[-1], output_size))
|
29 |
+
self.mlp = nn.Sequential(*layers)
|
30 |
+
|
31 |
+
def initialize_weights_to_zero(self):
|
32 |
+
for m in self.modules():
|
33 |
+
if isinstance(m, nn.Linear):
|
34 |
+
nn.init.zeros_(m.weight)
|
35 |
+
nn.init.zeros_(m.bias)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.mlp(x)
|
liveportrait/modules/util.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
This file defines various neural network modules and utility functions, including convolutional and residual blocks,
|
5 |
+
normalizations, and functions for spatial transformation and tensor manipulation.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch
|
11 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
|
16 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
17 |
+
"""
|
18 |
+
Transform a keypoint into gaussian like representation
|
19 |
+
"""
|
20 |
+
mean = kp
|
21 |
+
|
22 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean)
|
23 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
24 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
25 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
26 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
|
27 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
28 |
+
|
29 |
+
# Preprocess kp shape
|
30 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
|
31 |
+
mean = mean.view(*shape)
|
32 |
+
|
33 |
+
mean_sub = (coordinate_grid - mean)
|
34 |
+
|
35 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
36 |
+
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def make_coordinate_grid(spatial_size, ref, **kwargs):
|
41 |
+
d, h, w = spatial_size
|
42 |
+
x = torch.arange(w).type(ref.dtype).to(ref.device)
|
43 |
+
y = torch.arange(h).type(ref.dtype).to(ref.device)
|
44 |
+
z = torch.arange(d).type(ref.dtype).to(ref.device)
|
45 |
+
|
46 |
+
# NOTE: must be right-down-in
|
47 |
+
x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
|
48 |
+
y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
|
49 |
+
z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
|
50 |
+
|
51 |
+
yy = y.view(1, -1, 1).repeat(d, 1, w)
|
52 |
+
xx = x.view(1, 1, -1).repeat(d, h, 1)
|
53 |
+
zz = z.view(-1, 1, 1).repeat(1, h, w)
|
54 |
+
|
55 |
+
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
|
56 |
+
|
57 |
+
return meshed
|
58 |
+
|
59 |
+
|
60 |
+
class ConvT2d(nn.Module):
|
61 |
+
"""
|
62 |
+
Upsampling block for use in decoder.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
|
66 |
+
super(ConvT2d, self).__init__()
|
67 |
+
|
68 |
+
self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
|
69 |
+
padding=padding, output_padding=output_padding)
|
70 |
+
self.norm = nn.InstanceNorm2d(out_features)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
out = self.convT(x)
|
74 |
+
out = self.norm(out)
|
75 |
+
out = F.leaky_relu(out)
|
76 |
+
return out
|
77 |
+
|
78 |
+
|
79 |
+
class ResBlock3d(nn.Module):
|
80 |
+
"""
|
81 |
+
Res block, preserve spatial resolution.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, in_features, kernel_size, padding):
|
85 |
+
super(ResBlock3d, self).__init__()
|
86 |
+
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
87 |
+
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
88 |
+
self.norm1 = nn.BatchNorm3d(in_features, affine=True)
|
89 |
+
self.norm2 = nn.BatchNorm3d(in_features, affine=True)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
out = self.norm1(x)
|
93 |
+
out = F.relu(out)
|
94 |
+
out = self.conv1(out)
|
95 |
+
out = self.norm2(out)
|
96 |
+
out = F.relu(out)
|
97 |
+
out = self.conv2(out)
|
98 |
+
out += x
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
class UpBlock3d(nn.Module):
|
103 |
+
"""
|
104 |
+
Upsampling block for use in decoder.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
108 |
+
super(UpBlock3d, self).__init__()
|
109 |
+
|
110 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
111 |
+
padding=padding, groups=groups)
|
112 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
out = F.interpolate(x, scale_factor=(1, 2, 2))
|
116 |
+
out = self.conv(out)
|
117 |
+
out = self.norm(out)
|
118 |
+
out = F.relu(out)
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
class DownBlock2d(nn.Module):
|
123 |
+
"""
|
124 |
+
Downsampling block for use in encoder.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
128 |
+
super(DownBlock2d, self).__init__()
|
129 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
130 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
131 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
out = self.conv(x)
|
135 |
+
out = self.norm(out)
|
136 |
+
out = F.relu(out)
|
137 |
+
out = self.pool(out)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class DownBlock3d(nn.Module):
|
142 |
+
"""
|
143 |
+
Downsampling block for use in encoder.
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
147 |
+
super(DownBlock3d, self).__init__()
|
148 |
+
'''
|
149 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
150 |
+
padding=padding, groups=groups, stride=(1, 2, 2))
|
151 |
+
'''
|
152 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
153 |
+
padding=padding, groups=groups)
|
154 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
155 |
+
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
out = self.conv(x)
|
159 |
+
out = self.norm(out)
|
160 |
+
out = F.relu(out)
|
161 |
+
out = self.pool(out)
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
class SameBlock2d(nn.Module):
|
166 |
+
"""
|
167 |
+
Simple block, preserve spatial resolution.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
|
171 |
+
super(SameBlock2d, self).__init__()
|
172 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
173 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
174 |
+
if lrelu:
|
175 |
+
self.ac = nn.LeakyReLU()
|
176 |
+
else:
|
177 |
+
self.ac = nn.ReLU()
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
out = self.conv(x)
|
181 |
+
out = self.norm(out)
|
182 |
+
out = self.ac(out)
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class Encoder(nn.Module):
|
187 |
+
"""
|
188 |
+
Hourglass Encoder
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
192 |
+
super(Encoder, self).__init__()
|
193 |
+
|
194 |
+
down_blocks = []
|
195 |
+
for i in range(num_blocks):
|
196 |
+
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
|
197 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
outs = [x]
|
201 |
+
for down_block in self.down_blocks:
|
202 |
+
outs.append(down_block(outs[-1]))
|
203 |
+
return outs
|
204 |
+
|
205 |
+
|
206 |
+
class Decoder(nn.Module):
|
207 |
+
"""
|
208 |
+
Hourglass Decoder
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
212 |
+
super(Decoder, self).__init__()
|
213 |
+
|
214 |
+
up_blocks = []
|
215 |
+
|
216 |
+
for i in range(num_blocks)[::-1]:
|
217 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
218 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
219 |
+
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
|
220 |
+
|
221 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
222 |
+
self.out_filters = block_expansion + in_features
|
223 |
+
|
224 |
+
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
|
225 |
+
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
out = x.pop()
|
229 |
+
for up_block in self.up_blocks:
|
230 |
+
out = up_block(out)
|
231 |
+
skip = x.pop()
|
232 |
+
out = torch.cat([out, skip], dim=1)
|
233 |
+
out = self.conv(out)
|
234 |
+
out = self.norm(out)
|
235 |
+
out = F.relu(out)
|
236 |
+
return out
|
237 |
+
|
238 |
+
|
239 |
+
class Hourglass(nn.Module):
|
240 |
+
"""
|
241 |
+
Hourglass architecture.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
245 |
+
super(Hourglass, self).__init__()
|
246 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
247 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
248 |
+
self.out_filters = self.decoder.out_filters
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
return self.decoder(self.encoder(x))
|
252 |
+
|
253 |
+
|
254 |
+
class SPADE(nn.Module):
|
255 |
+
def __init__(self, norm_nc, label_nc):
|
256 |
+
super().__init__()
|
257 |
+
|
258 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
259 |
+
nhidden = 128
|
260 |
+
|
261 |
+
self.mlp_shared = nn.Sequential(
|
262 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
|
263 |
+
nn.ReLU())
|
264 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
265 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
266 |
+
|
267 |
+
def forward(self, x, segmap):
|
268 |
+
normalized = self.param_free_norm(x)
|
269 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
270 |
+
actv = self.mlp_shared(segmap)
|
271 |
+
gamma = self.mlp_gamma(actv)
|
272 |
+
beta = self.mlp_beta(actv)
|
273 |
+
out = normalized * (1 + gamma) + beta
|
274 |
+
return out
|
275 |
+
|
276 |
+
|
277 |
+
class SPADEResnetBlock(nn.Module):
|
278 |
+
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
279 |
+
super().__init__()
|
280 |
+
# Attributes
|
281 |
+
self.learned_shortcut = (fin != fout)
|
282 |
+
fmiddle = min(fin, fout)
|
283 |
+
self.use_se = use_se
|
284 |
+
# create conv layers
|
285 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
286 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
287 |
+
if self.learned_shortcut:
|
288 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
289 |
+
# apply spectral norm if specified
|
290 |
+
if 'spectral' in norm_G:
|
291 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
292 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
293 |
+
if self.learned_shortcut:
|
294 |
+
self.conv_s = spectral_norm(self.conv_s)
|
295 |
+
# define normalization layers
|
296 |
+
self.norm_0 = SPADE(fin, label_nc)
|
297 |
+
self.norm_1 = SPADE(fmiddle, label_nc)
|
298 |
+
if self.learned_shortcut:
|
299 |
+
self.norm_s = SPADE(fin, label_nc)
|
300 |
+
|
301 |
+
def forward(self, x, seg1):
|
302 |
+
x_s = self.shortcut(x, seg1)
|
303 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
304 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
305 |
+
out = x_s + dx
|
306 |
+
return out
|
307 |
+
|
308 |
+
def shortcut(self, x, seg1):
|
309 |
+
if self.learned_shortcut:
|
310 |
+
x_s = self.conv_s(self.norm_s(x, seg1))
|
311 |
+
else:
|
312 |
+
x_s = x
|
313 |
+
return x_s
|
314 |
+
|
315 |
+
def actvn(self, x):
|
316 |
+
return F.leaky_relu(x, 2e-1)
|
317 |
+
|
318 |
+
|
319 |
+
def filter_state_dict(state_dict, remove_name='fc'):
|
320 |
+
new_state_dict = {}
|
321 |
+
for key in state_dict:
|
322 |
+
if remove_name in key:
|
323 |
+
continue
|
324 |
+
new_state_dict[key] = state_dict[key]
|
325 |
+
return new_state_dict
|
326 |
+
|
327 |
+
|
328 |
+
class GRN(nn.Module):
|
329 |
+
""" GRN (Global Response Normalization) layer
|
330 |
+
"""
|
331 |
+
|
332 |
+
def __init__(self, dim):
|
333 |
+
super().__init__()
|
334 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
335 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
339 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
340 |
+
return self.gamma * (x * Nx) + self.beta + x
|
341 |
+
|
342 |
+
|
343 |
+
class LayerNorm(nn.Module):
|
344 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
345 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
346 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
347 |
+
with shape (batch_size, channels, height, width).
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
351 |
+
super().__init__()
|
352 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
353 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
354 |
+
self.eps = eps
|
355 |
+
self.data_format = data_format
|
356 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
357 |
+
raise NotImplementedError
|
358 |
+
self.normalized_shape = (normalized_shape, )
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
if self.data_format == "channels_last":
|
362 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
363 |
+
elif self.data_format == "channels_first":
|
364 |
+
u = x.mean(1, keepdim=True)
|
365 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
366 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
367 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
368 |
+
return x
|
369 |
+
|
370 |
+
|
371 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
372 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
373 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
374 |
+
def norm_cdf(x):
|
375 |
+
# Computes standard normal cumulative distribution function
|
376 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
377 |
+
|
378 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
379 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
380 |
+
"The distribution of values may be incorrect.",
|
381 |
+
stacklevel=2)
|
382 |
+
|
383 |
+
with torch.no_grad():
|
384 |
+
# Values are generated by using a truncated uniform distribution and
|
385 |
+
# then using the inverse CDF for the normal distribution.
|
386 |
+
# Get upper and lower cdf values
|
387 |
+
l = norm_cdf((a - mean) / std)
|
388 |
+
u = norm_cdf((b - mean) / std)
|
389 |
+
|
390 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
391 |
+
# [2l-1, 2u-1].
|
392 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
393 |
+
|
394 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
395 |
+
# standard normal
|
396 |
+
tensor.erfinv_()
|
397 |
+
|
398 |
+
# Transform to proper mean, std
|
399 |
+
tensor.mul_(std * math.sqrt(2.))
|
400 |
+
tensor.add_(mean)
|
401 |
+
|
402 |
+
# Clamp to ensure it's in the proper range
|
403 |
+
tensor.clamp_(min=a, max=b)
|
404 |
+
return tensor
|
405 |
+
|
406 |
+
|
407 |
+
def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
|
408 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
409 |
+
|
410 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
411 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
412 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
413 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
414 |
+
'survival rate' as the argument.
|
415 |
+
|
416 |
+
"""
|
417 |
+
if drop_prob == 0. or not training:
|
418 |
+
return x
|
419 |
+
keep_prob = 1 - drop_prob
|
420 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
421 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
422 |
+
if keep_prob > 0.0 and scale_by_keep:
|
423 |
+
random_tensor.div_(keep_prob)
|
424 |
+
return x * random_tensor
|
425 |
+
|
426 |
+
|
427 |
+
class DropPath(nn.Module):
|
428 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
429 |
+
"""
|
430 |
+
|
431 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
432 |
+
super(DropPath, self).__init__()
|
433 |
+
self.drop_prob = drop_prob
|
434 |
+
self.scale_by_keep = scale_by_keep
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
438 |
+
|
439 |
+
|
440 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
441 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|