FridayMaster commited on
Commit
f67ae72
·
verified ·
1 Parent(s): e44a872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import faiss
3
  import numpy as np
@@ -5,12 +7,13 @@ import openai
5
  from sentence_transformers import SentenceTransformer
6
  from nltk.tokenize import sent_tokenize
7
  import nltk
 
 
8
 
9
  # Download the required NLTK data
10
  nltk.download('punkt')
11
- nltk.download('punkt_tab')
12
 
13
- # Paths
14
  faiss_path = "manual_chunked_faiss_index_500.bin"
15
  manual_path = "ubuntu_manual.txt"
16
 
@@ -48,17 +51,19 @@ try:
48
  except Exception as e:
49
  raise RuntimeError(f"Failed to load FAISS index: {e}")
50
 
51
- # Load your embedding model
52
- embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
53
 
54
- # OpenAI API key
55
- openai.api_key = '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'
56
 
57
  # Function to create embeddings
58
  def embed_text(text_list):
59
- embeddings = embedding_model.encode(text_list)
60
- print("Embedding shape:", embeddings.shape) # Debugging: Print shape
61
- return np.array(embeddings, dtype=np.float32)
 
 
62
 
63
  # Function to retrieve relevant chunks for a user query
64
  def retrieve_chunks(query, k=5):
@@ -66,45 +71,44 @@ def retrieve_chunks(query, k=5):
66
 
67
  try:
68
  distances, indices = index.search(query_embedding, k=k)
69
- print("Indices:", indices) # Debugging: Print indices
70
- print("Distances:", distances) # Debugging: Print distances
71
  except Exception as e:
72
  raise RuntimeError(f"FAISS search failed: {e}")
73
-
74
  if len(indices[0]) == 0:
75
  return []
76
 
77
- # Ensure indices are within bounds
78
  valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
79
  if not valid_indices:
80
  return []
81
 
82
- # Retrieve relevant chunks
83
  relevant_chunks = [manual_chunks[i] for i in valid_indices]
84
  return relevant_chunks
85
 
 
 
 
 
86
  # Function to truncate long inputs
87
  def truncate_input(text, max_length=512):
88
- tokens = generator_tokenizer.encode(text, truncation=True, max_length=max_length, return_tensors="pt")
89
- return tokens
90
 
91
  # Function to perform RAG: Retrieve chunks and generate a response
92
  def rag_response(query, k=5, max_new_tokens=150):
93
  try:
94
- # Step 1: Retrieve relevant chunks
95
  relevant_chunks = retrieve_chunks(query, k=k)
96
 
97
  if not relevant_chunks:
98
  return "Sorry, I couldn't find relevant information."
99
 
100
- # Step 2: Combine the query with retrieved chunks
101
  augmented_input = query + "\n" + "\n".join(relevant_chunks)
102
 
103
- # Truncate and encode the input
104
  inputs = truncate_input(augmented_input)
105
 
106
  # Generate response
107
- outputs = generator_model.generate(inputs, max_new_tokens=max_new_tokens)
108
  generated_text = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
109
 
110
  return generated_text
@@ -128,4 +132,3 @@ if __name__ == "__main__":
128
 
129
 
130
 
131
-
 
1
+
2
+ # OpenAI API key
3
  import gradio as gr
4
  import faiss
5
  import numpy as np
 
7
  from sentence_transformers import SentenceTransformer
8
  from nltk.tokenize import sent_tokenize
9
  import nltk
10
+ from transformers import AutoTokenizer, AutoModel
11
+ import torch
12
 
13
  # Download the required NLTK data
14
  nltk.download('punkt')
 
15
 
16
+ # Paths to your files
17
  faiss_path = "manual_chunked_faiss_index_500.bin"
18
  manual_path = "ubuntu_manual.txt"
19
 
 
51
  except Exception as e:
52
  raise RuntimeError(f"Failed to load FAISS index: {e}")
53
 
54
+ # Load the tokenizer and model for embeddings
55
+ embedding_tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
56
+ embedding_model = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
57
+
58
 
 
 
59
 
60
  # Function to create embeddings
61
  def embed_text(text_list):
62
+ inputs = embedding_tokenizer(text_list, padding=True, truncation=True, return_tensors="pt")
63
+ with torch.no_grad():
64
+ outputs = embedding_model(**inputs)
65
+ embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() # Use the CLS token representation
66
+ return embeddings
67
 
68
  # Function to retrieve relevant chunks for a user query
69
  def retrieve_chunks(query, k=5):
 
71
 
72
  try:
73
  distances, indices = index.search(query_embedding, k=k)
74
+ print("Distances:", distances)
75
+ print("Indices:", indices)
76
  except Exception as e:
77
  raise RuntimeError(f"FAISS search failed: {e}")
78
+
79
  if len(indices[0]) == 0:
80
  return []
81
 
 
82
  valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
83
  if not valid_indices:
84
  return []
85
 
 
86
  relevant_chunks = [manual_chunks[i] for i in valid_indices]
87
  return relevant_chunks
88
 
89
+ # Load the tokenizer and model for generation
90
+ generator_tokenizer = AutoTokenizer.from_pretrained("gpt-3.5-turbo") # Replace with correct tokenizer if needed
91
+ generator_model = AutoModel.from_pretrained("gpt-3.5-turbo") # Replace with correct model if needed
92
+
93
  # Function to truncate long inputs
94
  def truncate_input(text, max_length=512):
95
+ inputs = generator_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
96
+ return inputs
97
 
98
  # Function to perform RAG: Retrieve chunks and generate a response
99
  def rag_response(query, k=5, max_new_tokens=150):
100
  try:
 
101
  relevant_chunks = retrieve_chunks(query, k=k)
102
 
103
  if not relevant_chunks:
104
  return "Sorry, I couldn't find relevant information."
105
 
 
106
  augmented_input = query + "\n" + "\n".join(relevant_chunks)
107
 
 
108
  inputs = truncate_input(augmented_input)
109
 
110
  # Generate response
111
+ outputs = generator_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
112
  generated_text = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
113
 
114
  return generated_text
 
132
 
133
 
134