Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
bitwise_not,
bitwise_or,
bmm,
celu,
clamp,
conv2d,
cos,
Expand All @@ -17,13 +18,15 @@
ge,
gelu,
gt,
instance_norm,
isinf,
isnan,
layer_norm,
le,
lt,
max_pool2d,
mm,
msort,
mul,
ne,
neg,
Expand All @@ -39,6 +42,7 @@
softmax,
sub,
tanh,
threshold,
)

__all__ = [
Expand All @@ -50,6 +54,7 @@
"bitwise_not",
"bitwise_or",
"bmm",
"celu",
"clamp",
"conv2d",
"cos",
Expand All @@ -60,13 +65,15 @@
"ge",
"gelu",
"gt",
"instance_norm",
"isinf",
"isnan",
"layer_norm",
"le",
"lt",
"max_pool2d",
"mm",
"msort",
"mul",
"ne",
"neg",
Expand All @@ -82,4 +89,5 @@
"softmax",
"sub",
"tanh",
"threshold",
]
23 changes: 23 additions & 0 deletions src/ntops/kernels/celu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, alpha, output):
output = max(0.0, input) + min(0.0, alpha * (ntl.exp(input / alpha) - 1)) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
162 changes: 162 additions & 0 deletions src/ntops/kernels/instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement as reduction_arrangement


def arrangement(
input,
mean,
var,
running_mean,
running_var,
weight,
bias,
eps,
output,
num_normalized_elements,
use_input_stats,
dims,
block_size=None,
):
if block_size is None:
block_size = ninetoothed.block_size()

def _arrange_channel_tensor(tensor):
arranged = tensor.tile((1,))
arranged.dtype = arranged.dtype.squeeze(0)
arranged = arranged.unsqueeze(0)
arranged = arranged.expand((input.shape[0], -1))

return arranged

def _arrange_mean_or_var(tensor):
arranged = tensor.tile((1, 1))
arranged.dtype = arranged.dtype.squeeze((0, 1))

return arranged

input_arranged, output_arranged = reduction_arrangement(
input, output, dim=dims, block_size=block_size
)
mean_arranged = _arrange_mean_or_var(mean)
var_arranged = _arrange_mean_or_var(var)
running_mean_arranged = _arrange_channel_tensor(running_mean)
running_var_arranged = _arrange_channel_tensor(running_var)
weight_arranged = _arrange_channel_tensor(weight)
bias_arranged = _arrange_channel_tensor(bias)
eps_arranged = eps
num_normalized_elements_arranged = num_normalized_elements

if use_input_stats:
return (
input_arranged,
mean_arranged,
var_arranged,
weight_arranged,
bias_arranged,
eps_arranged,
output_arranged,
num_normalized_elements_arranged,
)
else:
return (
input_arranged,
running_mean_arranged,
running_var_arranged,
weight_arranged,
bias_arranged,
eps_arranged,
output_arranged,
)


def application_using_input_stats(
input,
mean,
var,
weight,
bias,
eps,
output,
num_normalized_elements,
):
_mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

for i in range(input.shape[0]):
_mean += ntl.cast(input[i], ntl.float32)

mean = ntl.sum(_mean, 0) / num_normalized_elements

_var = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

for i in range(input.shape[0]):
diff = ntl.cast(input[i], ntl.float32) - mean
diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0)
_var += diff * diff

var = ntl.sum(_var, 0) / num_normalized_elements

application_with_mean_var(input, mean, var, weight, bias, eps, output)


def application_with_mean_var(
input,
mean,
var,
weight,
bias,
eps,
output,
):
std = ntl.sqrt(var + eps)

for i in range(input.shape[0]):
output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight + bias


def premake(
ndim,
use_input_stats,
num_normalized_elements,
dtype=None,
block_size=None,
):
dims = tuple(reversed(range(2, ndim)))

arrangement_ = functools.partial(
arrangement,
use_input_stats=use_input_stats,
dims=dims,
block_size=block_size,
)

input = Tensor(ndim, other=0, dtype=dtype)
mean, var = (Tensor(2, dtype=dtype) for _ in range(2))
running_mean, running_var, weight, bias = (Tensor(1, dtype=dtype) for _ in range(4))
eps = Tensor(0, dtype=ninetoothed.float64)
output = Tensor(ndim, dtype=dtype)
num_normalized_elements = Tensor(0, constexpr=True, value=num_normalized_elements)

if use_input_stats:
application = application_using_input_stats
else:
application = application_with_mean_var

tensors = (
input,
mean,
var,
running_mean,
running_var,
weight,
bias,
eps,
output,
num_normalized_elements,
)

return arrangement_, application, tensors
46 changes: 46 additions & 0 deletions src/ntops/kernels/msort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

ndim = input.ndim
dim = 0

non_target_dims = tuple(i for i in range(input.ndim) if i != dim)

def _arrangement(input):
arranged = input.permute(non_target_dims + (dim,))

if ndim == 1:
arranged = arranged.unsqueeze(0)
arranged = arranged.flatten(end_dim=-1)

arranged = arranged.tile((1, -1))
arranged.dtype = arranged.dtype.squeeze(0)

return arranged

return _arrangement(input), _arrangement(output)


def application(input, output):
output = ntl.sort(input) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(
ndim, dtype=dtype, other=float("inf"), shape_options={"constexpr": True}
),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
)

return arrangement_, application, tensors
24 changes: 24 additions & 0 deletions src/ntops/kernels/threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, threshold, value, output):
output = ntl.where(input > threshold, input, value) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
8 changes: 8 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ntops.torch.bitwise_not import bitwise_not
from ntops.torch.bitwise_or import bitwise_or
from ntops.torch.bmm import bmm
from ntops.torch.celu import celu
from ntops.torch.clamp import clamp
from ntops.torch.conv2d import conv2d
from ntops.torch.cos import cos
Expand All @@ -16,6 +17,7 @@
from ntops.torch.ge import ge
from ntops.torch.gelu import gelu
from ntops.torch.gt import gt
from ntops.torch.instance_norm import instance_norm
from ntops.torch.isinf import isinf
from ntops.torch.isnan import isnan
from ntops.torch.layer_norm import layer_norm
Expand All @@ -24,6 +26,7 @@
from ntops.torch.matmul import matmul
from ntops.torch.max_pool2d import max_pool2d
from ntops.torch.mm import mm
from ntops.torch.msort import msort
from ntops.torch.mul import mul
from ntops.torch.ne import ne
from ntops.torch.neg import neg
Expand All @@ -39,6 +42,7 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh
from ntops.torch.threshold import threshold

__all__ = [
"abs",
Expand All @@ -49,6 +53,7 @@
"bitwise_not",
"bitwise_or",
"bmm",
"celu",
"clamp",
"conv2d",
"cos",
Expand All @@ -59,6 +64,7 @@
"ge",
"gelu",
"gt",
"instance_norm",
"isinf",
"isnan",
"layer_norm",
Expand All @@ -67,6 +73,7 @@
"matmul",
"max_pool2d",
"mm",
"msort",
"mul",
"ne",
"neg",
Expand All @@ -82,4 +89,5 @@
"softmax",
"sub",
"tanh",
"threshold",
]
Loading