jRefactoring / graphCodeBert.py
gautam-shetty's picture
Initial commit
a5fb347
raw
history blame
2.3 kB
from transformers import AutoTokenizer, AutoModel
from Database import Database
class GraphCodeBert:
def __init__(self) -> None:
model_name = "microsoft/graphcodebert-base"
self.tokenizer= AutoTokenizer.from_pretrained(model_name)
self.model=AutoModel.from_pretrained(model_name)
def generate_embeddings(self):
database = Database("refactoring_details_neg")
# database.connect_db()
# collection = database.fetch_collection("refactoring_information")
# collection_len = collection.estimated_document_count()
collection_len = database.estimated_doc_count()
doc_count = 1
for doc in database.find_docs({}, {"_id": 1, "method_refactored": 1, "meth_rf_neg":1}):
doc_id = doc["_id"]
code_snippet = doc["method_refactored"]
code_snippet_neg = doc["meth_rf_neg"]
print(f'Generating embedding for doc_id:{doc_id} | Count-{doc_count}...')
# Compute embeddings
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_pos)
embedding_pos = output.last_hidden_state.mean(dim=1).squeeze().tolist()
#Neg Embedding
tokenized_input_neg = self.tokenizer(code_snippet_neg, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_neg)
embedding_neg = output.last_hidden_state.mean(dim=1).squeeze().tolist()
# Update document in MongoDB with embedding
database.update_by_id(doc_id, "embedding_pos", embedding_pos)
database.update_by_id(doc_id,"embedding_neg", embedding_neg)
collection_len -= 1
doc_count += 1
print(f'Embedding added for doc_id:{doc_id} | Remaining: {collection_len}.')
def generate_individual_embedding(self,code_snippet):
tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
output = self.model(**tokenized_input_pos)
embedding = output.last_hidden_state.mean(dim=1).squeeze().tolist()
return embedding
if __name__=="__main__":
GraphCodeBert().generate_embeddings()