|
from .base_model import BaseModel |
|
import openai |
|
from tqdm import tqdm |
|
from sentence_transformers import SentenceTransformer |
|
class BiomedModel(BaseModel): |
|
def __init__(self, |
|
generation_model="gpt-4", |
|
embedding_model="pritamdeka/S-PubMedBert-MS-MARCO", |
|
temperature=0, |
|
) -> None: |
|
self.generation_model = generation_model |
|
self.embedding_model = SentenceTransformer(embedding_model) |
|
self.temperature = temperature |
|
|
|
def respond(self, messages: str) -> str: |
|
response = openai.ChatCompletion.create( |
|
messages=messages, |
|
model=self.generation_model, |
|
temperature=self.temperature, |
|
).choices[0]['message']['content'] |
|
|
|
return response |
|
|
|
def embedding(self, texts: list) -> list: |
|
|
|
if len(texts) == 1: |
|
return self.embedding_model.encode(texts[0]).tolist() |
|
else: |
|
data = self.embedding_model.encode(texts, show_progress_bar=True) |
|
data = [d.tolist() for d in data] |
|
return data |