import torch from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig import gradio as gr from PIL import Image import os import logging from safetensors.torch import load_file # Import safetensors loading function # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Define the directory containing the model files model_dir = "." # Use current directory # Define paths to the specific model files model_path = os.path.join(model_dir, "model.safetensors") config_path = os.path.join(model_dir, "config.json") preprocessor_path = os.path.join(model_dir, "preprocessor_config.json") # Check if all required files exist for path in [model_path, config_path, preprocessor_path]: if not os.path.exists(path): logging.error(f"File not found: {path}") raise FileNotFoundError(f"Required file not found: {path}") else: logging.info(f"Found file: {path}") # Load the configuration config = AutoConfig.from_pretrained(config_path) # Ensure the labels are consistent with the model's config labels = list(config.id2label.values()) logging.info(f"Labels: {labels}") # Load the feature extractor feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path) # Load the model using the safetensors file state_dict = load_file(model_path) # Use safetensors to load the model weights model = ViTForImageClassification.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=state_dict ) # Ensure the model is in evaluation mode model.eval() logging.info("Model set to evaluation mode") # Define the prediction function def predict(image): logging.info("Starting prediction") logging.info(f"Input image shape: {image.size}") # Preprocess the image logging.info("Preprocessing image") inputs = feature_extractor(images=image, return_tensors="pt") logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}") logging.info("Running inference") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits[0], dim=0) logging.info(f"Raw logits: {logits}") logging.info(f"Probabilities: {probabilities}") # Prepare the output dictionary result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} logging.info(f"Prediction result: {result}") return result # Set up the Gradio Interface logging.info("Setting up Gradio interface") gradio_app = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=6), title="NeckLine Classifier" ) # Launch the app if __name__ == "__main__": logging.info("Launching the app") gradio_app.launch()