File size: 3,535 Bytes
c19c108
4639678
c19c108
 
 
 
 
 
 
5fa1d46
c19c108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Load the packages and configurations
import os
import pandas as pd
from retriever import BM25Plus
import streamlit as st
from groq import Groq

# Configure GROQ API_KEY
client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

# RTL formatting of the streamlit
st.markdown("""
<style>
body, html {
    direction: RTL;
    unicode-bidi: plaintext;
    text-align: right;
}
p, div, input, label, h1, h2, h3, h4, h5, h6 {
    direction: RTL;
    unicode-bidi: plaintext;
    text-align: right;
}
</style>
""", unsafe_allow_html=True)

# Streamlit details
st.title("DarooGAP")
st.sidebar.title("💊 داروگپ 💊")
st.sidebar.divider()

# Load the dataset
df = pd.read_csv('darooyab_qa.csv')
corpus = df.loc[:, 'Corpus'].to_list()


# Setup the BM25 retriever
def get_prompt(query, query_bm25):
    prompt = 'به سوال کاربر بر اساس متن هایی که در ادامه آمده است پاسخ بدهید'  + '\n'
    prompt += 'اگر قادر به جواب دادن به سوال نبودی، عبارت زیر را خروجی بده:' + '\n'
    prompt += 'متاسفم در حال حاضر اطلاعات زیادی درباره سوال شما نمی دانم!' + '\n\n'
    prompt += f'سوال:' + '\n' + query + '\n\n'
    for idx, topic in enumerate(query_bm25):
        prompt += f'متن {idx+1}: ' + topic + '\n'
    
    return prompt


def get_relevant_topics(query, corpus, n=3):
    # Build the best match 25 base
    tokenized_corpus = [doc.split(" ") for doc in corpus]
    bm25 = BM25Plus(tokenized_corpus)

    tokenized_query = query.split(" ")
    query_bm = bm25.get_top_n(tokenized_query, corpus, n=n)

    return query_bm

# Chatbot formatting
if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "چطور می‌توانم به شما کمک کنم؟"}]

if "drugs" not in st.session_state:
    st.session_state["drugs"] = {"drug_name": [], "drug_link": []}

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])


if prompt := st.chat_input():
    st.session_state.messages.append({"role": "user", "content": prompt})
    st.chat_message("user").write(prompt)
    topics = get_relevant_topics(prompt, corpus, n=5)
    relevant_drugs = df.iloc[list(topics.values())]
    st.sidebar.header('داروها')
    for drug in relevant_drugs.iterrows():
        drug_name = drug[1]['Farsi_generic_name']
        drug_link = drug[1]['Link']
        if drug_name not in st.session_state.drugs['drug_name']:
            st.session_state.drugs['drug_name'].append(drug_name)
            st.session_state.drugs['drug_link'].append(drug_link)

    for i in range(len(st.session_state.drugs['drug_name'])):
        if st.session_state.drugs['drug_name'][i] in prompt:
            st.sidebar.link_button(st.session_state.drugs['drug_name'][i], st.session_state.drugs['drug_link'][i])

    prompt = get_prompt(prompt, topics)
    response = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": "تو یک دستیار سودمند هستی."
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],

        # The language model which will generate the completion.
        model="llama3-70b-8192",
    )
    msg = response.choices[0].message.content
    st.session_state.messages.append({"role": "assistant", "content": msg})
    st.chat_message("assistant").write(msg)