better data streaming

This commit is contained in:
hiyouga 2023-11-19 23:32:47 +08:00
parent 211b2db5a8
commit 00baaa990e
2 changed files with 4 additions and 1 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 140 KiB

View File

@ -60,9 +60,12 @@ def get_dataset(
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"):
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))