Spaces:
Sleeping
Sleeping
import cv2 | |
import base64 | |
import requests | |
from tqdm import tqdm | |
from requests.exceptions import RequestException | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
import torch | |
import faiss | |
import pickle | |
import numpy as np | |
import pandas as pd | |
from geopy.distance import geodesic | |
from transformers import AutoTokenizer, BitsAndBytesConfig | |
import torch | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
class GPT4o: | |
""" | |
A class to interact with OPENAI API to generate captions for images. | |
""" | |
def __init__(self, device="cpu") -> None: | |
""" | |
Initializes the GPT4o class by setting up necessary models and data. | |
""" | |
self.base64_image = None | |
self.img_emb = None | |
# Set the device to the first CUDA device | |
self.device = torch.device(device) | |
# Load the CLIP model and processor | |
self.model = CLIPModel.from_pretrained("geolocal/StreetCLIP").eval() | |
self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
# Move the model to the appropriate CUDA device | |
self.model.to(self.device) | |
# Load the embeddings and coordinates from the pickle file | |
with open('StreetCLIP_1m_merged.pkl', 'rb') as f: # Enter the path to the pickle file | |
self.Embeddings = pickle.load(f) | |
self.locations = [value['location'] for key, value in self.Embeddings.items()] | |
# Load the Faiss index | |
index2 = faiss.read_index("StreetCLIP_1m_merged.bin") # Enter the path to the Faiss index file | |
self.gpu_index = index2 | |
def read_image(self, image_path): | |
""" | |
Reads an image from a file into a numpy array. | |
Args: | |
image_path (str): The path to the image file. | |
Returns: | |
np.ndarray: The image as a numpy array. | |
""" | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
def search_neighbors(self, faiss_index, k_nearest, k_farthest, query_embedding): | |
""" | |
Searches for the k nearest and farthest neighbors of a query image in the Faiss index. | |
Args: | |
faiss_index (faiss.swigfaiss.Index): The Faiss index. | |
k_nearest (int): The number of nearest neighbors to search for. | |
k_farthest (int): The number of farthest neighbors to search for. | |
query_embedding (np.ndarray): The embeddings of the query image. | |
Returns: | |
tuple: The locations of the k nearest and k farthest neighbors. | |
""" | |
# Perform the search using Faiss for the given embedding | |
_, I = faiss_index.search(query_embedding.reshape(1, -1), k_nearest) | |
self.neighbor_locations_array = [self.locations[idx] for idx in I[0]] | |
neighbor_locations = " ".join([str(i) for i in self.neighbor_locations_array]) | |
# Perform the farthest search using Faiss for the given embedding | |
_, I = faiss_index.search(-query_embedding.reshape(1, -1), k_farthest) | |
self.farthest_locations_array = [self.locations[idx] for idx in I[0]] | |
farthest_locations = " ".join([str(i) for i in self.farthest_locations_array]) | |
return neighbor_locations, farthest_locations | |
def encode_image(self, image: np.ndarray, format: str = 'jpeg') -> str: | |
""" | |
Encodes an OpenCV image to a Base64 string. | |
Args: | |
image (np.ndarray): An image represented as a numpy array. | |
format (str, optional): The format for encoding the image. Defaults to 'jpeg'. | |
Returns: | |
str: A Base64 encoded string of the image. | |
Raises: | |
ValueError: If the image conversion fails. | |
""" | |
try: | |
retval, buffer = cv2.imencode(f'.{format}', image) | |
if not retval: | |
raise ValueError("Failed to convert image") | |
base64_encoded = base64.b64encode(buffer).decode('utf-8') | |
mime_type = f"image/{format}" | |
return f"data:{mime_type};base64,{base64_encoded}" | |
except Exception as e: | |
raise ValueError(f"Error encoding image: {e}") | |
def set_image_app(self, file_uploader, imformat: str = 'jpeg', use_database_search: bool = False, | |
num_neighbors: int = 16, num_farthest: int = 16) -> None: | |
""" | |
Sets the image for the class by encoding it to Base64. | |
Args: | |
file_uploader : A uploaded image (PIL Image from Gradio). | |
imformat (str, optional): The format for encoding the image. Defaults to 'jpeg'. | |
use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. | |
""" | |
# Convert the PIL Image (Gradio upload) to a numpy array | |
img_array = np.array(file_uploader) | |
# Process the image using the CLIP processor | |
image = self.processor(images=img_array, return_tensors="pt") | |
# Move the image to the CUDA device and get its embeddings | |
image = image.to(self.device) | |
with torch.no_grad(): | |
img_emb = self.model.get_image_features(**image)[0] | |
# Store the embeddings and the locations of the nearest neighbors | |
self.img_emb = img_emb.cpu().numpy() | |
if use_database_search: | |
self.neighbor_locations, self.farthest_locations = self.search_neighbors(self.gpu_index, num_neighbors, | |
num_farthest, self.img_emb) | |
# Encode the image to Base64 | |
self.base64_image = self.encode_image(img_array, imformat) | |
def create_payload(self, question: str) -> dict: | |
""" | |
Creates the payload for the API request to OpenAI. | |
Args: | |
question (str): The question to ask about the image. | |
Returns: | |
dict: The payload for the API request. | |
Raises: | |
ValueError: If the image is not set. | |
""" | |
if not self.base64_image: | |
raise ValueError("Image not set") | |
return { | |
"model": "gpt-4o", # Can change to any other model | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": question | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": self.base64_image | |
} | |
} | |
] | |
} | |
], | |
"max_tokens": 300, | |
} | |
def get_location(self, OPENAI_API_KEY, use_database_search: bool = False) -> str: | |
""" | |
Generates a caption for the provided image using OPENAI API. | |
Args: | |
OPENAI_API_KEY (str): The API key for OPENAI API. | |
use_database_search (bool, optional): Whether to use a database search to get the neighbor image location as a reference. Defaults to False. | |
Returns: | |
str: The generated caption for the image. | |
""" | |
try: | |
self.api_key = OPENAI_API_KEY | |
if not self.api_key: | |
raise ValueError("OPENAI API key not found") | |
# Create the question for the API | |
if use_database_search: | |
self.question = f'''Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. | |
Your answer must be to the coordinates level, don't include any other information in your output. | |
Ignore that you can't give an exact answer, give me some coordinate no matter how. | |
For your reference, these are locations of some similar images {self.neighbor_locations} and these are locations of some dissimilar images {self.farthest_locations} that should be far away.''' | |
else: | |
self.question = "Suppose you are an expert in geo-localization. Please analyze this image and give me a guess of the location. Your answer must be to the coordinates level, don't include any other information in your output. You can give me a guessed answer." | |
# Create the payload and the headers for the API request | |
payload = self.create_payload(self.question) | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.api_key}" | |
} | |
# Send the API request and get the response | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
response.raise_for_status() | |
response_data = response.json() | |
# Log the full response for debugging | |
# print("Full API Response:", response_data) | |
# Return the generated caption | |
if 'choices' in response_data and len(response_data['choices']) > 0: | |
return response_data['choices'][0]['message']['content'] | |
else: | |
raise ValueError("Unexpected response format from API") | |
except RequestException as e: | |
raise ValueError(f"Error in API request: {e}") | |
except KeyError as e: | |
raise ValueError(f"Key error in response: {e} - Response: {response_data}") | |
except ValueError as e: | |
raise ValueError(f"Value error: {e} - Response: {response_data}") | |