Ubiquant_CharacterHunter / upload_huggingface.py
Facepalm0's picture
Upload upload_huggingface.py with huggingface_hub
93313ac verified
raw
history blame
3.62 kB
from huggingface_hub import HfApi
import os
from pathlib import Path
from dotenv import load_dotenv
import time
from requests.exceptions import ConnectionError, HTTPError
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
def get_gitignore_patterns():
"""读取 .gitignore 文件中的规则"""
if os.path.exists('.gitignore'):
with open('.gitignore', 'r', encoding='utf-8') as f:
return f.read().splitlines()
return []
def upload_to_huggingface(
local_directory: str,
repo_id: str,
max_retries: int = 3,
retry_delay: int = 5
):
"""
上传整个目录到 Hugging Face
Args:
local_directory: 本地项目目录路径
repo_id: Hugging Face 仓库ID (格式: username/repo_name)
max_retries: 最大重试次数
retry_delay: 重试间隔(秒)
"""
# 加载环境变量
load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError("请在 .env 文件中设置 HUGGINGFACE_TOKEN")
# 读取 .gitignore 规则
gitignore_patterns = get_gitignore_patterns()
# 添加一些额外的忽略规则
additional_patterns = [
"VectorScience - 方案介绍 - 策略部分.docx",
".git/",
"__pycache/",
"*.pyc",
".env",
".gitignore",
".DS_Store",
"ewv9ssdcuvg6",
".docx",
"wandb/",
"jk_zfls/"
]
gitignore_patterns.extend(additional_patterns)
# 创建 PathSpec 对象来匹配文件
spec = PathSpec.from_lines(GitWildMatchPattern, gitignore_patterns)
# 初始化 Hugging Face API
api = HfApi()
# 获取所有要上传的文件
files_to_upload = []
for root, _, files in os.walk(local_directory):
for file in files:
file_path = Path(root) / file
relative_path = file_path.relative_to(local_directory)
# 使用 PathSpec 检查文件是否应该被忽略
if not spec.match_file(str(relative_path)):
files_to_upload.append((str(file_path), str(relative_path)))
# 上传文件
for local_path, path_in_repo in files_to_upload:
for attempt in range(max_retries):
try:
print(f"正在上传: {local_path} -> {path_in_repo}")
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=path_in_repo,
repo_id=repo_id,
token=token
)
print(f"成功上传: {path_in_repo}")
break
except (ConnectionError, HTTPError) as e:
if attempt < max_retries - 1:
print(f"上传失败,{retry_delay}秒后重试... ({attempt + 1}/{max_retries})")
print(f"错误信息: {str(e)}")
time.sleep(retry_delay)
else:
print(f"上传失败: {path_in_repo}")
print(f"错误信息: {str(e)}")
raise
if __name__ == "__main__":
# 设置环境变量以使用代理
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:17890"
# 使用示例
upload_to_huggingface(
local_directory=".", # 当前目录
repo_id="Facepalm0/Ubiquant_CharacterHunter", # 替换为你的仓库ID
max_retries=3, # 最大重试5次
retry_delay=5 # 每次重试间隔5秒
)