diff --git a/assets/wechat.jpg b/assets/wechat.jpg index a37873c1..d3d6c314 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 3b42f17d..826b548c 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -88,7 +88,11 @@ def get_dataset( elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") - stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" - return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=data_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" + ) else: raise ValueError("Unknown mixing strategy.") diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 9d432c56..839dec8f 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -60,7 +60,7 @@ class DataArguments: ) mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( default="concat", - metadata={"help": "Strategy to use in dataset mixing."} + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} ) interleave_probs: Optional[str] = field( default=None, @@ -106,7 +106,8 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") - def init_for_training(self): # support mixing multiple datasets + def init_for_training(self, seed: int): # support mixing multiple datasets + self.seed = seed dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] try: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 50e96bb0..f4da7712 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -88,8 +88,8 @@ def get_train_args( transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) - data_args.init_for_training() + # Check arguments + data_args.init_for_training(training_args.seed) if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.")