Update README.md
#2
by
Jiqing
- opened
README.md
CHANGED
@@ -5,7 +5,9 @@ tags: []
|
|
5 |
|
6 |
# Model Card for Model ID
|
7 |
|
8 |
-
ProtST for binary localization
|
|
|
|
|
9 |
|
10 |
## Running script
|
11 |
```python
|
@@ -22,6 +24,9 @@ import torch
|
|
22 |
import logging
|
23 |
import datasets
|
24 |
import transformers
|
|
|
|
|
|
|
25 |
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
logger = logging.getLogger(__name__)
|
@@ -73,7 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
|
|
73 |
"lr": training_args.learning_rate * lr_ratio
|
74 |
},
|
75 |
]
|
76 |
-
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
|
|
77 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
78 |
|
79 |
return optimizer
|
@@ -98,7 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
|
|
98 |
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
-
device = torch.device("cpu")
|
|
|
102 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
103 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
104 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
@@ -108,8 +115,10 @@ if __name__ == "__main__":
|
|
108 |
'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
|
109 |
'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
|
110 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
111 |
-
'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
112 |
-
|
|
|
|
|
113 |
|
114 |
def tokenize_protein(example, tokenizer=None):
|
115 |
protein_seq = example["prot_seq"]
|
@@ -125,7 +134,8 @@ if __name__ == "__main__":
|
|
125 |
for split in ["train", "validation", "test"]:
|
126 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
127 |
|
128 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
|
129 |
|
130 |
transformers.utils.logging.set_verbosity_info()
|
131 |
log_level = training_args.get_process_log_level()
|
@@ -134,9 +144,16 @@ if __name__ == "__main__":
|
|
134 |
optimizer = create_optimizer(model)
|
135 |
scheduler = create_scheduler(training_args, optimizer)
|
136 |
|
|
|
|
|
|
|
|
|
|
|
137 |
# build trainer
|
138 |
-
trainer = Trainer(
|
|
|
139 |
model=model,
|
|
|
140 |
args=training_args,
|
141 |
train_dataset=raw_dataset["train"],
|
142 |
eval_dataset=raw_dataset["validation"],
|
|
|
5 |
|
6 |
# Model Card for Model ID
|
7 |
|
8 |
+
ProtST for binary localization.
|
9 |
+
|
10 |
+
The following script shows how to finetune ProtST on Gaudi.
|
11 |
|
12 |
## Running script
|
13 |
```python
|
|
|
24 |
import logging
|
25 |
import datasets
|
26 |
import transformers
|
27 |
+
+ import habana_frameworks.torch
|
28 |
+
+ from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
|
29 |
+
|
30 |
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
logger = logging.getLogger(__name__)
|
|
|
78 |
"lr": training_args.learning_rate * lr_ratio
|
79 |
},
|
80 |
]
|
81 |
+
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
82 |
+
+ optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
|
83 |
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
84 |
|
85 |
return optimizer
|
|
|
104 |
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
+
- device = torch.device("cpu")
|
108 |
+
+ device = torch.device("hpu")
|
109 |
raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
|
110 |
model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
111 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
|
|
115 |
'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
|
116 |
'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
|
117 |
'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
|
118 |
+
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
|
119 |
+
+ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
|
120 |
+
- training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
121 |
+
+ training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
|
122 |
|
123 |
def tokenize_protein(example, tokenizer=None):
|
124 |
protein_seq = example["prot_seq"]
|
|
|
134 |
for split in ["train", "validation", "test"]:
|
135 |
raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
|
136 |
|
137 |
+
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
138 |
+
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
|
139 |
|
140 |
transformers.utils.logging.set_verbosity_info()
|
141 |
log_level = training_args.get_process_log_level()
|
|
|
144 |
optimizer = create_optimizer(model)
|
145 |
scheduler = create_scheduler(training_args, optimizer)
|
146 |
|
147 |
+
+ gaudi_config = GaudiConfig()
|
148 |
+
+ gaudi_config.use_fused_adam = True
|
149 |
+
+ gaudi_config.use_fused_clip_norm =True
|
150 |
+
|
151 |
+
|
152 |
# build trainer
|
153 |
+
- trainer = Trainer(
|
154 |
+
+ trainer = GaudiTrainer(
|
155 |
model=model,
|
156 |
+
+ gaudi_config=gaudi_config,
|
157 |
args=training_args,
|
158 |
train_dataset=raw_dataset["train"],
|
159 |
eval_dataset=raw_dataset["validation"],
|