fix PPO trainer
This commit is contained in:
parent
b5ba87952a
commit
1d8a1878ea
|
@ -161,7 +161,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||
|
||||
queries, responses = [], []
|
||||
query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
||||
query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
|
|
Loading…
Reference in New Issue