add test case
This commit is contained in:
parent
fb72a3adb0
commit
52a06efaf8
|
@ -19,7 +19,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
|
@ -73,6 +73,25 @@ class UnslothGradientCheckpointing(torch.autograd.Function):
|
|||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
|
||||
@wraps(gradient_checkpointing_func)
|
||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel",
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||
|
@ -96,22 +115,13 @@ def _gradient_checkpointing_enable(
|
|||
else:
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
|
|
|
@ -51,6 +51,12 @@ def test_checkpointing_disable():
|
|||
assert getattr(module, "gradient_checkpointing") is False
|
||||
|
||||
|
||||
def test_unsloth_gradient_checkpointing():
|
||||
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
|
|
Loading…
Reference in New Issue