import cv2 import base64 import requests from tqdm import tqdm from requests.exceptions import RequestException from PIL import Image from transformers import CLIPModel, CLIPProcessor import torch import faiss import pickle import numpy as np import pandas as pd from geopy.distance import geodesic from transformers import AutoTokenizer, BitsAndBytesConfig import torch from PIL import Image import requests from io import BytesIO import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" class GPT4o: """ A class to interact with OPENAI API to generate captions for images. """ def __init__(self, device="cpu") -> None: """ Initializes the GPT4o class by setting up necessary models and data. """ self.base64_image = None self.img_emb = None # Set the device to the first CUDA device self.device = torch.device(device) # Load the CLIP model and processor self.model = CLIPModel.from_pretrained("geolocal/StreetCLIP").eval() self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") # Move the model to the appropriate CUDA device self.model.to(self.device) # Load the embeddings and coordinates from the pickle file with open('StreetCLIP_1m_merged.pkl', 'rb') as f: # Enter the path to the pickle file self.Embeddings = pickle.load(f) self.locations = [value['location'] for key, value in self.Embeddings.items()] # Load the Faiss index index2 = faiss.read_index("StreetCLIP_1m_merged.bin") # Enter the path to the Faiss index file self.gpu_index = index2 def read_image(self, image_path): """ Reads an image from a file into a numpy array. Args: image_path (str): The path to the image file. Returns: np.ndarray: The image as a numpy array. """ image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding): """ Searches for the k nearest and farthest neighbors of a query image in the Faiss index. Args: faiss_index (faiss.swigfaiss.Index): The Faiss index. k_nearest (int): The number of nearest neighbors to search for. k_farthest (int): The number of farthest neighbors to search for. query_embedding (np.ndarray): The embeddings of the query image. Returns: tuple: The locations of the k nearest and k farthest neighbors. """ # Perform the search using Faiss for the given embedding _, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest) self.neighbor_locations_array = [self.locations[idx] for idx in I[0]] neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array]) # Perform the farthest search using Faiss for the given embedding _, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest) self.farthest_locations_array = [self.locations[idx] for idx in I[0]] farthest_locations = " ".join([str(i) for i in self.farthest_locations_array]) return neighbor_locations, farthest_locations def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str: """ Encodes an OpenCV image to a Base64 string. Args: image (np.ndarray): An image represented as a numpy array. format (str, optional): The format for encoding the image. Defaults to 'jpeg'. Returns: str: A Base64 encoded string of the image. Raises: ValueError: If the image conversion fails. """ try: retval, buffer = cv2.imencode(f'.{format}', image) if not retval: raise ValueError("Failed to convert image") base64_encoded = base64.b64encode(buffer).decode('utf-8') mime_type = f"image/{format}" return f"data:{mime_type};base64,{base64_encoded}" except Exception as e: raise ValueError(f"Error encoding image: {e}") def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False, num_neighbors: int = 16, num_farthest: int = 16) -> None: """ Sets the image for the class by encoding it to Base64. Args: file_uploader : A uploaded image (PIL Image from Gradio). imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'. use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. """ # Convert the PIL Image (Gradio upload) to a numpy array img_array = np.array(file_uploader) # Process the image using the CLIP processor image = self.processor(images=img_array, return_tensors="pt") # Move the image to the CUDA device and get its embeddings image = image.to(self.device) with torch.no_grad(): img_emb = self.model.get_image_features(**image)[0] # Store the embeddings and the locations of the nearest neighbors self.img_emb = img_emb.cpu().numpy() if use_database_search: self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, num_farthest, self.img_emb) # Encode the image to Base64 self.base64_image = self.encode_image(img_array, imformat) def create_payload(self, question: str) -> dict: """ Creates the payload for the API request to OpenAI. Args: question (str): The question to ask about the image. Returns: dict: The payload for the API request. Raises: ValueError: If the image is not set. """ if not self.base64_image: raise ValueError("Image not set") return { "model": "gpt-4o", # Can change to any other model "messages": [ { "role": "user", "content": [ { "type": "text", "text": question }, { "type": "image_url", "image_url": { "url": self.base64_image } } ] } ], "max_tokens": 300, } def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str: """ Generates a caption for the provided image using OPENAI API. Args: OPENAI_API_KEY (str): The API key for OPENAI API. use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. Returns: str: The generated caption for the image. """ try: self.api_key = OPENAI_API_KEY if not self.api_key: raise ValueError("OPENAI API key not found") # Create the question for the API if use_database_search: self.question = f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. Ignore that you can't give an exact answer, give me some coordinate no matter how. For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.''' else: self.question = "Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed answer." # Create the payload and the headers for the API request payload = self.create_payload(self.question) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } # Send the API request and get the response response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) response.raise_for_status() response_data = response.json() # Log the full response for debugging # print("Full API Response:", response_data) # Return the generated caption if 'choices' in response_data and len(response_data['choices']) > 0: return response_data['choices'][0]['message']['content'] else: raise ValueError("Unexpected response format from API") except RequestException as e: raise ValueError(f"Error in API request: {e}") except KeyError as e: raise ValueError(f"Key error in response: {e} - Response: {response_data}") except ValueError as e: raise ValueError(f"Value error: {e} - Response: {response_data}")