Spaces:
Running
Running
import os | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
import re | |
from groq import Groq | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Serve static files for assets | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize Hugging Face Inference Client | |
#client = InferenceClient() | |
client = Groq() | |
# Pydantic model for API input | |
class InfographicRequest(BaseModel): | |
description: str | |
# Load prompt template from environment variable | |
SYSTEM_INSTRUCT = os.getenv("SYSTEM_INSTRUCTOR") | |
PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE") | |
async def extract_code_blocks(markdown_text): | |
""" | |
Extracts code blocks from the given Markdown text. | |
Args: | |
markdown_text (str): The Markdown content as a string. | |
Returns: | |
list: A list of code blocks extracted from the Markdown. | |
""" | |
# Regex to match code blocks (fenced with triple backticks) | |
code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL) | |
# Find all code blocks | |
code_blocks = code_block_pattern.findall(markdown_text) | |
return code_blocks | |
async def generate_infographic(infoRequest): | |
prompt = PROMPT_TEMPLATE.format(description=description) | |
generated_completion = client.chat.completions.create( | |
model="llama-3.1-70b-versatile", | |
messages=[ | |
{"role": "system", "content": SYSTEM_INSTRUCT}, | |
{"role": "user", "content": prompt} | |
], | |
temperature=0.5, | |
max_tokens=5000, | |
top_p=1, | |
stream=False, | |
stop=None | |
) | |
generated_text = generated_completion.choices[0].message.content | |
code_blocks= await extract_code_blocks(generated_text) | |
return code_blocks | |
# Route to serve the HTML template | |
async def serve_frontend(): | |
return HTMLResponse(open("static/infographic_gen.html").read()) | |
# Route to handle infographic generation | |
async def generate_infographic(request: InfographicRequest): | |
description = request.description | |
prompt = PROMPT_TEMPLATE.format(description=description) | |
try: | |
# messages = [{"role": "user", "content": prompt}] | |
# stream = client.chat.completions.create( | |
# model="Qwen/Qwen2.5-Coder-32B-Instruct", | |
# messages=messages, | |
# temperature=0.4, | |
# max_tokens=6000, | |
# top_p=0.7, | |
# stream=True, | |
# ) | |
# generated_text = "" | |
# for chunk in stream: | |
# generated_text += chunk.choices[0].delta.content | |
# print(generated_text) | |
#code_blocks= await extract_code_blocks(generated_text) | |
code_blocks=generate_infographic(description) | |
if code_blocks: | |
return JSONResponse(content={"html": code_blocks[0]}) | |
else: | |
return JSONResponse(content={"error": "No generation"},status_code=500) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |