loplopez commited on
Commit
8ad2ef4
·
1 Parent(s): c8df78e

tests on classification results

Browse files
app/app.py CHANGED
@@ -41,9 +41,10 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
41
  items = input_data.items
42
  # TODO consider sampling them?
43
 
44
- print(items)
45
  reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
46
  #reranked_ids = [ for id_ in reranked_ids]
 
 
47
 
48
  user_in_db = user_db.get_user(user_id=user)
49
 
@@ -97,6 +98,5 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
97
 
98
  # no civic content to boost on
99
  else:
100
- print("there")
101
  return RankingResponse(ranked_ids=reranked_ids, new_items=[])
102
 
 
41
  items = input_data.items
42
  # TODO consider sampling them?
43
 
 
44
  reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
45
  #reranked_ids = [ for id_ in reranked_ids]
46
+ print("Receiving boost on: ", first_topic)
47
+ print("Position: ", insertion_pos)
48
 
49
  user_in_db = user_db.get_user(user_id=user)
50
 
 
98
 
99
  # no civic content to boost on
100
  else:
 
101
  return RankingResponse(ranked_ids=reranked_ids, new_items=[])
102
 
app/modules/classify.py CHANGED
@@ -10,7 +10,7 @@ except:
10
  print("No GPU available, running on CPU")
11
  device = None
12
 
13
- #model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
14
  model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
15
 
16
  label_map = {
@@ -49,6 +49,7 @@ def classify(texts: List[str], labels: List[str]):
49
  # Iterate through each text to check for special cases
50
  for index, text in enumerate(texts):
51
  if text == "NON-VALID":
 
52
  # If text is "X", directly assign the label and score
53
  results.append({
54
  "sequence": text,
@@ -57,16 +58,16 @@ def classify(texts: List[str], labels: List[str]):
57
  })
58
  else:
59
  # Otherwise, prepare for model processing
 
60
  model_texts.append(text)
61
  model_indices.append(index)
62
 
63
  if model_texts:
64
  # Process texts through the model if there are any
65
- predicted_labels = model(model_texts, labels, multi_label=False, batch_size=16)
66
 
67
  # Insert model results into the correct positions
68
  for pred, idx in zip(predicted_labels, model_indices):
69
  results.insert(idx, pred)
70
-
71
- print(results)
72
  return results
 
10
  print("No GPU available, running on CPU")
11
  device = None
12
 
13
+ #model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
14
  model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
15
 
16
  label_map = {
 
49
  # Iterate through each text to check for special cases
50
  for index, text in enumerate(texts):
51
  if text == "NON-VALID":
52
+ print("NON-VALID TEXT!!", text)
53
  # If text is "X", directly assign the label and score
54
  results.append({
55
  "sequence": text,
 
58
  })
59
  else:
60
  # Otherwise, prepare for model processing
61
+ #print("- text =>", text)
62
  model_texts.append(text)
63
  model_indices.append(index)
64
 
65
  if model_texts:
66
  # Process texts through the model if there are any
67
+ predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32)
68
 
69
  # Insert model results into the correct positions
70
  for pred, idx in zip(predicted_labels, model_indices):
71
  results.insert(idx, pred)
72
+ print([(r['labels'][0], r['sequence']) for r in results])
 
73
  return results
app/modules/redistribute.py CHANGED
@@ -24,9 +24,7 @@ def redistribute(platform, items):
24
  mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
25
  first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
26
  # TODO include parent linking
27
- print("OK--", predicted_labels)
28
  reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
29
- print(reranked_ids)
30
  return reranked_ids, first_topic, insertion_pos
31
 
32
 
 
24
  mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
25
  first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
26
  # TODO include parent linking
 
27
  reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
 
28
  return reranked_ids, first_topic, insertion_pos
29
 
30