Spaces:
Build error
Build error
| import json | |
| import os | |
| import gradio as gr | |
| from distilabel.llms import LlamaCppLLM | |
| from distilabel.steps.tasks.argillalabeller import ArgillaLabeller | |
| file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf") | |
| download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q8_0.gguf?download=true" | |
| if not os.path.exists(file_path): | |
| import requests | |
| import tqdm | |
| response = requests.get(download_url, stream=True) | |
| total_length = int(response.headers.get("content-length")) | |
| with open(file_path, "wb") as f: | |
| for chunk in tqdm.tqdm( | |
| response.iter_content(chunk_size=1024 * 1024), | |
| total=total_length / (1024 * 1024), | |
| unit="KB", | |
| unit_scale=True, | |
| ): | |
| f.write(chunk) | |
| llm = LlamaCppLLM( | |
| model_path=file_path, | |
| n_gpu_layers=-1, | |
| # n_ctx=1024 * 128, | |
| generation_kwargs={"max_new_tokens": 1024 * 128}, | |
| ) | |
| task = ArgillaLabeller(llm=llm) | |
| task.load() | |
| def load_examples(): | |
| with open("examples.json", "r") as f: | |
| return json.load(f) | |
| # Create Gradio examples | |
| examples = load_examples() | |
| def process_fields(fields): | |
| if isinstance(fields, str): | |
| fields = json.loads(fields) | |
| if isinstance(fields, dict): | |
| fields = [fields] | |
| return [field if isinstance(field, dict) else json.loads(field) for field in fields] | |
| def process_records_gradio(records, example_records, fields, question): | |
| try: | |
| # Convert string inputs to dictionaries | |
| records = json.loads(records) | |
| example_records = json.loads(example_records) if example_records else None | |
| fields = process_fields(fields) if fields else None | |
| question = json.loads(question) if question else None | |
| if not fields and not question: | |
| return "Error: Either fields or question must be provided" | |
| runtime_parameters = {"fields": fields, "question": question} | |
| if example_records: | |
| runtime_parameters["example_records"] = example_records | |
| task.set_runtime_parameters(runtime_parameters) | |
| results = [] | |
| output = task.process(inputs=[{"records": record} for record in records]) | |
| for _ in range(len(records)): | |
| entry = next(output)[0] | |
| if entry["suggestions"]: | |
| results.append(entry["suggestions"].serialize()) | |
| return json.dumps({"results": results}, indent=2) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| description = """ | |
| An example workflow for JSON payload. | |
| ```python | |
| import json | |
| import os | |
| from gradio_client import Client | |
| import argilla as rg | |
| # Initialize Argilla client | |
| client = rg.Argilla( | |
| api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"] | |
| ) | |
| # Load the dataset | |
| dataset = client.datasets(name="my_dataset", workspace="my_workspace") | |
| # Prepare example data | |
| example_field = dataset.settings.fields["my_input_field"].serialize() | |
| example_question = dataset.settings.questions["my_question_to_predict"].serialize() | |
| payload = { | |
| "records": [next(dataset.records()).to_dict()], | |
| "fields": [example_field], | |
| "question": example_question, | |
| } | |
| # Use gradio client to process the data | |
| client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
| result = client.predict( | |
| records=json.dumps(payload["records"]), | |
| example_records=json.dumps(payload["example_records"]), | |
| fields=json.dumps(payload["fields"]), | |
| question=json.dumps(payload["question"]), | |
| api_name="/predict" | |
| ) | |
| ``` | |
| """ | |
| interface = gr.Interface( | |
| fn=process_records_gradio, | |
| inputs=[ | |
| gr.Number(label="Number of Records"), | |
| gr.Code(label="Records (JSON)", language="json", lines=5), | |
| gr.Code(label="Example Records (JSON, optional)", language="json", lines=5), | |
| gr.Code(label="Fields (JSON, optional)", language="json"), | |
| gr.Code(label="Question (JSON, optional)", language="json"), | |
| ], | |
| examples=examples, | |
| outputs=gr.Code(label="Suggestions", language="json", lines=10), | |
| title="Distilabel - ArgillaLabeller - Record Processing Interface", | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |