mgmtprofessor commited on
Commit
38d73ce
·
verified ·
1 Parent(s): 7f19cee

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +177 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dropbox
3
+ import streamlit as st
4
+ import torch
5
+ import pandas as pd
6
+ import time
7
+ from tqdm import tqdm
8
+ from simpletransformers.classification import ClassificationModel
9
+
10
+ # Set up Streamlit app
11
+ st.title("Document Scoring App for Various Categories")
12
+
13
+ # Model directories and corresponding Dropbox paths
14
+ model_directories = {
15
+ 'finance': 'models/finance_model',
16
+ 'accounting': 'models/accounting_model',
17
+ 'technology': 'models/technology_model',
18
+ 'international': 'models/international_model',
19
+ 'operations': 'models/operations_model',
20
+ 'marketing': 'models/marketing_model',
21
+ 'management': 'models/management_model',
22
+ 'legal': 'models/legal_model'
23
+ }
24
+
25
+ # Dropbox paths to main model directories
26
+ dropbox_model_paths = {
27
+ 'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model',
28
+ 'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model',
29
+ 'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model',
30
+ 'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model',
31
+ 'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model',
32
+ 'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model',
33
+ 'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model',
34
+ 'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model'
35
+ }
36
+
37
+ # Dropbox paths to model checkpoints (all 8 models)
38
+ dropbox_checkpoint_paths = {
39
+ 'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model/checkpoint-174-epoch-3',
40
+ 'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model/checkpoint-174-epoch-3',
41
+ 'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model/checkpoint-174-epoch-3',
42
+ 'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model/checkpoint-174-epoch-3',
43
+ 'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model/checkpoint-174-epoch-3',
44
+ 'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model/checkpoint-174-epoch-3',
45
+ 'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model/checkpoint-174-epoch-3',
46
+ 'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model/checkpoint-174-epoch-3'
47
+ }
48
+
49
+ # Check if CUDA is available
50
+ use_cuda = torch.cuda.is_available()
51
+
52
+ # Function to download files from Dropbox recursively, including checkpoint directories
53
+ def download_files_from_dropbox(dbx, dropbox_path, local_dir):
54
+ # List all files and subfolders in the Dropbox path
55
+ try:
56
+ for entry in dbx.files_list_folder(dropbox_path).entries:
57
+ local_path = os.path.join(local_dir, entry.name)
58
+ if isinstance(entry, dropbox.files.FileMetadata):
59
+ # It's a file, download it
60
+ with open(local_path, "wb") as f:
61
+ metadata, res = dbx.files_download(path=entry.path_lower)
62
+ f.write(res.content)
63
+ elif isinstance(entry, dropbox.files.FolderMetadata):
64
+ # It's a folder, create it locally and download its contents
65
+ os.makedirs(local_path, exist_ok=True)
66
+ download_files_from_dropbox(dbx, entry.path_lower, local_path)
67
+ except dropbox.exceptions.ApiError as err:
68
+ st.error(f"Dropbox API error: {err}")
69
+
70
+ # Function to download models and checkpoints from Dropbox
71
+ def download_model(category):
72
+ model_path = model_directories[category]
73
+ if not os.path.exists(model_path):
74
+ os.makedirs(model_path, exist_ok=True)
75
+
76
+ dbx = dropbox.Dropbox(st.secrets["dropbox_api_key"])
77
+
78
+ # Download the main model files
79
+ st.write(f"Downloading {category} model...")
80
+ download_files_from_dropbox(dbx, dropbox_model_paths[category], model_path)
81
+
82
+ # Download the checkpoint files if available
83
+ if category in dropbox_checkpoint_paths:
84
+ checkpoint_path = os.path.join(model_path, "checkpoint-174-epoch-3")
85
+ os.makedirs(checkpoint_path, exist_ok=True)
86
+ st.write(f"Downloading checkpoint for {category} model...")
87
+ download_files_from_dropbox(dbx, dropbox_checkpoint_paths[category], checkpoint_path)
88
+
89
+ st.success(f"{category} model and checkpoints downloaded successfully.")
90
+
91
+ # Function to load a model, skipping if it can't be loaded
92
+ def load_model(category):
93
+ model_path = model_directories[category]
94
+ # Ensure the model is downloaded
95
+ download_model(category)
96
+ try:
97
+ model = ClassificationModel(
98
+ "bert",
99
+ model_path,
100
+ use_cuda=use_cuda,
101
+ args={"silent": True} # Suppress output
102
+ )
103
+ return model
104
+ except Exception as e:
105
+ st.error(f"Failed to load model for {category}: {e}")
106
+ return None
107
+
108
+ # Function to score a document and return the prediction and probability for class '1'
109
+ def score_document(model, text_data):
110
+ if isinstance(text_data, str):
111
+ text_data = [text_data]
112
+
113
+ predictions, raw_outputs = model.predict(text_data)
114
+
115
+ # Get the probability associated with class '1'
116
+ probability_class_1 = torch.nn.functional.softmax(torch.tensor(raw_outputs[0]), dim=0)[1].item()
117
+
118
+ return predictions[0], probability_class_1
119
+
120
+ # Let the user upload a file
121
+ doc_file = st.file_uploader("Upload a document (.txt)", type=["txt"])
122
+
123
+ # Track the start time
124
+ start_time = time.time()
125
+
126
+ # Make predictions when a file is uploaded
127
+ if doc_file is not None:
128
+ # Read the content of the uploaded .txt file
129
+ text_data = doc_file.read().decode("utf-8")
130
+
131
+ # Initialize an empty DataFrame for results
132
+ result_df = pd.DataFrame(columns=["Category", "Prediction", "Probability"])
133
+
134
+ # Progress bar
135
+ progress_bar = st.progress(0)
136
+ total_categories = len(model_directories)
137
+
138
+ for i, category in enumerate(tqdm(model_directories.keys(), desc="Scoring documents")):
139
+ # Load the pre-trained model for the current category
140
+ model = load_model(category)
141
+
142
+ # Skip the category if model loading fails
143
+ if model is not None:
144
+ # Score the document
145
+ prediction, probability = score_document(model, text_data)
146
+
147
+ # Create a DataFrame for the current result
148
+ new_row = pd.DataFrame({
149
+ "Category": [category],
150
+ "Prediction": [prediction],
151
+ "Probability": [probability]
152
+ })
153
+
154
+ # Use pd.concat to append the new row to the DataFrame
155
+ result_df = pd.concat([result_df, new_row], ignore_index=True)
156
+
157
+ # Update the progress bar
158
+ progress_bar.progress((i + 1) / total_categories)
159
+
160
+ # Estimate remaining time
161
+ elapsed_time = time.time() - start_time
162
+ estimated_total_time = (elapsed_time / (i + 1)) * total_categories
163
+ st.write(f"Elapsed time: {elapsed_time:.2f}s, Estimated time remaining: {estimated_total_time - elapsed_time:.2f}s")
164
+
165
+ # Save results to CSV
166
+ csv = result_df.to_csv(index=False).encode('utf-8')
167
+ st.download_button(
168
+ label="Download results as CSV",
169
+ data=csv,
170
+ file_name="document_scoring_results.csv",
171
+ mime="text/csv",
172
+ )
173
+
174
+ # Display completion message
175
+ st.success("Document scoring complete!")
176
+
177
+ st.write("Note: Ensure the uploaded document is in .txt format containing text data.")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ pandas
4
+ tqdm
5
+ simpletransformers
6
+ dropbox==11.34.0