3v324v23 commited on
Commit
370a5dd
·
1 Parent(s): 654a1b6

first commit

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ License:
3
+ Unless required by applicable law or agreed to in writing, software
4
+ distributed under the License is distributed on an "AS IS" BASIS,
5
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
+ In no event shall the authors or copyright holders be liable
7
+ for any claim, damages or other liability, whether in an action of contract,otherwise,
8
+ arising from, out of or in connection with the software or the use or
9
+ other dealings in the software.
10
+
11
+ Copyright (c) 2024 pi19404. All rights reserved.
12
+
13
+ Authors:
14
+ pi19404 <pi19404@gmail.com>
15
+ """
16
+
17
+
18
+ """
19
+ Gradio Interface for Shield Gemma LLM Evaluator
20
+
21
+ This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator.
22
+ It allows users to input JSON data and select various options to evaluate the content
23
+ for policy violations.
24
+
25
+ Functions:
26
+ my_inference_function: The main inference function to process input data and return results.
27
+ """
28
+
29
+ import gradio as gr
30
+ from gradio_client import Client
31
+ import torch
32
+ import json
33
+ import threading
34
+ import os
35
+
36
+ API_TOKEN=os.getenv("API_TOKEN")
37
+
38
+ lock = threading.Lock()
39
+ client = Client("pi19404/ai-worker",hf_token=API_TOKEN)
40
+
41
+ def my_inference_function(input_data, output_data,mode, max_length, max_new_tokens, model_size):
42
+ """
43
+ The main inference function to process input data and return results.
44
+
45
+ Args:
46
+ input_data (str or dict): The input data in JSON format.
47
+ mode (str): The mode of operation ("scoring" or "generative").
48
+ max_length (int): The maximum length of the input prompt.
49
+ max_new_tokens (int): The maximum number of new tokens to generate.
50
+ model_size (str): The size of the model to be used.
51
+
52
+ Returns:
53
+ str: The output data in JSON format.
54
+ """
55
+ with lock:
56
+ try:
57
+
58
+
59
+
60
+ result = client.predict(
61
+ input_data=input_data,
62
+ output_data=output_data,
63
+ mode=mode,
64
+ max_length=max_length,
65
+ max_new_tokens=max_new_tokens,
66
+ model_size=model_size,
67
+ api_name="/my_inference_function"
68
+ )
69
+ print(result)
70
+ print("entering return",result)
71
+ return result # Pretty-print the JSON
72
+ except json.JSONDecodeError:
73
+ return json.dumps({"error": "Invalid JSON input"})
74
+ except KeyError:
75
+ return json.dumps({"error": "Missing 'input' key in JSON"})
76
+ except ValueError as e:
77
+ return json.dumps({"error": str(e)})
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("## LLM Safety Evaluation")
81
+
82
+ with gr.Tab("ShieldGemma2"):
83
+ input_text = gr.Textbox(label="Input Text")
84
+ output_text = gr.Textbox(
85
+ label="Response Text",
86
+ lines=5,
87
+ max_lines=10,
88
+ show_copy_button=True,
89
+ elem_classes=["wrap-text"]
90
+ )
91
+ mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode")
92
+ max_length_input = gr.Number(label="Max Length", value=150)
93
+ max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024)
94
+ model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
95
+ response_text = gr.Textbox(
96
+ label="Output Text",
97
+ lines=10,
98
+ max_lines=20,
99
+ show_copy_button=True,
100
+ elem_classes=["wrap-text"]
101
+ )
102
+ text_button = gr.Button("Submit")
103
+ text_button.click(fn=my_inference_function, inputs=[input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
104
+
105
+ # with gr.Tab("API Input"):
106
+ # api_input = gr.JSON(label="Input JSON")
107
+ # mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode")
108
+ # max_length_input_api = gr.Number(label="Max Length", value=150)
109
+ # max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None)
110
+ # model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
111
+ # api_output = gr.JSON(label="Output JSON")
112
+ # api_button = gr.Button("Submit")
113
+ # api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)
114
+
115
+ demo.launch(share=True)
116
+
117
+
118
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio_client