from transformers import  AutoTokenizer, AutoModel

from Database import Database

class GraphCodeBert:

    def __init__(self) -> None:
        model_name = "microsoft/graphcodebert-base"
        self.tokenizer= AutoTokenizer.from_pretrained(model_name)
        self.model=AutoModel.from_pretrained(model_name)


    def generate_embeddings(self):
        database = Database("refactoring_details_neg")
        # database.connect_db()
        # collection = database.fetch_collection("refactoring_information")
        # collection_len = collection.estimated_document_count()
        collection_len = database.estimated_doc_count()


        doc_count = 1
        for doc in database.find_docs({}, {"_id": 1, "method_refactored": 1, "meth_rf_neg":1}):
            doc_id = doc["_id"]
            code_snippet = doc["method_refactored"]
            code_snippet_neg = doc["meth_rf_neg"]
            print(f'Generating embedding for doc_id:{doc_id} | Count-{doc_count}...')
            
            # Compute embeddings
            tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
            output = self.model(**tokenized_input_pos)
            embedding_pos = output.last_hidden_state.mean(dim=1).squeeze().tolist()

            #Neg Embedding
            tokenized_input_neg = self.tokenizer(code_snippet_neg, return_tensors="pt", padding=True, truncation=True)
            output = self.model(**tokenized_input_neg)
            embedding_neg = output.last_hidden_state.mean(dim=1).squeeze().tolist()

            # Update document in MongoDB with embedding
            database.update_by_id(doc_id, "embedding_pos", embedding_pos)
            database.update_by_id(doc_id,"embedding_neg", embedding_neg)

            collection_len -= 1
            doc_count += 1
            print(f'Embedding added for doc_id:{doc_id} | Remaining: {collection_len}.')

    def generate_individual_embedding(self,code_snippet):
        tokenized_input_pos = self.tokenizer(code_snippet, return_tensors="pt", padding=True, truncation=True)
        output = self.model(**tokenized_input_pos)
        embedding = output.last_hidden_state.mean(dim=1).squeeze().tolist()

        return embedding



if __name__=="__main__":
    GraphCodeBert().generate_embeddings()