tameto commited on
Commit
26844cd
1 Parent(s): f26e683
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -1,11 +1,7 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
 
5
- model_name = "elyza/Llama-3-ELYZA-JP-8B"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
9
 
10
  SYSTEM_MESSAGE = """
11
  あなたは関西弁で話す生命保険の営業マンです。お客様の状況を理解し、適切な保険プランを提案することが仕事です。以下の点に注意してください:
@@ -32,22 +28,19 @@ def create_prompt(message, history):
32
  def respond(message, history, max_tokens, temperature, top_p):
33
  prompt = create_prompt(message, history)
34
 
35
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
36
-
37
- with torch.no_grad():
38
- output = model.generate(
39
- input_ids,
40
- max_new_tokens=min(max_tokens, 125), # 約250文字
41
- temperature=temperature,
42
- top_p=top_p,
43
- do_sample=True,
44
- pad_token_id=tokenizer.eos_token_id,
45
- )
46
 
47
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
48
- assistant_response = generated_text.split("助手: ")[-1]
 
 
 
 
 
49
 
50
- truncated_response = assistant_response[:250]
 
51
  last_punctuation = max(
52
  truncated_response.rfind('。'),
53
  truncated_response.rfind('!'),
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
 
3
 
4
+ client = InferenceClient("elyza/Llama-3-ELYZA-JP-8B")
 
 
 
5
 
6
  SYSTEM_MESSAGE = """
7
  あなたは関西弁で話す生命保険の営業マンです。お客様の状況を理解し、適切な保険プランを提案することが仕事です。以下の点に注意してください:
 
28
  def respond(message, history, max_tokens, temperature, top_p):
29
  prompt = create_prompt(message, history)
30
 
31
+ # トークン数を調整して、約250文字になるように設定
32
+ estimated_max_tokens = min(max_tokens, 125) # 日本語の場合、1トークンは約2文字に相当
 
 
 
 
 
 
 
 
 
33
 
34
+ response = client.text_generation(
35
+ prompt,
36
+ max_new_tokens=estimated_max_tokens,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ stop_sequences=["\n", "人間:"] # 改行または次の人間の入力で生成を停止
40
+ )
41
 
42
+ # 250文字で切り取り、最後の文が途中で切れないように調整
43
+ truncated_response = response[:250]
44
  last_punctuation = max(
45
  truncated_response.rfind('。'),
46
  truncated_response.rfind('!'),