gallary2 / indexer.py
OzoneAsai's picture
Upload 7 files
b4b645c verified
import os
import pandas as pd
from PIL import Image, UnidentifiedImageError
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import pyarrow as pa
import pyarrow.parquet as pq
# 画像フォルダとモデルのパスを指定
image_folder = "scraped_images" # 画像フォルダのパス
model_path = "MichalMlodawski/nsfw-image-detection-large" # NSFWモデルのパス
# サブフォルダを含めてjpgファイルを再帰的に取得
jpg_files = []
for root, dirs, files in os.walk(image_folder):
for file in files:
if file.lower().endswith(".jpg"):
jpg_files.append(os.path.join(root, file))
# jpgファイルが存在するか確認
if not jpg_files:
print("No jpg files found in folder:", image_folder)
exit()
# モデルとプロセッサの読み込み
feature_extractor = AutoProcessor.from_pretrained(model_path)
model = FocalNetForImageClassification.from_pretrained(model_path)
model.eval()
# 画像の変換処理
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ラベルとNSFWカテゴリのマッピング
label_to_category = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
# 結果を保存するためのリスト
results = []
# ログファイルを作成(破損画像ファイルを記録)
error_log = "error_log.txt"
# 各画像に対して分類処理を行い、結果を取得
for jpg_file in jpg_files:
try:
# 画像を開く
image = Image.open(jpg_file).convert("RGB")
except UnidentifiedImageError:
# 画像を識別できない場合のエラーハンドリング
with open(error_log, "a", encoding="utf-8") as log_file:
log_file.write(f"Unidentified image file: {jpg_file}. Skipping...\n")
print(f"Unidentified image file: {jpg_file}. Skipping...")
continue
image_tensor = transform(image).unsqueeze(0)
# モデルでの推論
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# ラベルを取得
label = model.config.id2label[predicted.item()]
category = label_to_category.get(label, "Unknown")
# 結果をリストに追加
results.append({
"file_path": jpg_file,
"label": label,
"category": category,
"confidence": confidence.item() * 100
})
# 結果をDataFrameに変換
df = pd.DataFrame(results)
# Parquet形式で保存
parquet_file = "nsfw_classification_results.parquet"
table = pa.Table.from_pandas(df)
pq.write_table(table, parquet_file)
print(f"Classification completed and saved to {parquet_file}!")