Spaces:
Sleeping
Sleeping
ezequiellopez
commited on
Commit
·
c8df78e
1
Parent(s):
c6dd11e
debugging integration tests
Browse files- app/app.py +2 -9
- app/modules/classify.py +45 -5
- app/modules/redistribute.py +18 -3
app/app.py
CHANGED
@@ -3,9 +3,8 @@ from fastapi import FastAPI, HTTPException
|
|
3 |
#import redis
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
6 |
-
import torch
|
7 |
|
8 |
-
from modules.redistribute import redistribute, insert_element_at_position
|
9 |
#from modules.models.api import Input, Output, NewItem, UUID
|
10 |
from modules.database import BoostDatabase, UserDatabase, User
|
11 |
from _models.request import RankingRequest
|
@@ -19,10 +18,6 @@ load_dotenv('../.env')
|
|
19 |
redis_port = os.getenv("REDIS_PORT")
|
20 |
fastapi_port = os.getenv("FASTAPI_PORT")
|
21 |
|
22 |
-
|
23 |
-
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
24 |
-
#print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
25 |
-
|
26 |
#print("Redis port:", redis_port)
|
27 |
print("FastAPI port:", fastapi_port)
|
28 |
|
@@ -47,10 +42,8 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
|
47 |
# TODO consider sampling them?
|
48 |
|
49 |
print(items)
|
50 |
-
reranked_ids, first_topic, insertion_pos = redistribute(items=items)
|
51 |
#reranked_ids = [ for id_ in reranked_ids]
|
52 |
-
print("here!")
|
53 |
-
print(reranked_ids)
|
54 |
|
55 |
user_in_db = user_db.get_user(user_id=user)
|
56 |
|
|
|
3 |
#import redis
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
|
|
6 |
|
7 |
+
from modules.redistribute import redistribute, insert_element_at_position, handle_text_content
|
8 |
#from modules.models.api import Input, Output, NewItem, UUID
|
9 |
from modules.database import BoostDatabase, UserDatabase, User
|
10 |
from _models.request import RankingRequest
|
|
|
18 |
redis_port = os.getenv("REDIS_PORT")
|
19 |
fastapi_port = os.getenv("FASTAPI_PORT")
|
20 |
|
|
|
|
|
|
|
|
|
21 |
#print("Redis port:", redis_port)
|
22 |
print("FastAPI port:", fastapi_port)
|
23 |
|
|
|
42 |
# TODO consider sampling them?
|
43 |
|
44 |
print(items)
|
45 |
+
reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
|
46 |
#reranked_ids = [ for id_ in reranked_ids]
|
|
|
|
|
47 |
|
48 |
user_in_db = user_db.get_user(user_id=user)
|
49 |
|
app/modules/classify.py
CHANGED
@@ -1,10 +1,17 @@
|
|
1 |
from transformers import pipeline
|
2 |
from typing import List
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
|
|
|
|
|
8 |
|
9 |
label_map = {
|
10 |
"something else": "non-civic",
|
@@ -13,6 +20,7 @@ label_map = {
|
|
13 |
"health are and public health": "health",
|
14 |
"religious": "news" # CONSCIOUS DECISION
|
15 |
}
|
|
|
16 |
|
17 |
def map_scores(predicted_labels: List[dict], default_label: str):
|
18 |
mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
|
@@ -26,7 +34,39 @@ def get_first_relevant_label(predicted_labels, mapped_scores: List[float], defau
|
|
26 |
|
27 |
|
28 |
def classify(texts: List[str], labels: List[str]):
|
29 |
-
predicted_labels = model(texts, labels, multi_label=False)
|
30 |
print(predicted_labels)
|
31 |
return predicted_labels
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import pipeline
|
2 |
from typing import List
|
3 |
|
4 |
+
try:
|
5 |
+
import torch
|
6 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
7 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
8 |
+
device = 0
|
9 |
+
except:
|
10 |
+
print("No GPU available, running on CPU")
|
11 |
+
device = None
|
12 |
|
13 |
+
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
14 |
+
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
|
15 |
|
16 |
label_map = {
|
17 |
"something else": "non-civic",
|
|
|
20 |
"health are and public health": "health",
|
21 |
"religious": "news" # CONSCIOUS DECISION
|
22 |
}
|
23 |
+
default_label = "something else"
|
24 |
|
25 |
def map_scores(predicted_labels: List[dict], default_label: str):
|
26 |
mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
|
|
|
34 |
|
35 |
|
36 |
def classify(texts: List[str], labels: List[str]):
|
37 |
+
predicted_labels = model(texts, labels, multi_label=False, batch_size=16)
|
38 |
print(predicted_labels)
|
39 |
return predicted_labels
|
40 |
+
|
41 |
+
|
42 |
+
def classify(texts: List[str], labels: List[str]):
|
43 |
+
results = []
|
44 |
+
|
45 |
+
# Lists to hold texts and indices for model processing
|
46 |
+
model_texts = []
|
47 |
+
model_indices = []
|
48 |
+
|
49 |
+
# Iterate through each text to check for special cases
|
50 |
+
for index, text in enumerate(texts):
|
51 |
+
if text == "NON-VALID":
|
52 |
+
# If text is "X", directly assign the label and score
|
53 |
+
results.append({
|
54 |
+
"sequence": text,
|
55 |
+
"labels": [default_label], # Assuming the first label is the correct one for "X"
|
56 |
+
"scores": [1.0] # Assign a full score
|
57 |
+
})
|
58 |
+
else:
|
59 |
+
# Otherwise, prepare for model processing
|
60 |
+
model_texts.append(text)
|
61 |
+
model_indices.append(index)
|
62 |
+
|
63 |
+
if model_texts:
|
64 |
+
# Process texts through the model if there are any
|
65 |
+
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=16)
|
66 |
+
|
67 |
+
# Insert model results into the correct positions
|
68 |
+
for pred, idx in zip(predicted_labels, model_indices):
|
69 |
+
results.insert(idx, pred)
|
70 |
+
|
71 |
+
print(results)
|
72 |
+
return results
|
app/modules/redistribute.py
CHANGED
@@ -4,12 +4,27 @@ from modules.classify import classify, map_scores, get_first_relevant_label
|
|
4 |
labels = ["something else", "headlines, news channels, news articles, breaking news", "politics, policy and politicians", "health care and public health", "religious"]
|
5 |
|
6 |
|
7 |
-
def
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
10 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
11 |
# TODO include parent linking
|
12 |
-
print("OK
|
13 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
14 |
print(reranked_ids)
|
15 |
return reranked_ids, first_topic, insertion_pos
|
|
|
4 |
labels = ["something else", "headlines, news channels, news articles, breaking news", "politics, policy and politicians", "health care and public health", "religious"]
|
5 |
|
6 |
|
7 |
+
def handle_text_content(platform, items):
|
8 |
+
texts = []
|
9 |
+
for item in items:
|
10 |
+
if platform == "reddit" and item.title:
|
11 |
+
text = item.title +"\n"+ item.text
|
12 |
+
else:
|
13 |
+
text = item.text
|
14 |
+
|
15 |
+
if len(text) <=5:
|
16 |
+
text = "NON-VALID"
|
17 |
+
|
18 |
+
texts.append(text)
|
19 |
+
return texts
|
20 |
+
|
21 |
+
|
22 |
+
def redistribute(platform, items):
|
23 |
+
predicted_labels = classify(texts=handle_text_content(platform=platform, items=items), labels=labels)
|
24 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
25 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
26 |
# TODO include parent linking
|
27 |
+
print("OK--", predicted_labels)
|
28 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
29 |
print(reranked_ids)
|
30 |
return reranked_ids, first_topic, insertion_pos
|