socm / embeddings.py
spencer's picture
add normal files
6df828c
raw
history blame
4.54 kB
import logging
import os
import faiss
import torch
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class FaissIndex:
def __init__(
self,
embedding_size=None,
faiss_index_location=None,
indexer=faiss.IndexFlatIP,
):
if embedding_size or faiss_index_location:
self.embedding_size = embedding_size
else:
raise ValueError("Must provide embedding_size")
self.faiss_index_location = faiss_index_location
if faiss_index_location and os.path.exists(faiss_index_location):
self.index = faiss.read_index(faiss_index_location)
logger.info(f"Setting embedding size ({self.index.d}) to match saved index")
self.embedding_size = self.index.d
if os.path.exists(faiss_index_location + ".ids"):
with open(faiss_index_location + ".ids") as f:
self.id_list = f.read().split("\n")
elif self.index.ntotal > 0:
raise ValueError("Index file exists but ids file does not")
else:
self.id_list = []
else:
os.makedirs(os.path.dirname(faiss_index_location), exist_ok=True)
self.index = None
self.indexer = indexer
self.id_list = []
def faiss_init(self):
index = self.indexer(self.embedding_size)
if self.faiss_index_location:
faiss.write_index(index, self.faiss_index_location)
self.index = index
def add(self, inputs, ids, normalize=True):
if not self.index:
self.faiss_init()
if normalize:
faiss.normalize_L2(inputs)
self.index.add(inputs)
self.id_list.extend(ids)
faiss.write_index(self.index, self.faiss_index_location)
with open(self.faiss_index_location + ".ids", "a") as f:
f.write("\n".join(ids) + "\n")
def search(self, embedding, k=10, normalize=True):
if len(embedding.shape):
embedding = embedding.reshape(1, -1)
if normalize:
faiss.normalize_L2(embedding)
D, I = self.index.search(embedding, k)
labels = [self.id_list[i] for i in I.squeeze()]
return D, I, labels
def reset(self):
if self.index:
self.index.reset()
self.id_list = []
try:
os.remove(self.faiss_index_location)
os.remove(self.faiss_index_location + ".ids")
except FileNotFoundError:
pass
def __len__(self):
if self.index:
return self.index.ntotal
return 0
class VectorSearch:
def __init__(self):
self.places = self.load("places")
self.objects = self.load("objects")
def load(self, index_name):
return FaissIndex(
faiss_index_location=f"faiss_indices/{index_name}.index",
)
def top_places(self, query_vec, k=5):
if isinstance(query_vec, torch.Tensor):
query_vec = query_vec.detach().numpy()
*_, results = self.places.search(query_vec, k=k)
return results
def top_objects(self, query_vec, k=5):
if isinstance(query_vec, torch.Tensor):
query_vec = query_vec.detach().numpy()
*_, results = self.objects.search(query_vec, k=k)
return results
def prompt_activities(self, query_vec, k=5, one_shot=False):
places = self.top_places(query_vec, k=k)
objects = self.top_objects(query_vec, k=k)
place_str = f"Places: {', '.join(places)}. "
object_str = f"Objects: {', '.join(objects)}. "
act_str = "I might be doing these 3 activities: "
zs = place_str + object_str + act_str
example = (
"Places: kitchen. Objects: coffee maker. "
f"{act_str}: eating, making breakfast, grinding coffee.\n "
)
fs = example + place_str + object_str + act_str
if one_shot:
return (zs, fs)
return zs, places, objects
def prompt_summary(self, state_history: list, k=5):
rec_strings = ["Event log:"]
for rec in state_history:
rec_strings.append(
f"Places: {', '.join(rec.places)}. "
f"Objects: {', '.join(rec.objects)}. "
f"Activities: {', '.join(rec.activities)} "
)
question = "How would you summarize these events in a few full sentences? "
return "\n".join(rec_strings) + "\n" + question