minchul's picture
Upload directory
f4a7259 verified
|
raw
history blame
2.35 kB
metadata
language: en
license: mit
arxiv: 1801.07698

CVLFace Pretrained Model (ARCFACE IR101 WEBFACE4M)

🌎 GitHub • 🤗 Hugging Face


1. Introduction

Model Name: ARCFACE IR101 WEBFACE4M

Related Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition (https://arxiv.org/abs/1801.07698)

Please cite the orignal paper and the license of the training dataset.

2. Quick Start

from transformers import AutoModel
from huggingface_hub import hf_hub_download
import shutil
import os
import torch


# helpfer function to download huggingface repo and use model
def download(repo_id, path, HF_TOKEN=None):
    files_path = os.path.join(path, 'files.txt')
    if not os.path.exists(files_path):
        hf_hub_download(repo_id, 'files.txt', token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)
    with open(os.path.join(path, 'files.txt'), 'r') as f:
        files = f.read().split('\n')
    for file in [f for f in files if f] + ['config.json', 'wrapper.py', 'model.safetensors']:
        full_path = os.path.join(path, file)
        if not os.path.exists(full_path):
            hf_hub_download(repo_id, file, token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)

            
# helpfer function to download huggingface repo and use model
def load_model_from_local_path(path, HF_TOKEN=None):
    cwd = os.getcwd()
    os.chdir(path)
    model = AutoModel.from_pretrained(path, trust_remote_code=True, token=HF_TOKEN)
    os.chdir(cwd)
    return model


# helpfer function to download huggingface repo and use model
def load_model_by_repo_id(repo_id, save_path, HF_TOKEN=None, force_download=False):
    if force_download:
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
    download(repo_id, save_path, HF_TOKEN)
    return load_model_from_local_path(save_path, HF_TOKEN)


if __name__ == '__main__':
    HF_TOKEN = 'YOUR_HUGGINGFACE_TOKEN'
    path = 'path/to/store/model/locally'
    repo_id = 'minchul/cvlface_arcface_ir101_webface4m'
    model = load_model_by_repo_id(repo_id, path, HF_TOKEN)
    input = torch.randn(1, 3, 112, 112)
    out = model(input)