|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import spaces |
|
import os |
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
|
|
|
model_name = "minoD/JURAN" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="cpu", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
|
|
|
warmup_done = False |
|
|
|
def generate_prompt(F): |
|
result = f"""### 指示: |
|
あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。 |
|
|
|
### エントリーシート: |
|
{F} |
|
|
|
### 面接官の質問:""" |
|
result = result.replace('\n', '<NL>') |
|
return result |
|
|
|
@spaces.GPU(duration=60) |
|
def warmup_model(): |
|
"""モデルのウォームアップ処理""" |
|
global warmup_done |
|
if not warmup_done: |
|
print("ウォームアップ中...") |
|
model.to("cuda") |
|
|
|
|
|
dummy_input = tokenizer("テスト", return_tensors="pt").input_ids.to("cuda") |
|
with torch.no_grad(): |
|
_ = model.generate( |
|
dummy_input, |
|
max_new_tokens=10, |
|
do_sample=False |
|
) |
|
|
|
model.to("cpu") |
|
torch.cuda.empty_cache() |
|
warmup_done = True |
|
print("ウォームアップ完了") |
|
|
|
@spaces.GPU(duration=60) |
|
def generate2(F=None, maxTokens=256): |
|
try: |
|
|
|
if not warmup_done: |
|
warmup_model() |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
model.to("cuda") |
|
|
|
prompt = generate_prompt(F) |
|
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=maxTokens, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.75, |
|
top_k=40, |
|
no_repeat_ngram_size=2, |
|
) |
|
|
|
model.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
outputs = outputs[0].tolist() |
|
decoded = tokenizer.decode(outputs) |
|
|
|
if tokenizer.eos_token_id in outputs: |
|
eos_index = outputs.index(tokenizer.eos_token_id) |
|
decoded = tokenizer.decode(outputs[:eos_index]) |
|
|
|
sentinel = "### 面接官の質問:" |
|
sentinelLoc = decoded.find(sentinel) |
|
if sentinelLoc >= 0: |
|
result = decoded[sentinelLoc + len(sentinel):] |
|
result = result.split('\n')[0] if '\n' in result else result |
|
return result.replace("<NL>", "\n").strip() |
|
else: |
|
return 'Warning: Expected prompt template to be emitted. Ignoring output.' |
|
|
|
except Exception as e: |
|
return f"エラーが発生しました: {str(e)}" |
|
|
|
def inference(input_text): |
|
return generate2(input_text) |
|
|
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"), |
|
outputs=gr.Textbox(label="想定される質問"), |
|
title="JURAN🌺", |
|
description="面接官モデルが回答を生成します。", |
|
api_name="ask", |
|
flagging_mode="never" |
|
) |
|
|
|
iface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |