Source code for eazygrad.functions.specials

from __future__ import annotations

import numpy as np
from ..utils import check
from .._tensor import _Tensor
from ..grad import operations, dag
from .math import exp
import numba as nb
# import line_profiler

def _validate_dim_arg(dim: int) -> None:
	if not isinstance(dim, int):
		raise ValueError("Dim argument should be an integer, got {}".format(type(dim)))

# generic but temporary implementation of logsumexp
# Slow, will be replaced with a numba friendly version
def _logsumexp_generic(f64_array: np.ndarray, dim: int) -> np.ndarray:
	M = f64_array.max(axis=dim, keepdims=True)
	logsumexp = np.exp(f64_array-M).sum(axis=dim, keepdims=True)
	logsumexp = np.log(logsumexp)
	logsumexp += M
	return logsumexp

@nb.njit(cache=True, fastmath=True, parallel=True)
def _fast_logsumexp(x2d: np.ndarray) -> np.ndarray:
    N, K = x2d.shape
    out = np.empty(N, dtype=np.float64)

    for i in nb.prange(N):
        # 1) max
        m = -np.inf
        for j in range(K):
            v = x2d[i, j]
            if v > m:
                m = v

        # Handle all -inf row
        if m == -np.inf:
            out[i] = -np.inf
            continue

        # 2) sum exp(x - m)
        s = 0.0
        for j in range(K):
            s += np.exp(x2d[i, j] - m)

        out[i] = np.log(s) + m

    return out

[docs] def logsumexp(input: _Tensor, dim: int, keepdims: bool = False) -> _Tensor: """ Compute ``log(sum(exp(input)))`` in a numerically stable way. Parameters ---------- input : _Tensor Input tensor. dim : int Axis along which the reduction is performed. keepdims : bool, default=False Whether to keep the reduced dimension in the output. Returns ------- _Tensor Reduced tensor after applying the log-sum-exp operation. See Also -------- `torch.logsumexp <https://pytorch.org/docs/stable/generated/torch.logsumexp.html>`_ """ if not isinstance(input, _Tensor): raise TypeError(f"Expected input to be an eazygrad tensor, got {type(input)}") ndim = input.ndim requires_grad = input.requires_grad if ndim == 0: # no-op, return the tensor result = _Tensor(input.numpy(), requires_grad=requires_grad) if requires_grad : result.node_id = dag.create_node( parents_id = [input.node_id], operation = operations.Copy(dtype=input.dtype), result = result ) return result _validate_dim_arg(dim) dtype = input.dtype reshaped = False # Maybe type promotion for exp and sum ops f64_array = input._array.astype(np.float64, copy=False) logsumexp = f64_array if dim != -1 and dim != ndim-1: logsumexp = np.moveaxis(logsumexp, dim, -1) if ndim > 2: new_shape = logsumexp.shape logsumexp = logsumexp.reshape(-1, logsumexp.shape[-1]) reshaped = True if not logsumexp.flags.c_contiguous: logsumexp = np.ascontiguousarray(logsumexp) if ndim==1: # Convert to 2d array logsumexp = _fast_logsumexp(np.expand_dims(logsumexp, axis=0)) else: logsumexp = _fast_logsumexp(logsumexp) if reshaped: logsumexp = logsumexp.reshape(new_shape[:-1]) if keepdims and logsumexp.ndim < ndim: logsumexp = np.expand_dims(logsumexp, axis=dim) # Recast to input dtype result = _Tensor(logsumexp, requires_grad=requires_grad, dtype=dtype) if requires_grad : result.node_id = dag.create_node( parents_id = [input.node_id], operation = operations.LogSumExp(arr=f64_array, logsumexp=logsumexp, dim=dim), result = result ) return result
[docs] def softmax(input: _Tensor, dim: int) -> _Tensor: """ Compute the softmax of a tensor along a given axis. Parameters ---------- input : _Tensor Input tensor. dim : int Axis along which the softmax is computed. Returns ------- _Tensor Tensor of normalized exponentials. See Also -------- `torch.softmax <https://pytorch.org/docs/stable/generated/torch.softmax.html>`_ """ if not isinstance(input, _Tensor): raise TypeError(f"Expected input to be an eazygrad tensor, got {type(input)}") _validate_dim_arg(dim) dtype = input.dtype # Type promotion for numerical stability input = input.double() shifted_input = input - logsumexp(input, dim, keepdims=True) result = exp(shifted_input).to(dtype) return result
[docs] def log_softmax(input: _Tensor, dim: int) -> _Tensor: """ Compute the logarithm of the softmax along a given axis. Parameters ---------- input : _Tensor Input tensor. dim : int Axis along which the log-softmax is computed. Returns ------- _Tensor Tensor containing log-probabilities. See Also -------- `torch.log_softmax <https://pytorch.org/docs/stable/generated/torch.log_softmax.html>`_ """ if not isinstance(input, _Tensor): raise TypeError(f"Expected input to be an eazygrad tensor, got {type(input)}") _validate_dim_arg(dim) dtype = input.dtype # Type promotion for numerical stability input = input.to(np.float64) result = (input - logsumexp(input, dim, keepdims=True)).to(dtype) return result