fix ChatGLM lm_head #494
This commit is contained in:
parent
20a29297b1
commit
d019956808
|
@ -153,6 +153,10 @@ def load_model_and_tokenizer(
|
||||||
if "GenerationMixin" not in str(model.generate.__func__):
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
||||||
|
# Fix LM head (for ChatGLM2)
|
||||||
|
if not hasattr(model, "lm_head"):
|
||||||
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files.
|
||||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
|
|
|
@ -32,12 +32,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
Trainer.__init__(self, **kwargs)
|
Trainer.__init__(self, **kwargs)
|
||||||
if ref_model is not None:
|
if not hasattr(self, "accelerator"):
|
||||||
if hasattr(self, "accelerator"):
|
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
|
||||||
else:
|
|
||||||
raise AttributeError("Please update `transformers`.")
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
if ref_model is not None:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self,
|
||||||
model: Optional[torch.nn.Module] = None,
|
model: Optional[torch.nn.Module] = None,
|
||||||
|
|
|
@ -45,7 +45,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
output_box = gr.Markdown()
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
input_list = [
|
input_components = [
|
||||||
top_elems["lang"],
|
top_elems["lang"],
|
||||||
top_elems["model_name"],
|
top_elems["model_name"],
|
||||||
top_elems["checkpoints"],
|
top_elems["checkpoints"],
|
||||||
|
@ -62,13 +62,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
||||||
predict
|
predict
|
||||||
]
|
]
|
||||||
|
|
||||||
output_list = [
|
output_components = [
|
||||||
output_box,
|
output_box,
|
||||||
process_bar
|
process_bar
|
||||||
]
|
]
|
||||||
|
|
||||||
cmd_preview_btn.click(runner.preview_eval, input_list, output_list)
|
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
|
||||||
start_btn.click(runner.run_eval, input_list, output_list)
|
start_btn.click(runner.run_eval, input_components, output_components)
|
||||||
stop_btn.click(runner.set_abort, queue=False)
|
stop_btn.click(runner.set_abort, queue=False)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
|
Loading…
Reference in New Issue