asadAbdullah commited on
Commit
2a1a9a9
1 Parent(s): 84bb67b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -71
app.py CHANGED
@@ -1,97 +1,62 @@
 
1
  # Import required libraries
2
  import os
3
  import pandas as pd
4
  import streamlit as st
5
- # from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
6
- from transformers import DistilBertTokenizerFast, DistilBertModel
7
- # from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
8
-
9
  from transformers import pipeline
10
  from sentence_transformers import SentenceTransformer, util
11
  import requests
12
  import json
 
13
 
14
- # Configure Hugging Face API token securely
15
  api_key = os.getenv("HF_API_KEY")
16
 
17
  # Load the CSV dataset
18
- try:
19
- data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
20
- except FileNotFoundError:
21
- st.error("Dataset file not found. Please upload it to this directory.")
22
-
23
- # Load DistilBERT Tokenizer and Model
24
- # tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
25
- # model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
26
-
27
- # Load DistilBERT tokenizer and model (without classification layer)
28
- tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
29
- model = DistilBertModel.from_pretrained("distilbert-base-uncased")
30
-
31
- query = "What is fructose-1,6-bisphosphatase deficiency?"
32
-
33
- # Tokenize input
34
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
35
-
36
- # Get model output (embeddings)
37
- with torch.no_grad():
38
- outputs = model(**inputs)
39
-
40
- # Extract embeddings (last hidden state)
41
- embeddings = outputs.last_hidden_state.mean(dim=1) # Averaging over token embeddings
42
-
43
- # Use the embeddings for further processing or retrieval
44
- print(embeddings)
45
-
46
-
47
- # Preprocessing the dataset (if needed)
48
- if 'combined_description' not in data.columns:
49
- data['combined_description'] = (
50
- data['Symptoms'].fillna('') + " " +
51
- data['Severity Level'].fillna('') + " " +
52
- data['Risk Assessment'].fillna('') + " " +
53
- data['Treatment Options'].fillna('') + " " +
54
- data['Suggested Medical Tests'].fillna('') + " " +
55
- data['Minimum Values for Medical Tests'].fillna('') + " " +
56
- data['Emergency Treatment'].fillna('')
57
- )
58
-
59
- # Initialize Sentence Transformer model for RAG-based retrieval
60
  retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
61
 
62
- # Define a function to get embeddings using DistilBERT
63
  def generate_embedding(description):
64
- if description:
65
- inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512)
66
- outputs = model(**inputs)
67
- return outputs.logits.detach().numpy().flatten()
68
  else:
69
  return []
70
 
71
  # Generate embeddings for the combined description
72
- if 'embeddings' not in data.columns:
73
- data['embeddings'] = data['combined_description'].apply(generate_embedding)
74
 
75
- # Function to retrieve relevant information based on user query
76
  def get_relevant_info(query, top_k=3):
77
  query_embedding = retriever_model.encode(query)
78
  similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
79
  top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
80
  return data.iloc[top_indices]
81
 
82
- # Function to generate response using DistilBERT (integrating with the model)
83
- def generate_response(input_text, relevant_info):
84
- # Concatenate the relevant information as context for the model
85
- context = "\n".join(relevant_info['combined_description'].tolist())
86
- input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
87
 
88
- # Simple logic for generating a response using DistilBERT-based model
89
- inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512)
90
- outputs = model(**inputs)
91
- logits = outputs.logits.detach().numpy().flatten()
92
- response = tokenizer.decode(logits.argmax(), skip_special_tokens=True)
93
-
94
- return response
95
 
96
  # Streamlit UI for the Chatbot
97
  def main():
@@ -112,19 +77,20 @@ def main():
112
  relevant_info = get_relevant_info(user_query)
113
  st.write("#### Relevant Medical Information:")
114
  for i, row in relevant_info.iterrows():
115
- st.write(f"- {row['combined_description']}") # Adjust to show meaningful info
116
 
117
- # Generate a response from DistilBERT model
118
- response = generate_response(user_query, relevant_info)
119
  st.write("#### Model's Response:")
120
  st.write(response)
121
 
122
  # Process the uploaded file (if any)
123
  if uploaded_file:
124
- # Display analysis of the uploaded report file (process based on file type)
125
  st.write("### Uploaded Report Analysis:")
126
  report_text = "Extracted report content here" # Placeholder for file processing logic
127
  st.write(report_text)
128
 
 
129
  if __name__ == "__main__":
130
  main()
 
1
+
2
  # Import required libraries
3
  import os
4
  import pandas as pd
5
  import streamlit as st
 
 
 
 
6
  from transformers import pipeline
7
  from sentence_transformers import SentenceTransformer, util
8
  import requests
9
  import json
10
+ from pyngrok import ngrok
11
 
12
+ # Set up Hugging Face API token
13
  api_key = os.getenv("HF_API_KEY")
14
 
15
  # Load the CSV dataset
16
+ data = pd.read_csv('genetic-Final.csv')
17
+
18
+ # Drop unnecessary columns (Unnamed columns)
19
+ data = data.drop(columns=['Unnamed: 0', 'Unnamed: 11', 'Unnamed: 12', 'Unnamed: 13'])
20
+
21
+ # Combine relevant columns into one combined description field
22
+ data['combined_description'] = (
23
+ data['Symptoms'].fillna('') + " " +
24
+ data['Severity Level'].fillna('') + " " +
25
+ data['Risk Assessment'].fillna('') + " " +
26
+ data['Treatment Options'].fillna('') + " " +
27
+ data['Suggested Medical Tests'].fillna('') + " " +
28
+ data['Minimum Values for Medical Tests'].fillna('') + " " +
29
+ data['Emergency Treatment'].fillna('')
30
+ )
31
+
32
+ # Initialize the Sentence Transformer model for embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
34
 
35
+ # Function to safely generate embeddings for each row
36
  def generate_embedding(description):
37
+ if description: # Check if the description is not empty or NaN
38
+ return retriever_model.encode(description).tolist() # Convert the numpy array to list
 
 
39
  else:
40
  return []
41
 
42
  # Generate embeddings for the combined description
43
+ data['embeddings'] = data['combined_description'].apply(generate_embedding)
 
44
 
45
+ # Function to retrieve relevant information from CSV dataset based on user query
46
  def get_relevant_info(query, top_k=3):
47
  query_embedding = retriever_model.encode(query)
48
  similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
49
  top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
50
  return data.iloc[top_indices]
51
 
52
+ # Function to generate response using Hugging Face Model API
53
+ def generate_response(input_text):
54
+ api_url = "https://api-inference.huggingface.co/models/m42-health/Llama3-Med42-8B"
55
+ headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"}
56
+ payload = {"inputs": input_text}
57
 
58
+ response = requests.post(api_url, headers=headers, json=payload)
59
+ return json.loads(response.content.decode("utf-8"))[0]["generated_text"]
 
 
 
 
 
60
 
61
  # Streamlit UI for the Chatbot
62
  def main():
 
77
  relevant_info = get_relevant_info(user_query)
78
  st.write("#### Relevant Medical Information:")
79
  for i, row in relevant_info.iterrows():
80
+ st.write(f"- {row['combined_description']}")
81
 
82
+ # Generate a response from the Llama3-Med42-8B model
83
+ response = generate_response(user_query)
84
  st.write("#### Model's Response:")
85
  st.write(response)
86
 
87
  # Process the uploaded file (if any)
88
  if uploaded_file:
89
+ # Display analysis of the uploaded report file
90
  st.write("### Uploaded Report Analysis:")
91
  report_text = "Extracted report content here" # Placeholder for file processing logic
92
  st.write(report_text)
93
 
94
+ # Start Streamlit app in Colab using ngrok
95
  if __name__ == "__main__":
96
  main()