ctheodoris
commited on
Commit
•
c48e37c
1
Parent(s):
748f48a
Update minor formatting
Browse files
examples/hyperparam_optimiz_for_disease_classifier.py
CHANGED
@@ -50,7 +50,7 @@ def initialize_ray_with_check(ip_address):
|
|
50 |
# Usage:
|
51 |
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
52 |
if initialize_ray_with_check(ip):
|
53 |
-
print("Ray initialized successfully
|
54 |
else:
|
55 |
print("Error during Ray initialization.")
|
56 |
|
@@ -62,7 +62,7 @@ import seaborn as sns; sns.set()
|
|
62 |
from collections import Counter
|
63 |
from datasets import load_from_disk
|
64 |
from scipy.stats import ranksums
|
65 |
-
from sklearn.metrics import accuracy_score
|
66 |
from transformers import BertForSequenceClassification
|
67 |
from transformers import Trainer
|
68 |
from transformers.training_args import TrainingArguments
|
@@ -155,6 +155,7 @@ def model_init():
|
|
155 |
return model
|
156 |
|
157 |
# define metrics
|
|
|
158 |
def compute_metrics(pred):
|
159 |
labels = pred.label_ids
|
160 |
preds = pred.predictions.argmax(-1)
|
|
|
50 |
# Usage:
|
51 |
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
52 |
if initialize_ray_with_check(ip):
|
53 |
+
print("Ray initialized successfully.")
|
54 |
else:
|
55 |
print("Error during Ray initialization.")
|
56 |
|
|
|
62 |
from collections import Counter
|
63 |
from datasets import load_from_disk
|
64 |
from scipy.stats import ranksums
|
65 |
+
from sklearn.metrics import accuracy_score
|
66 |
from transformers import BertForSequenceClassification
|
67 |
from transformers import Trainer
|
68 |
from transformers.training_args import TrainingArguments
|
|
|
155 |
return model
|
156 |
|
157 |
# define metrics
|
158 |
+
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
159 |
def compute_metrics(pred):
|
160 |
labels = pred.label_ids
|
161 |
preds = pred.predictions.argmax(-1)
|