fix test case

This commit is contained in:
hiyouga 2024-09-08 01:50:51 +08:00
parent 52a06efaf8
commit b332908ab4
2 changed files with 4 additions and 1 deletions

View File

@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func

View File

@ -54,7 +54,7 @@ def test_checkpointing_disable():
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod
def test_upcast_layernorm():