Spaces:
Sleeping
Sleeping
loplopez
commited on
Commit
·
8ad2ef4
1
Parent(s):
c8df78e
tests on classification results
Browse files- app/app.py +2 -2
- app/modules/classify.py +5 -4
- app/modules/redistribute.py +0 -2
app/app.py
CHANGED
@@ -41,9 +41,10 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
|
41 |
items = input_data.items
|
42 |
# TODO consider sampling them?
|
43 |
|
44 |
-
print(items)
|
45 |
reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
|
46 |
#reranked_ids = [ for id_ in reranked_ids]
|
|
|
|
|
47 |
|
48 |
user_in_db = user_db.get_user(user_id=user)
|
49 |
|
@@ -97,6 +98,5 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
|
97 |
|
98 |
# no civic content to boost on
|
99 |
else:
|
100 |
-
print("there")
|
101 |
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
102 |
|
|
|
41 |
items = input_data.items
|
42 |
# TODO consider sampling them?
|
43 |
|
|
|
44 |
reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
|
45 |
#reranked_ids = [ for id_ in reranked_ids]
|
46 |
+
print("Receiving boost on: ", first_topic)
|
47 |
+
print("Position: ", insertion_pos)
|
48 |
|
49 |
user_in_db = user_db.get_user(user_id=user)
|
50 |
|
|
|
98 |
|
99 |
# no civic content to boost on
|
100 |
else:
|
|
|
101 |
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
102 |
|
app/modules/classify.py
CHANGED
@@ -10,7 +10,7 @@ except:
|
|
10 |
print("No GPU available, running on CPU")
|
11 |
device = None
|
12 |
|
13 |
-
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
14 |
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
|
15 |
|
16 |
label_map = {
|
@@ -49,6 +49,7 @@ def classify(texts: List[str], labels: List[str]):
|
|
49 |
# Iterate through each text to check for special cases
|
50 |
for index, text in enumerate(texts):
|
51 |
if text == "NON-VALID":
|
|
|
52 |
# If text is "X", directly assign the label and score
|
53 |
results.append({
|
54 |
"sequence": text,
|
@@ -57,16 +58,16 @@ def classify(texts: List[str], labels: List[str]):
|
|
57 |
})
|
58 |
else:
|
59 |
# Otherwise, prepare for model processing
|
|
|
60 |
model_texts.append(text)
|
61 |
model_indices.append(index)
|
62 |
|
63 |
if model_texts:
|
64 |
# Process texts through the model if there are any
|
65 |
-
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=
|
66 |
|
67 |
# Insert model results into the correct positions
|
68 |
for pred, idx in zip(predicted_labels, model_indices):
|
69 |
results.insert(idx, pred)
|
70 |
-
|
71 |
-
print(results)
|
72 |
return results
|
|
|
10 |
print("No GPU available, running on CPU")
|
11 |
device = None
|
12 |
|
13 |
+
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
|
14 |
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
|
15 |
|
16 |
label_map = {
|
|
|
49 |
# Iterate through each text to check for special cases
|
50 |
for index, text in enumerate(texts):
|
51 |
if text == "NON-VALID":
|
52 |
+
print("NON-VALID TEXT!!", text)
|
53 |
# If text is "X", directly assign the label and score
|
54 |
results.append({
|
55 |
"sequence": text,
|
|
|
58 |
})
|
59 |
else:
|
60 |
# Otherwise, prepare for model processing
|
61 |
+
#print("- text =>", text)
|
62 |
model_texts.append(text)
|
63 |
model_indices.append(index)
|
64 |
|
65 |
if model_texts:
|
66 |
# Process texts through the model if there are any
|
67 |
+
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32)
|
68 |
|
69 |
# Insert model results into the correct positions
|
70 |
for pred, idx in zip(predicted_labels, model_indices):
|
71 |
results.insert(idx, pred)
|
72 |
+
print([(r['labels'][0], r['sequence']) for r in results])
|
|
|
73 |
return results
|
app/modules/redistribute.py
CHANGED
@@ -24,9 +24,7 @@ def redistribute(platform, items):
|
|
24 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
25 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
26 |
# TODO include parent linking
|
27 |
-
print("OK--", predicted_labels)
|
28 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
29 |
-
print(reranked_ids)
|
30 |
return reranked_ids, first_topic, insertion_pos
|
31 |
|
32 |
|
|
|
24 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
25 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
26 |
# TODO include parent linking
|
|
|
27 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
|
|
28 |
return reranked_ids, first_topic, insertion_pos
|
29 |
|
30 |
|