Inni-23 commited on
Commit
96dd2db
1 Parent(s): 67ad1b2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from azure.storage.blob import BlobServiceClient
4
+ from flask import Flask, request, jsonify
5
+
6
+ app = Flask(__name__)
7
+
8
+ # BERT model and tokenizer
9
+ model_name = "textattack/bert-base-uncased-yelp-polarity"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
+
13
+ # Predict the category
14
+ def predict_category(input_text):
15
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
16
+ with torch.no_grad():
17
+ logits = model(**inputs).logits
18
+ probabilities = logits.softmax(dim=1)
19
+ predicted_category = ["Documentation", "Content", "Memes"][torch.argmax(probabilities)]
20
+ return predicted_category
21
+
22
+ # Function to extract text from JSON and predict the category
23
+ def predict_category_from_json(json_data):
24
+ input_text = json_data.get('text', '')
25
+ category = predict_category(input_text)
26
+ return category
27
+
28
+ # Importing data from blob storage
29
+ def import_data_from_blob(blob_service_client, container_name, blob_name):
30
+ blob_client = blob_service_client.get_blob_client(container=container_name, blob=blob_name)
31
+ blob_data = blob_client.download_blob()
32
+ content = blob_data.readall()
33
+ return content
34
+
35
+ @app.route('/predict_category', methods=['POST'])
36
+ def predict_category_api():
37
+ try:
38
+ # Assuming JSON format with a key named 'text' that contains the text data.
39
+ json_data = request.get_json()
40
+ input_text = json_data.get('text', '')
41
+
42
+ # Predict the category
43
+ category = predict_category(input_text)
44
+
45
+ response = {'category': category}
46
+ return jsonify(response)
47
+ except Exception as e:
48
+ return jsonify({'error': str(e)})
49
+
50
+ if __name__ == '__main__':
51
+ # Azure Blob Storage connection string
52
+ connection_string = "DefaultEndpointsProtocol=https;AccountName=keywisestorage;AccountKey=uRzlCQwv/SSF6WgkEz0g83dBjnFrziSNNt8PIY5Nnt+OJic0v5xjPnO8ZMhb7SjyesYSOK79TbJ/+AStdLKiDw==;EndpointSuffix=core.windows.net"
53
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
54
+
55
+ # Define your container and blob name
56
+ container_name = "keywisestorage"
57
+ blob_name = "pagescontainer"
58
+
59
+ app.run(host="0.0.0.0", port=5000)
requirements.txt ADDED
Binary file (98 Bytes). View file