Spaces:
Build error
Build error
import json | |
import os | |
import gradio as gr | |
from distilabel.llms import LlamaCppLLM | |
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller | |
file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf") | |
download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q8_0.gguf?download=true" | |
if not os.path.exists(file_path): | |
import requests | |
import tqdm | |
response = requests.get(download_url, stream=True) | |
total_length = int(response.headers.get("content-length")) | |
with open(file_path, "wb") as f: | |
for chunk in tqdm.tqdm( | |
response.iter_content(chunk_size=1024 * 1024), | |
total=total_length / (1024 * 1024), | |
unit="KB", | |
unit_scale=True, | |
): | |
f.write(chunk) | |
context_window = 1024 * 128 | |
llm = LlamaCppLLM( | |
model_path=file_path, | |
n_gpu_layers=-1, | |
n_ctx=context_window, | |
generation_kwargs={"max_new_tokens": context_window}, | |
) | |
task = ArgillaLabeller(llm=llm) | |
task.load() | |
def load_examples(): | |
with open("examples.json", "r") as f: | |
return json.load(f) | |
# Create Gradio examples | |
examples = load_examples() | |
def process_fields(fields): | |
if isinstance(fields, str): | |
fields = json.loads(fields) | |
if isinstance(fields, dict): | |
fields = [fields] | |
return [field if isinstance(field, dict) else json.loads(field) for field in fields] | |
def process_records_gradio(records, example_records, fields, question): | |
try: | |
# Convert string inputs to dictionaries | |
records = json.loads(records) | |
example_records = json.loads(example_records) if example_records else None | |
fields = process_fields(fields) if fields else None | |
question = json.loads(question) if question else None | |
if not fields and not question: | |
return "Error: Either fields or question must be provided" | |
runtime_parameters = {"fields": fields, "question": question} | |
if example_records: | |
runtime_parameters["example_records"] = example_records | |
task.set_runtime_parameters(runtime_parameters) | |
results = [] | |
output = task.process(inputs=[{"records": record} for record in records]) | |
for _ in range(len(records)): | |
entry = next(output)[0] | |
if entry["suggestions"]: | |
results.append(entry["suggestions"].serialize()) | |
return json.dumps({"results": results}, indent=2) | |
except Exception as e: | |
raise Exception(f"Error: {str(e)}") | |
return f"Error: {str(e)}" | |
description = """ | |
An example workflow for JSON payload. | |
```python | |
import json | |
import os | |
from gradio_client import Client | |
import argilla as rg | |
# Initialize Argilla client | |
client = rg.Argilla( | |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"] | |
) | |
# Load the dataset | |
dataset = client.datasets(name="my_dataset", workspace="my_workspace") | |
# Prepare example data | |
example_field = dataset.settings.fields["my_input_field"].serialize() | |
example_question = dataset.settings.questions["my_question_to_predict"].serialize() | |
payload = { | |
"records": [next(dataset.records()).to_dict()], | |
"fields": [example_field], | |
"question": example_question, | |
} | |
# Use gradio client to process the data | |
client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
result = client.predict( | |
records=json.dumps(payload["records"]), | |
example_records=json.dumps(payload["example_records"]), | |
fields=json.dumps(payload["fields"]), | |
question=json.dumps(payload["question"]), | |
api_name="/predict" | |
) | |
``` | |
""" | |
interface = gr.Interface( | |
fn=process_records_gradio, | |
inputs=[ | |
gr.Code(label="Records (JSON)", language="json", lines=5), | |
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5), | |
gr.Code(label="Fields (JSON, optional)", language="json"), | |
gr.Code(label="Question (JSON, optional)", language="json"), | |
], | |
examples=examples, | |
outputs=gr.Code(label="Suggestions", language="json", lines=10), | |
title="Distilabel - ArgillaLabeller - Record Processing Interface", | |
description=description, | |
) | |
if __name__ == "__main__": | |
interface.launch() | |