diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..e72246d 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -7,6 +7,7 @@ bitwise_not, bitwise_or, bmm, + celu, clamp, conv2d, cos, @@ -17,6 +18,7 @@ ge, gelu, gt, + instance_norm, isinf, isnan, layer_norm, @@ -24,6 +26,7 @@ lt, max_pool2d, mm, + msort, mul, ne, neg, @@ -39,6 +42,7 @@ softmax, sub, tanh, + threshold, ) __all__ = [ @@ -50,6 +54,7 @@ "bitwise_not", "bitwise_or", "bmm", + "celu", "clamp", "conv2d", "cos", @@ -60,6 +65,7 @@ "ge", "gelu", "gt", + "instance_norm", "isinf", "isnan", "layer_norm", @@ -67,6 +73,7 @@ "lt", "max_pool2d", "mm", + "msort", "mul", "ne", "neg", @@ -82,4 +89,5 @@ "softmax", "sub", "tanh", + "threshold", ] diff --git a/src/ntops/kernels/celu.py b/src/ntops/kernels/celu.py new file mode 100644 index 0000000..6ba511d --- /dev/null +++ b/src/ntops/kernels/celu.py @@ -0,0 +1,23 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, alpha, output): + output = max(0.0, input) + min(0.0, alpha * (ntl.exp(input / alpha) - 1)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/instance_norm.py b/src/ntops/kernels/instance_norm.py new file mode 100644 index 0000000..a8a486f --- /dev/null +++ b/src/ntops/kernels/instance_norm.py @@ -0,0 +1,162 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +def arrangement( + input, + mean, + var, + running_mean, + running_var, + weight, + bias, + eps, + output, + num_normalized_elements, + use_input_stats, + dims, + block_size=None, +): + if block_size is None: + block_size = ninetoothed.block_size() + + def _arrange_channel_tensor(tensor): + arranged = tensor.tile((1,)) + arranged.dtype = arranged.dtype.squeeze(0) + arranged = arranged.unsqueeze(0) + arranged = arranged.expand((input.shape[0], -1)) + + return arranged + + def _arrange_mean_or_var(tensor): + arranged = tensor.tile((1, 1)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + input_arranged, output_arranged = reduction_arrangement( + input, output, dim=dims, block_size=block_size + ) + mean_arranged = _arrange_mean_or_var(mean) + var_arranged = _arrange_mean_or_var(var) + running_mean_arranged = _arrange_channel_tensor(running_mean) + running_var_arranged = _arrange_channel_tensor(running_var) + weight_arranged = _arrange_channel_tensor(weight) + bias_arranged = _arrange_channel_tensor(bias) + eps_arranged = eps + num_normalized_elements_arranged = num_normalized_elements + + if use_input_stats: + return ( + input_arranged, + mean_arranged, + var_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + num_normalized_elements_arranged, + ) + else: + return ( + input_arranged, + running_mean_arranged, + running_var_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + ) + + +def application_using_input_stats( + input, + mean, + var, + weight, + bias, + eps, + output, + num_normalized_elements, +): + _mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + _mean += ntl.cast(input[i], ntl.float32) + + mean = ntl.sum(_mean, 0) / num_normalized_elements + + _var = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i], ntl.float32) - mean + diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0) + _var += diff * diff + + var = ntl.sum(_var, 0) / num_normalized_elements + + application_with_mean_var(input, mean, var, weight, bias, eps, output) + + +def application_with_mean_var( + input, + mean, + var, + weight, + bias, + eps, + output, +): + std = ntl.sqrt(var + eps) + + for i in range(input.shape[0]): + output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight + bias + + +def premake( + ndim, + use_input_stats, + num_normalized_elements, + dtype=None, + block_size=None, +): + dims = tuple(reversed(range(2, ndim))) + + arrangement_ = functools.partial( + arrangement, + use_input_stats=use_input_stats, + dims=dims, + block_size=block_size, + ) + + input = Tensor(ndim, other=0, dtype=dtype) + mean, var = (Tensor(2, dtype=dtype) for _ in range(2)) + running_mean, running_var, weight, bias = (Tensor(1, dtype=dtype) for _ in range(4)) + eps = Tensor(0, dtype=ninetoothed.float64) + output = Tensor(ndim, dtype=dtype) + num_normalized_elements = Tensor(0, constexpr=True, value=num_normalized_elements) + + if use_input_stats: + application = application_using_input_stats + else: + application = application_with_mean_var + + tensors = ( + input, + mean, + var, + running_mean, + running_var, + weight, + bias, + eps, + output, + num_normalized_elements, + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/msort.py b/src/ntops/kernels/msort.py new file mode 100644 index 0000000..e91f354 --- /dev/null +++ b/src/ntops/kernels/msort.py @@ -0,0 +1,46 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + dim = 0 + + non_target_dims = tuple(i for i in range(input.ndim) if i != dim) + + def _arrangement(input): + arranged = input.permute(non_target_dims + (dim,)) + + if ndim == 1: + arranged = arranged.unsqueeze(0) + arranged = arranged.flatten(end_dim=-1) + + arranged = arranged.tile((1, -1)) + arranged.dtype = arranged.dtype.squeeze(0) + + return arranged + + return _arrangement(input), _arrangement(output) + + +def application(input, output): + output = ntl.sort(input) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor( + ndim, dtype=dtype, other=float("inf"), shape_options={"constexpr": True} + ), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/threshold.py b/src/ntops/kernels/threshold.py new file mode 100644 index 0000000..2744871 --- /dev/null +++ b/src/ntops/kernels/threshold.py @@ -0,0 +1,24 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, threshold, value, output): + output = ntl.where(input > threshold, input, value) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..9719b16 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -6,6 +6,7 @@ from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm +from ntops.torch.celu import celu from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos @@ -16,6 +17,7 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.instance_norm import instance_norm from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -24,6 +26,7 @@ from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm +from ntops.torch.msort import msort from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg @@ -39,6 +42,7 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.threshold import threshold __all__ = [ "abs", @@ -49,6 +53,7 @@ "bitwise_not", "bitwise_or", "bmm", + "celu", "clamp", "conv2d", "cos", @@ -59,6 +64,7 @@ "ge", "gelu", "gt", + "instance_norm", "isinf", "isnan", "layer_norm", @@ -67,6 +73,7 @@ "matmul", "max_pool2d", "mm", + "msort", "mul", "ne", "neg", @@ -82,4 +89,5 @@ "softmax", "sub", "tanh", + "threshold", ] diff --git a/src/ntops/torch/celu.py b/src/ntops/torch/celu.py new file mode 100644 index 0000000..d3d661b --- /dev/null +++ b/src/ntops/torch/celu.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def celu(input, alpha=1.0, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.celu.premake, input.ndim) + + kernel(input, alpha, output) + + return output diff --git a/src/ntops/torch/instance_norm.py b/src/ntops/torch/instance_norm.py new file mode 100644 index 0000000..7c09901 --- /dev/null +++ b/src/ntops/torch/instance_norm.py @@ -0,0 +1,72 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def instance_norm( + input, + running_mean=None, + running_var=None, + weight=None, + bias=None, + use_input_stats=True, + momentum=0.1, + eps=1e-05, +): + if weight is None: + weight = torch.ones(input.shape[1], device=input.device, dtype=input.dtype) + + if bias is None: + bias = torch.zeros(input.shape[1], device=input.device, dtype=input.dtype) + + has_running_stats = running_mean is not None and running_var is not None + + if use_input_stats: + mean = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype) + var = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype) + + output = torch.empty_like(input) + + num_normalized_elements = math.prod(input.shape[2:]) + kernel = _cached_make( + ntops.kernels.instance_norm.premake, + input.ndim, + use_input_stats, + num_normalized_elements, + dtype=input.dtype, + ) + + if use_input_stats: + kernel( + input, + mean, + var, + weight, + bias, + eps, + output, + num_normalized_elements, + ) + + # We reduce in PyTorch instead of using tl.atomic_add in Triton because: + # 1. Triton blocks cannot synchronize to safely apply the momentum update after all additions finish. + # 2. N blocks atomically adding to the same C addresses creates severe memory contention. + if has_running_stats: + batch_mean = mean.mean(0) + avg_vars = var.mean(0) + + unbiased_var = ( + (avg_vars) * num_normalized_elements / (num_normalized_elements - 1) + if num_normalized_elements > 1 + else avg_vars + ) + + running_mean.mul_(1 - momentum).add_(momentum * batch_mean) + running_var.mul_(1 - momentum).add_(momentum * unbiased_var) + else: + kernel(input, running_mean, running_var, weight, bias, eps, output) + + return output diff --git a/src/ntops/torch/msort.py b/src/ntops/torch/msort.py new file mode 100644 index 0000000..5889f19 --- /dev/null +++ b/src/ntops/torch/msort.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def msort(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.msort.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/threshold.py b/src/ntops/torch/threshold.py new file mode 100644 index 0000000..f671391 --- /dev/null +++ b/src/ntops/torch/threshold.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def threshold(input, threshold, value, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.threshold.premake, input.ndim) + + kernel(input, threshold, value, output) + + return output diff --git a/tests/test_celu.py b/tests/test_celu.py new file mode 100644 index 0000000..680cccb --- /dev/null +++ b/tests/test_celu.py @@ -0,0 +1,20 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("inplace", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_celu(shape, inplace, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + alpha = 1.0 + + ninetoothed_output = ntops.torch.celu(input, alpha, inplace) + reference_output = F.celu(input, alpha, inplace) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py new file mode 100644 index 0000000..b3cbbca --- /dev/null +++ b/tests/test_instance_norm.py @@ -0,0 +1,86 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("eps", (1e-8, 1e-5, 1e-3)) +@pytest.mark.parametrize("bias_is_none", (False, True)) +@pytest.mark.parametrize("weight_is_none", (False, True)) +@pytest.mark.parametrize("use_input_stats", (False, True)) +@pytest.mark.parametrize("track_running_stats", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_instance_norm( + shape, + dtype, + device, + rtol, + atol, + weight_is_none, + bias_is_none, + use_input_stats, + track_running_stats, + eps, +): + while len(shape) < 3: + shape.insert(0, 1) + + input = torch.randn(shape, dtype=dtype, device=device) + + if weight_is_none: + weight = None + else: + weight = torch.randn(shape[1], dtype=dtype, device=device) + + if bias_is_none: + bias = None + else: + bias = torch.randn(shape[1], dtype=dtype, device=device) + + if use_input_stats and not track_running_stats: + reference_running_mean = None + reference_running_var = None + ninetoothed_running_mean = None + ninetoothed_running_var = None + else: + reference_running_mean = torch.randn(shape[1], dtype=dtype, device=device) + reference_running_var = torch.randn(shape[1], dtype=dtype, device=device).abs() + + if use_input_stats: + ninetoothed_running_mean = reference_running_mean.clone() + ninetoothed_running_var = reference_running_var.clone() + else: + ninetoothed_running_mean = reference_running_mean + ninetoothed_running_var = reference_running_var + + ninetoothed_output = ntops.torch.instance_norm( + input, + running_mean=ninetoothed_running_mean, + running_var=ninetoothed_running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + eps=eps, + ) + reference_output = torch.nn.functional.instance_norm( + input, + running_mean=reference_running_mean, + running_var=reference_running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + eps=eps, + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) + + if use_input_stats and track_running_stats: + assert torch.allclose( + ninetoothed_running_mean, reference_running_mean, rtol=rtol, atol=atol + ) + assert torch.allclose( + ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol + ) diff --git a/tests/test_msort.py b/tests/test_msort.py new file mode 100644 index 0000000..a2754d2 --- /dev/null +++ b/tests/test_msort.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_msort(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.msort(input) + reference_output = torch.msort(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_threshold.py b/tests/test_threshold.py new file mode 100644 index 0000000..73ddb63 --- /dev/null +++ b/tests/test_threshold.py @@ -0,0 +1,22 @@ +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_threshold(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + threshold = random.uniform(-1, 1) + value = random.uniform(0, 1) + + ninetoothed_output = ntops.torch.threshold(input, threshold, value) + reference_output = F.threshold(input, threshold, value) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)