JURAN / app.py
minoD's picture
Update app.py
560c76a verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
import os
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
model_name = "minoD/JURAN"
# モデルのロード
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# ウォームアップフラグ
warmup_done = False
def generate_prompt(F):
result = f"""### 指示:
あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。
### エントリーシート:
{F}
### 面接官の質問:"""
result = result.replace('\n', '<NL>')
return result
@spaces.GPU(duration=60)
def warmup_model():
"""モデルのウォームアップ処理"""
global warmup_done
if not warmup_done:
print("ウォームアップ中...")
model.to("cuda")
# ダミー推論を実行
dummy_input = tokenizer("テスト", return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
_ = model.generate(
dummy_input,
max_new_tokens=10,
do_sample=False
)
model.to("cpu")
torch.cuda.empty_cache()
warmup_done = True
print("ウォームアップ完了")
@spaces.GPU(duration=60)
def generate2(F=None, maxTokens=256):
try:
# ウォームアップ(初回のみ)
if not warmup_done:
warmup_model()
# 乱数シードを固定(オプション)
torch.manual_seed(42)
model.to("cuda")
prompt = generate_prompt(F)
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=maxTokens,
do_sample=True,
temperature=0.7,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
model.to("cpu")
torch.cuda.empty_cache()
outputs = outputs[0].tolist()
decoded = tokenizer.decode(outputs)
if tokenizer.eos_token_id in outputs:
eos_index = outputs.index(tokenizer.eos_token_id)
decoded = tokenizer.decode(outputs[:eos_index])
sentinel = "### 面接官の質問:"
sentinelLoc = decoded.find(sentinel)
if sentinelLoc >= 0:
result = decoded[sentinelLoc + len(sentinel):]
result = result.split('\n')[0] if '\n' in result else result
return result.replace("<NL>", "\n").strip()
else:
return 'Warning: Expected prompt template to be emitted. Ignoring output.'
except Exception as e:
return f"エラーが発生しました: {str(e)}"
def inference(input_text):
return generate2(input_text)
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
outputs=gr.Textbox(label="想定される質問"),
title="JURAN🌺",
description="面接官モデルが回答を生成します。",
api_name="ask",
flagging_mode="never"
)
iface.launch(
server_name="0.0.0.0",
server_port=7860
)