File size: 2,457 Bytes
58974f8 3b85261 58974f8 3b85261 3ecbc1a 75cfec7 a6a044d 3ecbc1a 75cfec7 58974f8 3ecbc1a 58974f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from models import BaseModel
from .base_refiner import BaseRefiner
from utils.image_encoder import encode_image
import asyncio
class SimpleRefiner(BaseRefiner):
def __init__(self,
sys_prompt: str,
model: BaseModel,
) -> None:
BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
async def refine_async(self, message: str, memory, image_paths=None) -> str:
if memory is None:
memory = []
else:
memory = memory.messages[1:]
user_context = [{"role": "user", "content": [
{"type": "text", "text": f"{message}"},]}]
if image_paths:
if not isinstance(image_paths, list):
image_paths = [image_paths]
for image_path in image_paths:
user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
{"type": "text", "text": f"{message}"},
]}]
else:
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
respond_task = asyncio.create_task(self.model.respond_async(context))
await respond_task
response = respond_task.result()
return response
def refine(self, message: str, memory, image_paths=None) -> str:
if memory is None:
memory = []
else:
memory = memory.messages[1:]
user_context = [{"role": "user", "content": [
{"type": "text", "text": f"{message}"},]}]
if image_paths:
if not isinstance(image_paths, list):
image_paths = [image_paths]
for image_path in image_paths:
user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
{"type": "text", "text": f"{message}"},
]}]
else:
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
response = self.model.respond(context)
return response |