fix #1218
This commit is contained in:
parent
cb0edd2302
commit
7a11a42dfd
Binary file not shown.
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 146 KiB |
|
@ -88,7 +88,11 @@ def get_dataset(
|
|||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=data_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
|
|
@ -60,7 +60,7 @@ class DataArguments:
|
|||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing."}
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
default=None,
|
||||
|
@ -106,7 +106,8 @@ class DataArguments:
|
|||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
|
|
|
@ -88,8 +88,8 @@ def get_train_args(
|
|||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
|
Loading…
Reference in New Issue