add test case
This commit is contained in:
parent
fb72a3adb0
commit
52a06efaf8
|
@ -19,7 +19,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -73,6 +73,25 @@ class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||||
|
r"""
|
||||||
|
Only applies gradient checkpointing to trainable layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(gradient_checkpointing_func)
|
||||||
|
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
|
arg.requires_grad_(True)
|
||||||
|
|
||||||
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
|
||||||
def _gradient_checkpointing_enable(
|
def _gradient_checkpointing_enable(
|
||||||
self: "PreTrainedModel",
|
self: "PreTrainedModel",
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
@ -96,22 +115,13 @@ def _gradient_checkpointing_enable(
|
||||||
else:
|
else:
|
||||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
|
||||||
module: "torch.nn.Module" = func.__self__
|
|
||||||
|
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
|
||||||
for arg in args:
|
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
|
||||||
arg.requires_grad_(True)
|
|
||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
|
||||||
|
|
||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
self.enable_input_require_grads()
|
self.enable_input_require_grads()
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
else: # have already enabled input require gradients
|
else: # have already enabled input require gradients
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
|
|
||||||
def _fp32_forward_post_hook(
|
def _fp32_forward_post_hook(
|
||||||
|
|
|
@ -51,6 +51,12 @@ def test_checkpointing_disable():
|
||||||
assert getattr(module, "gradient_checkpointing") is False
|
assert getattr(module, "gradient_checkpointing") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsloth_gradient_checkpointing():
|
||||||
|
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||||
|
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||||
|
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
||||||
|
|
||||||
|
|
||||||
def test_upcast_layernorm():
|
def test_upcast_layernorm():
|
||||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
|
Loading…
Reference in New Issue