asadAbdullah commited on
Commit
e0c3387
1 Parent(s): 6feb2e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -2,7 +2,9 @@
2
  import os
3
  import pandas as pd
4
  import streamlit as st
5
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
 
 
6
  from transformers import pipeline
7
  from sentence_transformers import SentenceTransformer, util
8
  import requests
@@ -18,8 +20,28 @@ except FileNotFoundError:
18
  st.error("Dataset file not found. Please upload it to this directory.")
19
 
20
  # Load DistilBERT Tokenizer and Model
21
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
22
- model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Preprocessing the dataset (if needed)
25
  if 'combined_description' not in data.columns:
 
2
  import os
3
  import pandas as pd
4
  import streamlit as st
5
+ # from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
6
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
7
+
8
  from transformers import pipeline
9
  from sentence_transformers import SentenceTransformer, util
10
  import requests
 
20
  st.error("Dataset file not found. Please upload it to this directory.")
21
 
22
  # Load DistilBERT Tokenizer and Model
23
+ # tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
24
+ # model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
25
+
26
+ # Load DistilBERT tokenizer and model (without classification layer)
27
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
28
+ model = DistilBertModel.from_pretrained("distilbert-base-uncased")
29
+
30
+ query = "What is fructose-1,6-bisphosphatase deficiency?"
31
+
32
+ # Tokenize input
33
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
34
+
35
+ # Get model output (embeddings)
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+
39
+ # Extract embeddings (last hidden state)
40
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Averaging over token embeddings
41
+
42
+ # Use the embeddings for further processing or retrieval
43
+ print(embeddings)
44
+
45
 
46
  # Preprocessing the dataset (if needed)
47
  if 'combined_description' not in data.columns: