|
from models import BaseModel |
|
from .base_refiner import BaseRefiner |
|
from utils.image_encoder import encode_image |
|
import asyncio |
|
|
|
class SimpleRefiner(BaseRefiner): |
|
def __init__(self, |
|
sys_prompt: str, |
|
model: BaseModel, |
|
) -> None: |
|
BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model) |
|
|
|
async def refine_async(self, message: str, memory, image_paths=None) -> str: |
|
if memory is None: |
|
memory = [] |
|
else: |
|
memory = memory.messages[1:] |
|
|
|
user_context = [{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"},]}] |
|
if image_paths: |
|
if not isinstance(image_paths, list): |
|
image_paths = [image_paths] |
|
for image_path in image_paths: |
|
user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}}) |
|
context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"}, |
|
]}] |
|
else: |
|
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context |
|
|
|
respond_task = asyncio.create_task(self.model.respond_async(context)) |
|
await respond_task |
|
response = respond_task.result() |
|
return response |
|
|
|
def refine(self, message: str, memory, image_paths=None) -> str: |
|
if memory is None: |
|
memory = [] |
|
else: |
|
memory = memory.messages[1:] |
|
|
|
user_context = [{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"},]}] |
|
if image_paths: |
|
if not isinstance(image_paths, list): |
|
image_paths = [image_paths] |
|
for image_path in image_paths: |
|
user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}}) |
|
context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [ |
|
{"type": "text", "text": f"{message}"}, |
|
]}] |
|
else: |
|
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context |
|
response = self.model.respond(context) |
|
|
|
return response |