File size: 4,536 Bytes
6df828c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import logging
import os
import faiss
import torch
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class FaissIndex:
def __init__(
self,
embedding_size=None,
faiss_index_location=None,
indexer=faiss.IndexFlatIP,
):
if embedding_size or faiss_index_location:
self.embedding_size = embedding_size
else:
raise ValueError("Must provide embedding_size")
self.faiss_index_location = faiss_index_location
if faiss_index_location and os.path.exists(faiss_index_location):
self.index = faiss.read_index(faiss_index_location)
logger.info(f"Setting embedding size ({self.index.d}) to match saved index")
self.embedding_size = self.index.d
if os.path.exists(faiss_index_location + ".ids"):
with open(faiss_index_location + ".ids") as f:
self.id_list = f.read().split("\n")
elif self.index.ntotal > 0:
raise ValueError("Index file exists but ids file does not")
else:
self.id_list = []
else:
os.makedirs(os.path.dirname(faiss_index_location), exist_ok=True)
self.index = None
self.indexer = indexer
self.id_list = []
def faiss_init(self):
index = self.indexer(self.embedding_size)
if self.faiss_index_location:
faiss.write_index(index, self.faiss_index_location)
self.index = index
def add(self, inputs, ids, normalize=True):
if not self.index:
self.faiss_init()
if normalize:
faiss.normalize_L2(inputs)
self.index.add(inputs)
self.id_list.extend(ids)
faiss.write_index(self.index, self.faiss_index_location)
with open(self.faiss_index_location + ".ids", "a") as f:
f.write("\n".join(ids) + "\n")
def search(self, embedding, k=10, normalize=True):
if len(embedding.shape):
embedding = embedding.reshape(1, -1)
if normalize:
faiss.normalize_L2(embedding)
D, I = self.index.search(embedding, k)
labels = [self.id_list[i] for i in I.squeeze()]
return D, I, labels
def reset(self):
if self.index:
self.index.reset()
self.id_list = []
try:
os.remove(self.faiss_index_location)
os.remove(self.faiss_index_location + ".ids")
except FileNotFoundError:
pass
def __len__(self):
if self.index:
return self.index.ntotal
return 0
class VectorSearch:
def __init__(self):
self.places = self.load("places")
self.objects = self.load("objects")
def load(self, index_name):
return FaissIndex(
faiss_index_location=f"faiss_indices/{index_name}.index",
)
def top_places(self, query_vec, k=5):
if isinstance(query_vec, torch.Tensor):
query_vec = query_vec.detach().numpy()
*_, results = self.places.search(query_vec, k=k)
return results
def top_objects(self, query_vec, k=5):
if isinstance(query_vec, torch.Tensor):
query_vec = query_vec.detach().numpy()
*_, results = self.objects.search(query_vec, k=k)
return results
def prompt_activities(self, query_vec, k=5, one_shot=False):
places = self.top_places(query_vec, k=k)
objects = self.top_objects(query_vec, k=k)
place_str = f"Places: {', '.join(places)}. "
object_str = f"Objects: {', '.join(objects)}. "
act_str = "I might be doing these 3 activities: "
zs = place_str + object_str + act_str
example = (
"Places: kitchen. Objects: coffee maker. "
f"{act_str}: eating, making breakfast, grinding coffee.\n "
)
fs = example + place_str + object_str + act_str
if one_shot:
return (zs, fs)
return zs, places, objects
def prompt_summary(self, state_history: list, k=5):
rec_strings = ["Event log:"]
for rec in state_history:
rec_strings.append(
f"Places: {', '.join(rec.places)}. "
f"Objects: {', '.join(rec.objects)}. "
f"Activities: {', '.join(rec.activities)} "
)
question = "How would you summarize these events in a few full sentences? "
return "\n".join(rec_strings) + "\n" + question
|