Spaces:
Build error
Build error
File size: 3,735 Bytes
36821d3 c9bd449 d2e7f91 c9bd449 6199610 dff7018 36821d3 dff7018 aaaeb76 6199610 36821d3 d3fc1a4 aa23dc4 d3fc1a4 aa23dc4 d3fc1a4 aa23dc4 d3fc1a4 aa23dc4 36821d3 dff7018 607292d 36821d3 dff7018 78f9744 dff7018 36821d3 6199610 78f9744 6199610 36821d3 d3fc1a4 36821d3 aaaeb76 21cc5dc 778c655 21cc5dc 36821d3 dff7018 36821d3 dff7018 607292d 36821d3 21cc5dc 36821d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import json
import gradio as gr
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 1000 * 4},
)
task = ArgillaLabeller(llm=llm)
task.load()
def load_examples():
with open("examples.json", "r") as f:
return json.load(f)
# Create Gradio examples
examples = load_examples()
def process_fields(fields):
if isinstance(fields, str):
fields = json.loads(fields)
if isinstance(fields, dict):
fields = [fields]
return [field if isinstance(field, dict) else json.loads(field) for field in fields]
def process_records_gradio(records, fields, question, example_records=None):
try:
# Convert string inputs to dictionaries
if isinstance(records, str) and records:
records = json.loads(records)
if isinstance(example_records, str) and example_records:
example_records = json.loads(example_records)
if isinstance(fields, str) and fields:
fields = json.loads(fields)
if isinstance(question, str) and question:
question = json.loads(question)
if not fields and not question:
raise Exception("Error: Either fields or question must be provided")
runtime_parameters = {"fields": fields, "question": question}
if example_records:
runtime_parameters["example_records"] = example_records
task.set_runtime_parameters(runtime_parameters)
results = []
output = task.process(inputs=[{"record": record} for record in records])
for _ in range(len(records)):
entry = next(output)[0]
if entry["suggestions"]:
results.append(entry["suggestions"])
return json.dumps({"results": results}, indent=2)
except Exception as e:
raise gr.Error(f"Error: {str(e)}")
description = """
An example workflow for JSON payload.
```python
import json
import os
from gradio_client import Client
import argilla as rg
# Initialize Argilla client
client = rg.Argilla(
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
)
# Load the dataset
dataset = client.datasets(name="my_dataset", workspace="my_workspace")
# Prepare example data
example_field = dataset.settings.fields["my_input_field"].serialize()
example_question = dataset.settings.questions["my_question_to_predict"].serialize()
payload = {
"records": [next(dataset.records()).to_dict()],
"fields": [example_field],
"question": example_question,
}
# Use gradio client to process the data
client = Client("davidberenstein1957/distilabel-argilla-labeller")
result = client.predict(
records=json.dumps(payload["records"]),
example_records=json.dumps(payload["example_records"]),
fields=json.dumps(payload["fields"]),
question=json.dumps(payload["question"]),
api_name="/predict"
)
```
"""
interface = gr.Interface(
fn=process_records_gradio,
inputs=[
gr.Code(label="Records (JSON)", language="json", lines=5),
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
gr.Code(label="Fields (JSON, optional)", language="json"),
gr.Code(label="Question (JSON, optional)", language="json"),
],
examples=examples,
cache_examples=True,
outputs=gr.Code(label="Suggestions", language="json", lines=10),
title="Distilabel - ArgillaLabeller - Record Processing Interface",
description=description,
)
if __name__ == "__main__":
interface.launch()
|