From 1d8a1878ea053d1dbfc570eea868d2514ce75a51 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 2 Aug 2023 19:10:23 +0800 Subject: [PATCH] fix PPO trainer --- src/llmtuner/tuner/ppo/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index d1f47850..35d36787 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -161,7 +161,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): unwrapped_model.pretrained_model.generation_config._from_model_config = False queries, responses = [], [] - query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu() + query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu() for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1