davidberenstein1957's picture
Update app.py
dff7018 verified
raw
history blame
2.93 kB
import json
import spaces
import gradio as gr
import torch
from distilabel.llms import TransformersLLM
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
llm = TransformersLLM(
model="microsoft/Phi-3-mini-4k-instruct",
torch_dtype="float16",
model_kwargs={
"quantization_config": quantization_config,
"device_map": "auto",
},
)
task = ArgillaLabeller(llm=llm)
task.load()
def load_examples():
with open("examples.json", "r") as f:
return json.load(f)
# Create Gradio examples
examples = load_examples()
def process_fields(fields):
if isinstance(fields, str):
fields = json.loads(fields)
if isinstance(fields, dict):
fields = [fields]
return [field if isinstance(field, dict) else json.loads(field) for field in fields]
@spaces.GPU
def process_records_gradio(records, example_records, fields, question):
try:
# Convert string inputs to dictionaries
records = json.loads(records)
example_records = json.loads(example_records) if example_records else None
fields = process_fields(fields) if fields else None
question = json.loads(question) if question else None
print(fields)
print(question)
print(example_records)
if not fields and not question:
return "Error: Either fields or question must be provided"
runtime_parameters = {"fields": fields, "question": question}
if example_records:
runtime_parameters["example_records"] = example_records
print(runtime_parameters)
task.set_runtime_parameters(runtime_parameters)
results = []
for record in records:
output = next(task.process(inputs=[{"records": record}]))
results.append(output[0]["suggestions"])
return json.dumps({"results": results}, indent=2)
except Exception as e:
return f"Error: {str(e)}"
interface = gr.Interface(
fn=process_records_gradio,
inputs=[
gr.Code(label="Records (JSON)", language="json", lines=5),
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
gr.Code(label="Fields (JSON, optional)", language="json"),
gr.Code(label="Question (JSON, optional)", language="json"),
],
examples=examples,
outputs=gr.Code(label="Suggestions", language="json", lines=10),
title="Record Processing Interface",
description="Enter JSON data for `rg.Record.to_dict()`, `List[rg.Record.to_dict()]`, `List[Field].serialize()`, or `List[rg.Question.serialize()]` At least one of fields or question must be provided.",
)
if __name__ == "__main__":
interface.launch()