File size: 2,718 Bytes
82b1566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d0ef0f
 
82b1566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import AutoTokenizer, pipeline

import gradio as gr

model_name = "gpt2-large"

tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          trust_remote_code=True
                                          )
tokenizer.pad_token = tokenizer.eos_token
generator = pipeline(task="text-generation",
                     model=model_name,
                     tokenizer=tokenizer,
                     trust_remote_code=True
                     )

def nb_tokens(input):
  return len(tokenizer(input)['input_ids'])

def client_generate(input, max_new_tokens=256, stop_sequences=[]):
  output = generator(
    input,
    max_length=max_new_tokens+nb_tokens(input),
    pad_token_id=50256,
    num_return_sequences=1,
  )
  if len(output)==0 or 'generated_text' not in output[0]:
    return {'text': input, 'generated_text': ''}
  response = output[0]['generated_text'].split(input)[1].strip()
  if type(stop_sequences)==list and len(stop_sequences)>0:
    for seq in stop_sequences:
      response = response[:response.find(seq)]
  return {'text': input, 'generated_text': response}

def respond(message, chat_history, modelname=model_name, max_tokens=128):
    bot_message = client_generate(reshape_prompt(message, modelname),
                                 max_new_tokens=max_tokens,#1024,
                                 stop_sequences=["."],  #stop_sequences to not generate the user answer
                                  )['generated_text']
    chat_history.append((message, f"{bot_message}."))
    return "", chat_history

with gr.Blocks(
    title='RugbyXpert',
#     theme='sudeepshouche/minimalist',  # https://www.gradio.app/guides/theming-guide
) as demo:
    gr.Markdown(
        """
    # RugbyXpert
    """
    )
    chatbot = gr.Chatbot(
        height=500,  # just to fit the notebook
    )
    msg = gr.Textbox(label="Pose-moi une question sur le rugby pendant la saison 2022-2023")
    with gr.Row():
        with gr.Column():
            btn = gr.Button("Submit", variant="primary")
        with gr.Column():
            clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")
    gr.Examples([
        "Tu peux me donner le 21 de Vannes lors du match les opposant à Aurillac du vendredi 24 février 2023 ?",
        "Tu peux me retrouver le score final du match opposant Soyaux-Angoulême à Grenoble le vendredi 17 mars 2023 ?",
        "Dis-moi le score final du match opposant Vannes à Aurillac le vendredi 24 février 2023 ?",
    ], [msg])

    btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
    msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit

demo.launch()