From be99799413e1ba37807a02838bf2d87fd966bf55 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 10 Mar 2024 13:35:20 +0800 Subject: [PATCH] update parser --- data/example_dataset/example_dataset.py | 4 +-- src/llmtuner/hparams/parser.py | 44 +++++++++++++++---------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/data/example_dataset/example_dataset.py b/data/example_dataset/example_dataset.py index d7492b44..5d6cfa22 100644 --- a/data/example_dataset/example_dataset.py +++ b/data/example_dataset/example_dataset.py @@ -1,6 +1,6 @@ import json import datasets -from typing import Any, Dict, List +from typing import Any, Dict, Generator, List, Tuple _DESCRIPTION = "An example of dataset." @@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder): ) ] - def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: + def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]: example_dataset = json.load(open(filepath, "r", encoding="utf-8")) for key, example in enumerate(example_dataset): yield key, example diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index f275a1cb..74bcac2f 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -73,19 +73,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") - if model_args.infer_backend == "vllm": - if finetuning_args.stage != "sft": - raise ValueError("vLLM engine only supports auto-regressive models.") - - if model_args.adapter_name_or_path is not None: - raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.") - - if model_args.quantization_bit is not None: - raise ValueError("vLLM engine does not support quantization.") - - if model_args.rope_scaling is not None: - raise ValueError("vLLM engine does not support RoPE scaling.") - def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) @@ -154,6 +141,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.fp16 or training_args.bf16: raise ValueError("Turn off mixed precision training when using `pure_bf16`.") + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + _verify_model_args(model_args, finetuning_args) if ( @@ -252,12 +242,27 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() - _verify_model_args(model_args, finetuning_args) - model_args.device_map = "auto" if data_args.template is None: raise ValueError("Please specify which `template` to use.") + if model_args.infer_backend == "vllm": + if finetuning_args.stage != "sft": + raise ValueError("vLLM engine only supports auto-regressive models.") + + if model_args.adapter_name_or_path is not None: + raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.") + + if model_args.quantization_bit is not None: + raise ValueError("vLLM engine does not support quantization.") + + if model_args.rope_scaling is not None: + raise ValueError("vLLM engine does not support RoPE scaling.") + + _verify_model_args(model_args, finetuning_args) + + model_args.device_map = "auto" + return model_args, data_args, finetuning_args, generating_args @@ -265,12 +270,17 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() - _verify_model_args(model_args, finetuning_args) - model_args.device_map = "auto" if data_args.template is None: raise ValueError("Please specify which `template` to use.") + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + _verify_model_args(model_args, finetuning_args) + + model_args.device_map = "auto" + transformers.set_seed(eval_args.seed) return model_args, data_args, eval_args, finetuning_args