diff --git a/src/utils/common.py b/src/utils/common.py index 9137e54f..26523286 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -477,13 +477,13 @@ def preprocess_data( desc="Running tokenizer on dataset" ) - if stage == "pt": - print_unsupervised_dataset_example(dataset[0]) - elif stage == "sft": - print_supervised_dataset_example(dataset[0]) - elif stage == "rm": - print_pairwise_dataset_example(dataset[0]) - elif stage == "ppo": - print_unsupervised_dataset_example(dataset[0]) + if stage == "pt": + print_unsupervised_dataset_example(dataset[0]) + elif stage == "sft": + print_supervised_dataset_example(dataset[0]) + elif stage == "rm": + print_pairwise_dataset_example(dataset[0]) + elif stage == "ppo": + print_unsupervised_dataset_example(dataset[0]) - return dataset + return dataset