Spaces:
Sleeping
Sleeping
import pandas as pd | |
import time | |
import random | |
from sentence_transformers import SentenceTransformer | |
from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility | |
import configparser | |
from tqdm import tqdm | |
# Initialize SentenceTransformer model for embeddings | |
embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased") | |
# Read molecule names from CSV | |
csv_path = 'molecules-small.csv' | |
df = pd.read_csv(csv_path) | |
max_name_length = 256 | |
molecules = df['cmpdname'].tolist() | |
for i, molecule in enumerate(molecules): | |
if len(molecule) > max_name_length: | |
molecules[i] = molecule[:max_name_length] | |
cids = df['cid'].tolist() | |
# Encode embeddings for each molecule | |
embeddings_list = [] | |
for molecule in tqdm(molecules, desc="Generating Embeddings"): | |
embeddings = embedding_model.encode(molecule) | |
embeddings_list.append(embeddings) | |
cfp = configparser.RawConfigParser() | |
cfp.read('config.ini') | |
milvus_uri = cfp.get('example', 'uri') | |
token = cfp.get('example', 'token') | |
connections.connect("default", | |
uri=milvus_uri, | |
token=token) | |
print(f"Connecting to DB: {milvus_uri}") | |
# Define collection name and dimensionality of embeddings | |
collection_name = 'molecule_embeddings' | |
check_collection = utility.has_collection(collection_name) | |
if check_collection: | |
drop_result = utility.drop_collection(collection_name) | |
print("Success!") | |
dim = 768 # Adjust based on the dimensionality of your embeddings | |
# Define collection schema | |
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True) | |
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name") | |
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim) | |
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings], | |
auto_id=False, | |
description="my first collection!") | |
print(f"Creating example collection: {collection_name}") | |
collection = Collection(name=collection_name, schema=schema) | |
print(f"Schema: {schema}") | |
print("Success!") | |
batch_size = 1000 | |
total_rt = 0 | |
start = 0 | |
print(f"Inserting {len(embeddings_list)} entities... ") | |
for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"): | |
batch_embeddings = embeddings_list[i:i + batch_size] | |
batch_molecules = molecules[i:i + batch_size] | |
batch_cids = cids[i:i + batch_size] | |
entities = [batch_cids, batch_molecules, batch_embeddings] | |
start += batch_size | |
t0 = time.time() | |
ins_resp = collection.insert(entities) | |
ins_rt = time.time() - t0 | |
total_rt += ins_rt | |
print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!") | |
# Flush collection | |
print("Flushing collection...") | |
collection.flush() | |
# Build index | |
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} | |
print("Building index...") | |
collection.create_index(field_name='molecule_embedding', index_params=index_params) | |
collection.load() | |
# Example search | |
nq = 1 | |
search_params = {"metric_type": "L2"} | |
topk = 5 | |
search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)] | |
print(f"Searching vector: {search_vec}") | |
results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk) | |
print(f"Search results: {results}") | |
# Disconnect from Milvus server | |
connections.disconnect("default") | |
print("Disconnected from Milvus server.") | |