""" module for TRL PPO training """ import torch from tqdm import tqdm from trl import PPOTrainer class TRLPPOTrainer(PPOTrainer): """ wrapper for ppo trainer to handle customizations """ def train( self, reward_pipe, resume_from_checkpoint=None, # pylint: disable=unused-argument ): generation_kwargs = { "min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": self.tokenizer.eos_token_id, "max_new_tokens": 32, } sent_kwargs = { "return_all_scores": True, "function_to_apply": "none", "batch_size": 16, } for epoch, batch in tqdm( # pylint: disable=unused-variable enumerate(self.dataloader) ): query_tensors = batch["input_ids"] # generate model response response_tensors, ref_response_tensors = self.generate( query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs ) batch["response"] = self.tokenizer.batch_decode(response_tensors) batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) # Compute sentiment score texts = [q + r for q, r in zip(batch["query"], batch["response"])] pipe_outputs = reward_pipe(texts, **sent_kwargs) rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) ref_rewards = [ torch.tensor(output[1]["score"]) for output in ref_pipe_outputs ] batch["ref_rewards"] = ref_rewards # Run PPO step stats = self.step(query_tensors, response_tensors, rewards) self.log_stats( stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"], )