Source code for eazygrad
from __future__ import annotations
from typing import Any, Callable
from .functions import *
from .tensor_factories import *
from .utils import check
from ._tensor import _Tensor
from . import nn
from . import data
from .optimizer import SGD, Adam, AdamW
# TODO : add more robust matmul tests
[docs]
class no_grad_ctx:
"""
Context manager that temporarily disables gradient tracking.
Examples
--------
>>> with eazygrad.no_grad_ctx():
... y = x + 1.0
"""
def __enter__(self) -> None:
self.prev_state = dag.grad_enable
dag.grad_enable = False
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
dag.grad_enable = self.prev_state
[docs]
def no_grad(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator that disables gradient tracking inside a function.
Parameters
----------
func : callable
Function to execute with gradient tracking disabled.
Returns
-------
callable
Wrapped function.
"""
def wrapper(*args, **kwargs):
with no_grad_ctx():
return func(*args, **kwargs)
return wrapper