davidberenstein1957 HF staff commited on
Commit
dff7018
1 Parent(s): f39e1f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -14
app.py CHANGED
@@ -1,32 +1,70 @@
1
  import json
2
 
 
3
  import gradio as gr
 
4
  from distilabel.llms import TransformersLLM
5
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
 
6
 
7
- llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  task = ArgillaLabeller(llm=llm)
9
  task.load()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @spaces.GPU
12
- def process_records_gradio(records, example_records, field, question):
13
  try:
14
  # Convert string inputs to dictionaries
15
  records = json.loads(records)
16
  example_records = json.loads(example_records) if example_records else None
17
- field = json.loads(field) if field else None
18
  question = json.loads(question) if question else None
19
 
20
- if not field and not question:
21
- return "Error: Either field or question must be provided"
 
 
 
 
22
 
23
- task.set_runtime_parameters(
24
- {
25
- "fields": [field] if field else None,
26
- "question": question,
27
- "example_records": example_records,
28
- }
29
- )
30
 
31
  results = []
32
  for record in records:
@@ -43,12 +81,13 @@ interface = gr.Interface(
43
  inputs=[
44
  gr.Code(label="Records (JSON)", language="json", lines=5),
45
  gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
46
- gr.Code(label="Field (JSON, optional)", language="json"),
47
  gr.Code(label="Question (JSON, optional)", language="json"),
48
  ],
 
49
  outputs=gr.Code(label="Suggestions", language="json", lines=10),
50
  title="Record Processing Interface",
51
- description="Enter JSON data for records, example records, field, and question. At least one of field or question must be provided.",
52
  )
53
 
54
  if __name__ == "__main__":
 
1
  import json
2
 
3
+ import spaces
4
  import gradio as gr
5
+ import torch
6
  from distilabel.llms import TransformersLLM
7
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
8
+ from transformers import BitsAndBytesConfig
9
 
10
+ quantization_config = BitsAndBytesConfig(
11
+ load_in_4bit=True,
12
+ bnb_4bit_compute_dtype=torch.bfloat16,
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ )
16
+
17
+
18
+ llm = TransformersLLM(
19
+ model="microsoft/Phi-3-mini-4k-instruct",
20
+ torch_dtype="float16",
21
+ model_kwargs={
22
+ "quantization_config": quantization_config,
23
+ "device_map": "auto",
24
+ },
25
+ )
26
  task = ArgillaLabeller(llm=llm)
27
  task.load()
28
 
29
+
30
+ def load_examples():
31
+ with open("examples.json", "r") as f:
32
+ return json.load(f)
33
+
34
+
35
+ # Create Gradio examples
36
+ examples = load_examples()
37
+
38
+
39
+ def process_fields(fields):
40
+ if isinstance(fields, str):
41
+ fields = json.loads(fields)
42
+ if isinstance(fields, dict):
43
+ fields = [fields]
44
+ return [field if isinstance(field, dict) else json.loads(field) for field in fields]
45
+
46
+
47
  @spaces.GPU
48
+ def process_records_gradio(records, example_records, fields, question):
49
  try:
50
  # Convert string inputs to dictionaries
51
  records = json.loads(records)
52
  example_records = json.loads(example_records) if example_records else None
53
+ fields = process_fields(fields) if fields else None
54
  question = json.loads(question) if question else None
55
 
56
+ print(fields)
57
+ print(question)
58
+ print(example_records)
59
+
60
+ if not fields and not question:
61
+ return "Error: Either fields or question must be provided"
62
 
63
+ runtime_parameters = {"fields": fields, "question": question}
64
+ if example_records:
65
+ runtime_parameters["example_records"] = example_records
66
+ print(runtime_parameters)
67
+ task.set_runtime_parameters(runtime_parameters)
 
 
68
 
69
  results = []
70
  for record in records:
 
81
  inputs=[
82
  gr.Code(label="Records (JSON)", language="json", lines=5),
83
  gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
84
+ gr.Code(label="Fields (JSON, optional)", language="json"),
85
  gr.Code(label="Question (JSON, optional)", language="json"),
86
  ],
87
+ examples=examples,
88
  outputs=gr.Code(label="Suggestions", language="json", lines=10),
89
  title="Record Processing Interface",
90
+ description="Enter JSON data for `rg.Record.to_dict()`, `List[rg.Record.to_dict()]`, `List[Field].serialize()`, or `List[rg.Question.serialize()]` At least one of fields or question must be provided.",
91
  )
92
 
93
  if __name__ == "__main__":