fix ChatGLM lm_head #494

This commit is contained in:
hiyouga 2023-08-14 14:14:48 +08:00
parent 20a29297b1
commit d019956808
3 changed files with 12 additions and 8 deletions

View File

@ -153,6 +153,10 @@ def load_model_and_tokenizer(
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files. # Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()

View File

@ -32,12 +32,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs) Trainer.__init__(self, **kwargs)
if ref_model is not None: if not hasattr(self, "accelerator"):
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.") raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward( def concatenated_forward(
self, self,
model: Optional[torch.nn.Module] = None, model: Optional[torch.nn.Module] = None,

View File

@ -45,7 +45,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
input_list = [ input_components = [
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
@ -62,13 +62,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict predict
] ]
output_list = [ output_components = [
output_box, output_box,
process_bar process_bar
] ]
cmd_preview_btn.click(runner.preview_eval, input_list, output_list) cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_list, output_list) start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
return dict( return dict(