Files changed (1) hide show
  1. README.md +24 -7
README.md CHANGED
@@ -5,7 +5,9 @@ tags: []
5
 
6
  # Model Card for Model ID
7
 
8
- ProtST for binary localization
 
 
9
 
10
  ## Running script
11
  ```python
@@ -22,6 +24,9 @@ import torch
22
  import logging
23
  import datasets
24
  import transformers
 
 
 
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
@@ -73,7 +78,8 @@ def create_optimizer(opt_model, lr_ratio=0.1):
73
  "lr": training_args.learning_rate * lr_ratio
74
  },
75
  ]
76
- optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
 
77
  optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
78
 
79
  return optimizer
@@ -98,7 +104,8 @@ def preprocess_logits_for_metrics(logits, labels):
98
 
99
 
100
  if __name__ == "__main__":
101
- device = torch.device("cpu")
 
102
  raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
103
  model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
104
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
@@ -108,8 +115,10 @@ if __name__ == "__main__":
108
  'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
109
  'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
110
  'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
111
- 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
112
- training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
 
 
113
 
114
  def tokenize_protein(example, tokenizer=None):
115
  protein_seq = example["prot_seq"]
@@ -125,7 +134,8 @@ if __name__ == "__main__":
125
  for split in ["train", "validation", "test"]:
126
  raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
127
 
128
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
 
129
 
130
  transformers.utils.logging.set_verbosity_info()
131
  log_level = training_args.get_process_log_level()
@@ -134,9 +144,16 @@ if __name__ == "__main__":
134
  optimizer = create_optimizer(model)
135
  scheduler = create_scheduler(training_args, optimizer)
136
 
 
 
 
 
 
137
  # build trainer
138
- trainer = Trainer(
 
139
  model=model,
 
140
  args=training_args,
141
  train_dataset=raw_dataset["train"],
142
  eval_dataset=raw_dataset["validation"],
 
5
 
6
  # Model Card for Model ID
7
 
8
+ ProtST for binary localization.
9
+
10
+ The following script shows how to finetune ProtST on Gaudi.
11
 
12
  ## Running script
13
  ```python
 
24
  import logging
25
  import datasets
26
  import transformers
27
+ + import habana_frameworks.torch
28
+ + from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
29
+
30
 
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
 
78
  "lr": training_args.learning_rate * lr_ratio
79
  },
80
  ]
81
+ - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
82
+ + optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args)
83
  optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
84
 
85
  return optimizer
 
104
 
105
 
106
  if __name__ == "__main__":
107
+ - device = torch.device("cpu")
108
+ + device = torch.device("hpu")
109
  raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization")
110
  model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
111
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
 
115
  'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \
116
  'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \
117
  'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \
118
+ - 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3}
119
+ + 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True}
120
+ - training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
121
+ + training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0]
122
 
123
  def tokenize_protein(example, tokenizer=None):
124
  protein_seq = example["prot_seq"]
 
134
  for split in ["train", "validation", "test"]:
135
  raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"])
136
 
137
+ - data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
138
+ + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024)
139
 
140
  transformers.utils.logging.set_verbosity_info()
141
  log_level = training_args.get_process_log_level()
 
144
  optimizer = create_optimizer(model)
145
  scheduler = create_scheduler(training_args, optimizer)
146
 
147
+ + gaudi_config = GaudiConfig()
148
+ + gaudi_config.use_fused_adam = True
149
+ + gaudi_config.use_fused_clip_norm =True
150
+
151
+
152
  # build trainer
153
+ - trainer = Trainer(
154
+ + trainer = GaudiTrainer(
155
  model=model,
156
+ + gaudi_config=gaudi_config,
157
  args=training_args,
158
  train_dataset=raw_dataset["train"],
159
  eval_dataset=raw_dataset["validation"],