This commit is contained in:
hiyouga 2023-10-19 16:17:41 +08:00
parent cb0edd2302
commit 7a11a42dfd
4 changed files with 11 additions and 6 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 140 KiB

After

Width:  |  Height:  |  Size: 146 KiB

View File

@ -88,7 +88,11 @@ def get_dataset(
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")

View File

@ -60,7 +60,7 @@ class DataArguments:
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
)
interleave_probs: Optional[str] = field(
default=None,
@ -106,7 +106,8 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
def init_for_training(self): # support mixing multiple datasets
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:

View File

@ -88,8 +88,8 @@ def get_train_args(
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")