File size: 1,095 Bytes
58974f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from .base_model import BaseModel
import openai
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
class BiomedModel(BaseModel):
    def __init__(self, 
                 generation_model="gpt-4", 
                 embedding_model="pritamdeka/S-PubMedBert-MS-MARCO",
                 temperature=0, 
        ) -> None:
        self.generation_model = generation_model
        self.embedding_model = SentenceTransformer(embedding_model)
        self.temperature = temperature
    
    def respond(self, messages: str) -> str:
        response = openai.ChatCompletion.create(
            messages=messages,
            model=self.generation_model,
            temperature=self.temperature,
        ).choices[0]['message']['content']
        
        return response
    
    def embedding(self, texts: list) -> list:

        if len(texts) == 1:
            return self.embedding_model.encode(texts[0]).tolist()
        else:
            data = self.embedding_model.encode(texts, show_progress_bar=True)
            data = [d.tolist() for d in data]
            return data