Spaces:
Sleeping
Sleeping
update
Browse files
app.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
import torch
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
9 |
|
10 |
SYSTEM_MESSAGE = """
|
11 |
あなたは関西弁で話す生命保険の営業マンです。お客様の状況を理解し、適切な保険プランを提案することが仕事です。以下の点に注意してください:
|
@@ -32,22 +28,19 @@ def create_prompt(message, history):
|
|
32 |
def respond(message, history, max_tokens, temperature, top_p):
|
33 |
prompt = create_prompt(message, history)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
with torch.no_grad():
|
38 |
-
output = model.generate(
|
39 |
-
input_ids,
|
40 |
-
max_new_tokens=min(max_tokens, 125), # 約250文字
|
41 |
-
temperature=temperature,
|
42 |
-
top_p=top_p,
|
43 |
-
do_sample=True,
|
44 |
-
pad_token_id=tokenizer.eos_token_id,
|
45 |
-
)
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
|
|
51 |
last_punctuation = max(
|
52 |
truncated_response.rfind('。'),
|
53 |
truncated_response.rfind('!'),
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
|
|
3 |
|
4 |
+
client = InferenceClient("elyza/Llama-3-ELYZA-JP-8B")
|
|
|
|
|
|
|
5 |
|
6 |
SYSTEM_MESSAGE = """
|
7 |
あなたは関西弁で話す生命保険の営業マンです。お客様の状況を理解し、適切な保険プランを提案することが仕事です。以下の点に注意してください:
|
|
|
28 |
def respond(message, history, max_tokens, temperature, top_p):
|
29 |
prompt = create_prompt(message, history)
|
30 |
|
31 |
+
# トークン数を調整して、約250文字になるように設定
|
32 |
+
estimated_max_tokens = min(max_tokens, 125) # 日本語の場合、1トークンは約2文字に相当
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
response = client.text_generation(
|
35 |
+
prompt,
|
36 |
+
max_new_tokens=estimated_max_tokens,
|
37 |
+
temperature=temperature,
|
38 |
+
top_p=top_p,
|
39 |
+
stop_sequences=["\n", "人間:"] # 改行または次の人間の入力で生成を停止
|
40 |
+
)
|
41 |
|
42 |
+
# 250文字で切り取り、最後の文が途中で切れないように調整
|
43 |
+
truncated_response = response[:250]
|
44 |
last_punctuation = max(
|
45 |
truncated_response.rfind('。'),
|
46 |
truncated_response.rfind('!'),
|