ndurner's picture
don't bail on reconnect
09f6359
from abc import ABC, abstractmethod
from typing import Type, TypeVar
import base64
import os
import json
from doc2json import process_docx
import fitz
from PIL import Image
import io
import boto3
from botocore.config import Config
import re
from PIL import Image
import io
import math
import gradio
# constants
log_to_console = False
use_document_message_type = False # AWS document message type usage
LLMClass = TypeVar('LLMClass', bound='LLM')
class LLM:
@staticmethod
def create_llm(model: str) -> Type[LLMClass]:
return LLM()
def generate_body(self, message, history):
messages = []
# AWS API requires strict user, assi, user, ... sequence
lastTypeHuman = False
for msg in history:
if msg['role'] == "user":
if lastTypeHuman:
last_msg = messages.pop()
user_msg_parts = last_msg["content"]
else:
user_msg_parts = []
content = msg['content']
if isinstance(content, gradio.File) or isinstance(content, gradio.Image):
user_msg_parts.extend(self._process_file(content.value['path']))
elif isinstance(content, tuple):
user_msg_parts.extend(self._process_file(content[0]))
else:
user_msg_parts.extend([{"text": content}])
messages.append({"role": "user", "content": user_msg_parts})
lastTypeHuman = True
else:
messages.append({
"role": "assistant",
"content":[{"text": msg['content']}]
})
lastTypeHuman = False
if lastTypeHuman:
last_msg = messages.pop()
user_msg_parts = last_msg["content"]
else:
user_msg_parts = []
if message:
if message["text"]:
user_msg_parts.append({"text": message["text"]})
if message["files"]:
for file in message["files"]:
user_msg_parts.extend(self._process_file(file))
if user_msg_parts:
messages.append({"role": "user", "content": user_msg_parts})
return messages
def _process_file(self, file_path):
if use_document_message_type and self._is_supported_document_type(file_path):
return [self._create_document_message(file_path)]
else:
return self._encode_file(file_path)
def _is_supported_document_type(self, file_path):
supported_extensions = ['.pdf', '.csv', '.doc', '.docx', '.xls', '.xlsx', '.html', '.txt', '.md']
return os.path.splitext(file_path)[1].lower() in supported_extensions
def _create_document_message(self, file_path):
with open(file_path, 'rb') as file:
file_content = file.read()
file_name = re.sub(r'[^a-zA-Z0-9\s\-\(\)\[\]]', '', os.path.basename(file_path))[:200].strip() or "unnamed_file"
file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot
return {
"document": {
"name": file_name,
"format": file_extension,
"source": {
"bytes": file_content
}
}
}
def _encode_file(self, fn: str) -> list:
if fn.endswith(".docx"):
return [{"text": process_docx(fn)}]
elif fn.endswith(".pdf"):
return self._process_pdf_img(fn)
else:
with open(fn, mode="rb") as f:
content = f.read()
if isinstance(content, bytes):
try:
# try to add as image
image_data = self._encode_image(content)
return [{"image": image_data}]
except:
# not an image, try text
content = content.decode('utf-8', 'replace')
else:
content = str(content)
fname = os.path.basename(fn)
return [{"text": f"``` {fname}\n{content}\n```"}]
def _process_pdf_img(self, pdf_fn: str):
pdf = fitz.open(pdf_fn)
message_parts = []
page_scales = {} # Cache for similar page sizes
def calculate_tokens(width, height):
return (width * height) / 750
for page in pdf.pages():
page_rect = page.rect
orig_width = page_rect.width
orig_height = page_rect.height
page_key = (orig_width, orig_height)
# Use cached scale as starting point if available
scale = page_scales.get(page_key, 1.0)
while True:
# Render with current scale
mat = fitz.Matrix(scale, scale)
pix = page.get_pixmap(matrix=mat, alpha=False)
# Check actual rendered dimensions
actual_tokens = calculate_tokens(pix.width, pix.height)
actual_long_edge = max(pix.width, pix.height)
if actual_long_edge <= 1568 and actual_tokens <= 1600:
# We found a good scale, cache it
if page_key not in page_scales:
page_scales[page_key] = scale
break
# Calculate new scale factor based on both constraints
if actual_long_edge > 1568:
scale_factor = min(1568 / actual_long_edge, 0.9)
else:
scale_factor = min(math.sqrt(1600 / actual_tokens), 0.9)
scale *= scale_factor
# Convert to PIL Image
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Handle compression
quality = 95
while True:
buffer = io.BytesIO()
img.save(buffer, format="webp", quality=quality)
img_bytes = buffer.getvalue()
if len(img_bytes) <= 5 * 1024 * 1024 or quality <= 20:
break
quality = max(int(quality * 0.9), 20)
message_parts.append({"text": f"Page {page.number + 1} of file '{pdf_fn}'"})
message_parts.append({"image": {
"format": "webp",
"source": {"bytes": img_bytes}
}})
pdf.close()
return message_parts
def _encode_image(self, image_data):
try:
# Open the image using Pillow
img = Image.open(io.BytesIO(image_data))
original_format = img.format.lower()
except IOError:
raise Exception("Unknown image type")
# Ensure correct orientation based on EXIF
try:
exif = img._getexif()
if exif:
orientation = exif.get(274) # 274 is the orientation tag
if orientation:
# Rotate or flip based on EXIF orientation
if orientation == 3:
img = img.rotate(180, expand=True)
elif orientation == 6:
img = img.rotate(270, expand=True)
elif orientation == 8:
img = img.rotate(90, expand=True)
except:
pass # If EXIF processing fails, use image as-is
# check if within the limits for Claude as per https://docs.anthropic.com/en/docs/build-with-claude/vision
def calculate_tokens(width, height):
return (width * height) / 750
tokens = calculate_tokens(img.width, img.height)
long_edge = max(img.width, img.height)
format_ok = original_format in ["jpg", "jpeg", "png", "webp"]
# Check if the image already meets all requirements
if format_ok and (long_edge <= 1568 and tokens <= 1600 and len(image_data) <= 5 * 1024 * 1024):
return {
"format": original_format,
"source": {"bytes": image_data}
}
# If we need to modify the image, proceed with resizing and/or compression
orig_scale_factor = 1
orig_img = img
while long_edge > 1568 or tokens > 1600:
if long_edge > 1568:
scale_factor = min(1568 / long_edge, 0.9)
else:
scale_factor = min(math.sqrt(1600 / tokens), 0.9)
scale_factor = orig_scale_factor * scale_factor
orig_scale_factor = scale_factor
new_width = int(orig_img.width * scale_factor)
new_height = int(orig_img.height * scale_factor)
img = orig_img.resize((new_width, new_height), Image.LANCZOS)
long_edge = max(img.width, img.height)
tokens = calculate_tokens(img.width, img.height)
# Try to save in original format first
buffer = io.BytesIO()
out_fmt = "png" if original_format == "png" else "webp"
img.save(buffer, format=out_fmt, quality=95 if out_fmt == "webp" else None)
image_data = buffer.getvalue()
# If the image is still too large, switch to WebP and compress
if len(image_data) > 5 * 1024 * 1024:
quality = 95
while len(image_data) > 5 * 1024 * 1024:
quality = max(int(quality * 0.9), 20)
buffer = io.BytesIO()
img.save(buffer, format="webp", quality=quality)
image_data = buffer.getvalue()
if quality == 20:
# If we've reached quality 20 and it's still too large, resize
scale_factor = 0.9
new_width = int(img.width * scale_factor)
new_height = int(img.height * scale_factor)
img = img.resize((new_width, new_height), Image.LANCZOS)
quality = 95 # Reset quality for the resized image
return {
"format": "webp",
"source": {"bytes": image_data}
}
def read_response(self, response_stream):
"""
Handles response stream that may contain both regular text and tool use requests.
Yields tuples of (text, tool_request, stop_reason) where:
- text: accumulated text response
- tool_request: dict with tool use details if present, None otherwise
- stop_reason: string indicating why stream stopped, None while streaming
"""
message = {}
content = []
message['content'] = content
tool_use = {}
text = ''
stop_reason = None
for chunk in response_stream:
if 'messageStart' in chunk:
message['role'] = chunk['messageStart']['role']
elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']
elif 'contentBlockDelta' in chunk:
delta = chunk['contentBlockDelta']['delta']
if 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']
elif 'text' in delta:
text += delta['text']
yield None, delta['text']
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
tool_use['input'] = json.loads(tool_use['input'])
content.append({'toolUse': tool_use})
tool_use = {}
else:
content.append({'text': text})
elif 'messageStop' in chunk:
stop_reason = chunk['messageStop']['stopReason']
yield stop_reason, message
elif 'metadata' in chunk and 'usage' in chunk['metadata'] and log_to_console:
print("\nToken usage:")
print(f"Input tokens: {metadata['usage']['inputTokens']}")
print(f"Output tokens: {metadata['usage']['outputTokens']}")
print(f"Total tokens: {metadata['usage']['totalTokens']}")