Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from typing import Type, TypeVar | |
import base64 | |
import os | |
import json | |
from doc2json import process_docx | |
import fitz | |
from PIL import Image | |
import io | |
import boto3 | |
from botocore.config import Config | |
import re | |
from PIL import Image | |
import io | |
import math | |
import gradio | |
# constants | |
log_to_console = False | |
use_document_message_type = False # AWS document message type usage | |
LLMClass = TypeVar('LLMClass', bound='LLM') | |
class LLM: | |
def create_llm(model: str) -> Type[LLMClass]: | |
return LLM() | |
def generate_body(self, message, history): | |
messages = [] | |
# AWS API requires strict user, assi, user, ... sequence | |
lastTypeHuman = False | |
for msg in history: | |
if msg['role'] == "user": | |
if lastTypeHuman: | |
last_msg = messages.pop() | |
user_msg_parts = last_msg["content"] | |
else: | |
user_msg_parts = [] | |
content = msg['content'] | |
if isinstance(content, gradio.File) or isinstance(content, gradio.Image): | |
user_msg_parts.extend(self._process_file(content.value['path'])) | |
elif isinstance(content, tuple): | |
user_msg_parts.extend(self._process_file(content[0])) | |
else: | |
user_msg_parts.extend([{"text": content}]) | |
messages.append({"role": "user", "content": user_msg_parts}) | |
lastTypeHuman = True | |
else: | |
messages.append({ | |
"role": "assistant", | |
"content":[{"text": msg['content']}] | |
}) | |
lastTypeHuman = False | |
if lastTypeHuman: | |
last_msg = messages.pop() | |
user_msg_parts = last_msg["content"] | |
else: | |
user_msg_parts = [] | |
if message: | |
if message["text"]: | |
user_msg_parts.append({"text": message["text"]}) | |
if message["files"]: | |
for file in message["files"]: | |
user_msg_parts.extend(self._process_file(file)) | |
if user_msg_parts: | |
messages.append({"role": "user", "content": user_msg_parts}) | |
return messages | |
def _process_file(self, file_path): | |
if use_document_message_type and self._is_supported_document_type(file_path): | |
return [self._create_document_message(file_path)] | |
else: | |
return self._encode_file(file_path) | |
def _is_supported_document_type(self, file_path): | |
supported_extensions = ['.pdf', '.csv', '.doc', '.docx', '.xls', '.xlsx', '.html', '.txt', '.md'] | |
return os.path.splitext(file_path)[1].lower() in supported_extensions | |
def _create_document_message(self, file_path): | |
with open(file_path, 'rb') as file: | |
file_content = file.read() | |
file_name = re.sub(r'[^a-zA-Z0-9\s\-\(\)\[\]]', '', os.path.basename(file_path))[:200].strip() or "unnamed_file" | |
file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot | |
return { | |
"document": { | |
"name": file_name, | |
"format": file_extension, | |
"source": { | |
"bytes": file_content | |
} | |
} | |
} | |
def _encode_file(self, fn: str) -> list: | |
if fn.endswith(".docx"): | |
return [{"text": process_docx(fn)}] | |
elif fn.endswith(".pdf"): | |
return self._process_pdf_img(fn) | |
else: | |
with open(fn, mode="rb") as f: | |
content = f.read() | |
if isinstance(content, bytes): | |
try: | |
# try to add as image | |
image_data = self._encode_image(content) | |
return [{"image": image_data}] | |
except: | |
# not an image, try text | |
content = content.decode('utf-8', 'replace') | |
else: | |
content = str(content) | |
fname = os.path.basename(fn) | |
return [{"text": f"``` {fname}\n{content}\n```"}] | |
def _process_pdf_img(self, pdf_fn: str): | |
pdf = fitz.open(pdf_fn) | |
message_parts = [] | |
page_scales = {} # Cache for similar page sizes | |
def calculate_tokens(width, height): | |
return (width * height) / 750 | |
for page in pdf.pages(): | |
page_rect = page.rect | |
orig_width = page_rect.width | |
orig_height = page_rect.height | |
page_key = (orig_width, orig_height) | |
# Use cached scale as starting point if available | |
scale = page_scales.get(page_key, 1.0) | |
while True: | |
# Render with current scale | |
mat = fitz.Matrix(scale, scale) | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
# Check actual rendered dimensions | |
actual_tokens = calculate_tokens(pix.width, pix.height) | |
actual_long_edge = max(pix.width, pix.height) | |
if actual_long_edge <= 1568 and actual_tokens <= 1600: | |
# We found a good scale, cache it | |
if page_key not in page_scales: | |
page_scales[page_key] = scale | |
break | |
# Calculate new scale factor based on both constraints | |
if actual_long_edge > 1568: | |
scale_factor = min(1568 / actual_long_edge, 0.9) | |
else: | |
scale_factor = min(math.sqrt(1600 / actual_tokens), 0.9) | |
scale *= scale_factor | |
# Convert to PIL Image | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
# Handle compression | |
quality = 95 | |
while True: | |
buffer = io.BytesIO() | |
img.save(buffer, format="webp", quality=quality) | |
img_bytes = buffer.getvalue() | |
if len(img_bytes) <= 5 * 1024 * 1024 or quality <= 20: | |
break | |
quality = max(int(quality * 0.9), 20) | |
message_parts.append({"text": f"Page {page.number + 1} of file '{pdf_fn}'"}) | |
message_parts.append({"image": { | |
"format": "webp", | |
"source": {"bytes": img_bytes} | |
}}) | |
pdf.close() | |
return message_parts | |
def _encode_image(self, image_data): | |
try: | |
# Open the image using Pillow | |
img = Image.open(io.BytesIO(image_data)) | |
original_format = img.format.lower() | |
except IOError: | |
raise Exception("Unknown image type") | |
# Ensure correct orientation based on EXIF | |
try: | |
exif = img._getexif() | |
if exif: | |
orientation = exif.get(274) # 274 is the orientation tag | |
if orientation: | |
# Rotate or flip based on EXIF orientation | |
if orientation == 3: | |
img = img.rotate(180, expand=True) | |
elif orientation == 6: | |
img = img.rotate(270, expand=True) | |
elif orientation == 8: | |
img = img.rotate(90, expand=True) | |
except: | |
pass # If EXIF processing fails, use image as-is | |
# check if within the limits for Claude as per https://docs.anthropic.com/en/docs/build-with-claude/vision | |
def calculate_tokens(width, height): | |
return (width * height) / 750 | |
tokens = calculate_tokens(img.width, img.height) | |
long_edge = max(img.width, img.height) | |
format_ok = original_format in ["jpg", "jpeg", "png", "webp"] | |
# Check if the image already meets all requirements | |
if format_ok and (long_edge <= 1568 and tokens <= 1600 and len(image_data) <= 5 * 1024 * 1024): | |
return { | |
"format": original_format, | |
"source": {"bytes": image_data} | |
} | |
# If we need to modify the image, proceed with resizing and/or compression | |
orig_scale_factor = 1 | |
orig_img = img | |
while long_edge > 1568 or tokens > 1600: | |
if long_edge > 1568: | |
scale_factor = min(1568 / long_edge, 0.9) | |
else: | |
scale_factor = min(math.sqrt(1600 / tokens), 0.9) | |
scale_factor = orig_scale_factor * scale_factor | |
orig_scale_factor = scale_factor | |
new_width = int(orig_img.width * scale_factor) | |
new_height = int(orig_img.height * scale_factor) | |
img = orig_img.resize((new_width, new_height), Image.LANCZOS) | |
long_edge = max(img.width, img.height) | |
tokens = calculate_tokens(img.width, img.height) | |
# Try to save in original format first | |
buffer = io.BytesIO() | |
out_fmt = "png" if original_format == "png" else "webp" | |
img.save(buffer, format=out_fmt, quality=95 if out_fmt == "webp" else None) | |
image_data = buffer.getvalue() | |
# If the image is still too large, switch to WebP and compress | |
if len(image_data) > 5 * 1024 * 1024: | |
quality = 95 | |
while len(image_data) > 5 * 1024 * 1024: | |
quality = max(int(quality * 0.9), 20) | |
buffer = io.BytesIO() | |
img.save(buffer, format="webp", quality=quality) | |
image_data = buffer.getvalue() | |
if quality == 20: | |
# If we've reached quality 20 and it's still too large, resize | |
scale_factor = 0.9 | |
new_width = int(img.width * scale_factor) | |
new_height = int(img.height * scale_factor) | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
quality = 95 # Reset quality for the resized image | |
return { | |
"format": "webp", | |
"source": {"bytes": image_data} | |
} | |
def read_response(self, response_stream): | |
""" | |
Handles response stream that may contain both regular text and tool use requests. | |
Yields tuples of (text, tool_request, stop_reason) where: | |
- text: accumulated text response | |
- tool_request: dict with tool use details if present, None otherwise | |
- stop_reason: string indicating why stream stopped, None while streaming | |
""" | |
message = {} | |
content = [] | |
message['content'] = content | |
tool_use = {} | |
text = '' | |
stop_reason = None | |
for chunk in response_stream: | |
if 'messageStart' in chunk: | |
message['role'] = chunk['messageStart']['role'] | |
elif 'contentBlockStart' in chunk: | |
tool = chunk['contentBlockStart']['start']['toolUse'] | |
tool_use['toolUseId'] = tool['toolUseId'] | |
tool_use['name'] = tool['name'] | |
elif 'contentBlockDelta' in chunk: | |
delta = chunk['contentBlockDelta']['delta'] | |
if 'toolUse' in delta: | |
if 'input' not in tool_use: | |
tool_use['input'] = '' | |
tool_use['input'] += delta['toolUse']['input'] | |
elif 'text' in delta: | |
text += delta['text'] | |
yield None, delta['text'] | |
elif 'contentBlockStop' in chunk: | |
if 'input' in tool_use: | |
tool_use['input'] = json.loads(tool_use['input']) | |
content.append({'toolUse': tool_use}) | |
tool_use = {} | |
else: | |
content.append({'text': text}) | |
elif 'messageStop' in chunk: | |
stop_reason = chunk['messageStop']['stopReason'] | |
yield stop_reason, message | |
elif 'metadata' in chunk and 'usage' in chunk['metadata'] and log_to_console: | |
print("\nToken usage:") | |
print(f"Input tokens: {metadata['usage']['inputTokens']}") | |
print(f"Output tokens: {metadata['usage']['outputTokens']}") | |
print(f"Total tokens: {metadata['usage']['totalTokens']}") |