Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions include/matx/core/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,36 @@ struct MemTracker {

matxMemoryStats.currentBytesAllocated -= bytes;

// Check if the CUDA context is still valid before attempting to free.
// During static destruction at program exit, the CUDA context may already
// be destroyed, making cudaFree/cudaFreeAsync calls fail with
// CUDA_ERROR_CONTEXT_IS_DESTROYED.
auto is_cuda_free = [&]() {
if (iter->second.kind == MATX_HOST_MALLOC_MEMORY) return true; // not CUDA
int dev;
return cudaGetDevice(&dev) == cudaSuccess;
};

switch (iter->second.kind) {
case MATX_MANAGED_MEMORY:
[[fallthrough]];
case MATX_DEVICE_MEMORY:
cudaFree(ptr);
if (is_cuda_free()) cudaFree(ptr);
break;
case MATX_HOST_MEMORY:
cudaFreeHost(ptr);
if (is_cuda_free()) cudaFreeHost(ptr);
break;
case MATX_HOST_MALLOC_MEMORY:
free(ptr);
break;
case MATX_ASYNC_DEVICE_MEMORY:
if constexpr (std::is_same_v<no_stream_t, StreamType>) {
cudaFreeAsync(ptr, iter->second.stream);
}
else {
cudaFreeAsync(ptr, st.stream);
if (is_cuda_free()) {
if constexpr (std::is_same_v<no_stream_t, StreamType>) {
cudaFreeAsync(ptr, iter->second.stream);
}
else {
cudaFreeAsync(ptr, st.stream);
}
}
break;
default:
Expand Down
24 changes: 24 additions & 0 deletions include/matx/core/half_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,30 @@ tanh(const matxHalfComplex<T> &x)
using matxFp16Complex = matxHalfComplex<matxFp16>; ///< Alias for a MatX fp16 complex wrapper
using matxBf16Complex = matxHalfComplex<matxBf16>; ///< Alias for a MatXbf16 complex wrapper

struct matxFp16ComplexPlanar : public matxFp16Complex {
using matxFp16Complex::matxFp16Complex;
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar() = default;
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar(const matxFp16Complex &rhs) : matxFp16Complex(rhs) {}
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar &operator=(const matxFp16Complex &rhs)
{
this->x = rhs.x;
this->y = rhs.y;
return *this;
}
};

struct matxBf16ComplexPlanar : public matxBf16Complex {
using matxBf16Complex::matxBf16Complex;
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar() = default;
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar(const matxBf16Complex &rhs) : matxBf16Complex(rhs) {}
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar &operator=(const matxBf16Complex &rhs)
{
this->x = rhs.x;
this->y = rhs.y;
return *this;
}
};

}; // namespace matx

#ifndef __CUDACC_RTC__
Expand Down
4 changes: 3 additions & 1 deletion include/matx/core/operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ namespace matx {
if(supported) {
return make_tensor<typename Op::value_type>(in.Data(), in.Descriptor());
} else {
return make_tensor<typename Op::value_type>(in.Shape(), space, stream);
// Fresh allocation is row-major contiguous; copying an affine stride
// descriptor onto it would break transforms (e.g. cuFFT batch distance).
return make_tensor<typename Op::value_type>(Shape(in), space, stream);
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
detail::tensor_impl_t<T, RANK, Desc>{std::forward<D2>(desc)},
storage_{std::move(s)}
{
ValidatePlanarLayoutOnCreate_();
this->SetLocalData(storage_.data());
}

Expand All @@ -194,6 +195,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
detail::tensor_impl_t<T, RANK, D2>{std::forward<D2>(desc)},
storage_{std::move(s)}
{
ValidatePlanarLayoutOnCreate_();
this->SetLocalData(ldata);
}

Expand All @@ -210,6 +212,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
detail::tensor_impl_t<T, RANK, D2>{std::forward<D2>(desc)},
storage_{make_owning_storage<T>(this->desc_.TotalSize())}
{
ValidatePlanarLayoutOnCreate_();
this->SetLocalData(storage_.data());
}

Expand All @@ -225,6 +228,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
detail::tensor_impl_t<T, RANK, Desc>(cuda::std::array<index_t, 0>{}),
storage_{make_owning_storage<T>(1)}
{
ValidatePlanarLayoutOnCreate_();
this->SetLocalData(storage_.data());
}

Expand All @@ -239,6 +243,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
detail::tensor_impl_t<T, RANK, Desc>(shape),
storage_{make_owning_storage<T>(this->desc_.TotalSize())}
{
ValidatePlanarLayoutOnCreate_();
this->SetLocalData(storage_.data());
}

Expand Down Expand Up @@ -944,6 +949,7 @@ MATX_LOOP_UNROLL
Reset(T *const data, ShapeType &&shape) noexcept
{
this->desc_.InitFromShape(std::forward<ShapeType>(shape));
ValidatePlanarLayoutOnCreate_();
// For non-owning storage, we need to recreate the storage object
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
this->SetData(data);
Expand All @@ -965,6 +971,7 @@ MATX_LOOP_UNROLL
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

ValidatePlanarLayoutOnCreate_();
// For non-owning storage, we need to recreate the storage object
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
this->SetData(data);
Expand All @@ -986,6 +993,7 @@ MATX_LOOP_UNROLL
__MATX_HOST__ __MATX_INLINE__ void
Reset(T *const data, T *const ldata) noexcept
{
ValidatePlanarLayoutOnCreate_();
// For non-owning storage, we need to recreate the storage object
storage_ = make_non_owning_storage<T>(data, this->desc_.TotalSize());
this->SetData(ldata);
Expand Down Expand Up @@ -1529,6 +1537,18 @@ MATX_LOOP_UNROLL
}

private:
__MATX_HOST__ __MATX_INLINE__ void ValidatePlanarLayoutOnCreate_() const
{
if constexpr (is_planar_complex_v<T>) {
if constexpr (RANK > 0) {
MATX_ASSERT_STR(this->Stride(RANK - 1) == 1, matxInvalidDim,
"Planar complex tensors must have unit innermost stride");
}
MATX_ASSERT_STR(this->IsContiguous(), matxInvalidDim,
"Planar complex tensors must be contiguous (non-unity strides are not supported)");
}
}

Storage<T> storage_;
std::string name_ = std::string("tensor_") + std::to_string(RANK) + "_" + detail::to_short_str<T>();
};
Expand Down
1 change: 1 addition & 0 deletions include/matx/core/tensor_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <type_traits>
#include <cstdint>
#include "matx/core/error.h"
#include "matx/core/type_utils.h"

namespace matx {

Expand Down
93 changes: 86 additions & 7 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,52 @@ class tensor_impl_t {
using data_type = TensorData;
using shape_type = typename Desc::shape_type;
using stride_type = typename Desc::stride_type;
using shape_container = typename Desc::shape_container;
using stride_container = typename Desc::stride_container;
using matxoplvalue = bool;
using self_type = tensor_impl_t<T, RANK, Desc, TensorData>;

// Planar complex wrappers store real/imag in separate contiguous planes:
// [real_0..real_n-1][imag_0..imag_n-1]. Since there is no contiguous T object
// at element i, operator() cannot return a true T&. This proxy provides
// reference-like read/write semantics for expression assignment paths.
struct PlanarComplexProxy {
self_type *self;
index_t offset;

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator T() const
{
return self->LoadPlanarComplex(offset);
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const T &rhs)
{
self->StorePlanarComplex(offset, rhs);
return *this;
}

template <typename U>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const U &rhs)
requires requires(const U &u) { u.real(); u.imag(); }
{
T tmp{};
tmp.real(rhs.real());
tmp.imag(rhs.imag());
self->StorePlanarComplex(offset, tmp);
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto real() const
{
return self->LoadPlanarComplex(offset).real();
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto imag() const
{
return self->LoadPlanarComplex(offset).imag();
}
};

// Type specifier for signaling this is a matx operation
using matxop = bool;

Expand Down Expand Up @@ -1031,7 +1074,8 @@ MATX_IGNORE_WARNING_POP_GCC
s[i] = this->Stride(d);
}

return Desc{std::move(n), std::move(s)};
auto new_desc = Desc{std::move(n), std::move(s)};
return new_desc;
}

__MATX_INLINE__ auto Permute(const cuda::std::array<int32_t, RANK> &dims) const
Expand Down Expand Up @@ -1302,7 +1346,12 @@ MATX_IGNORE_WARNING_POP_GCC
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);

if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
return data_.ldata_[offset];
if constexpr (is_planar_complex_v<T>) {
return LoadPlanarComplex(offset);
}
else {
return data_.ldata_[offset];
}
} else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {
return *reinterpret_cast<detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
} else {
Expand Down Expand Up @@ -1366,7 +1415,12 @@ MATX_IGNORE_WARNING_POP_GCC
const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);

if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {
return data_.ldata_[offset];
if constexpr (is_planar_complex_v<T>) {
return PlanarComplexProxy{this, offset};
}
else {
return data_.ldata_[offset];
}
} else {
return *reinterpret_cast<detail::Vector<T, EPT_int>*>(data_.ldata_ + offset);
}
Expand All @@ -1386,7 +1440,7 @@ MATX_IGNORE_WARNING_POP_GCC
template <typename CapType>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) const noexcept
{
return cuda::std::apply([&](auto &&...args) -> T {
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
return this->operator()<CapType>(args...);
}, idx);
}
Expand All @@ -1400,7 +1454,7 @@ MATX_IGNORE_WARNING_POP_GCC
template <typename CapType>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) noexcept
{
return cuda::std::apply([&](auto &&...args) -> T& {
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
return this->operator()<CapType>(args...);
}, idx);
}
Expand All @@ -1413,7 +1467,7 @@ MATX_IGNORE_WARNING_POP_GCC
*/
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) const noexcept
{
return cuda::std::apply([&](auto &&...args) -> T {
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
return this->operator()<DefaultCapabilities>(args...);
}, idx);
}
Expand All @@ -1426,7 +1480,7 @@ MATX_IGNORE_WARNING_POP_GCC
*/
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) noexcept
{
return cuda::std::apply([&](auto &&...args) -> T& {
return cuda::std::apply([&](auto &&...args) -> decltype(auto) {
return this->operator()<DefaultCapabilities>(args...);
}, idx);
}
Expand All @@ -1437,6 +1491,10 @@ MATX_IGNORE_WARNING_POP_GCC
// Since tensors are a "leaf" operator type, we will never have an operator passed to a tensor as the
// type, but only POD types.
if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) {
if constexpr (is_planar_complex_v<T>) {
return cuda::std::array<detail::ElementsPerThread, 2>{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
}

if constexpr (Rank() == 0) {
return cuda::std::array<detail::ElementsPerThread, 2>{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE};
}
Expand Down Expand Up @@ -1702,6 +1760,27 @@ MATX_IGNORE_WARNING_POP_GCC
protected:
TensorData data_;
Desc desc_;

private:
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T LoadPlanarComplex(index_t offset) const
{
using Scalar = typename T::value_type;
const auto *base = reinterpret_cast<const Scalar *>(data_.ldata_);
const index_t total = this->TotalSize();
T out{};
out.real(base[offset]);
out.imag(base[offset + total]);
return out;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void StorePlanarComplex(index_t offset, const T &v)
{
using Scalar = typename T::value_type;
auto *base = reinterpret_cast<Scalar *>(data_.ldata_);
const index_t total = this->TotalSize();
base[offset] = v.real();
base[offset + total] = v.imag();
}
Comment thread
cliffburdick marked this conversation as resolved.
};

}
Expand Down
20 changes: 14 additions & 6 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,22 @@ namespace matx
return "f32";
if constexpr (std::is_same_v<T, double>)
return "f64";
if constexpr (std::is_same_v<T, matxHalf<__half>>)
if constexpr (std::is_same_v<T, matxFp16>)
return "f16";
if constexpr (std::is_same_v<T, matxHalf<__nv_bfloat16>>)
if constexpr (std::is_same_v<T, matxBf16>)
return "bf16";
else
return "x" + std::to_string(sizeof(T)*8);
}
else {
if constexpr (std::is_same_v<T, matxFp16ComplexPlanar>)
return "f16cp";
if constexpr (std::is_same_v<T, matxBf16ComplexPlanar>)
return "bf16cp";
if constexpr (std::is_same_v<T, matxFp16Complex>)
return "f16c";
if constexpr (std::is_same_v<T, matxBf16Complex>)
return "bf16c";
if constexpr (std::is_same_v<typename T::value_type, int32_t>)
return "i32c";
if constexpr (std::is_same_v<typename T::value_type, uint32_t>)
Expand All @@ -149,10 +157,6 @@ namespace matx
return "f32c";
if constexpr (std::is_same_v<typename T::value_type, double>)
return "f64c";
if constexpr (std::is_same_v<typename T::value_type, matxHalf<__half>>)
return "f16";
if constexpr (std::is_same_v<typename T::value_type, matxHalf<__nv_bfloat16>>)
return "bf16";
else
return "x" + std::to_string(sizeof(typename T::value_type)*8) + "c";
}
Expand Down Expand Up @@ -199,6 +203,10 @@ namespace matx
return {kDLComplex, 32, 1};
if constexpr (std::is_same_v<T, matxBf16Complex>)
return {kDLComplex, 32, 1}; // Wrong, but no other choice
if constexpr (std::is_same_v<T, matxFp16ComplexPlanar>)
return {kDLComplex, 32, 1};
if constexpr (std::is_same_v<T, matxBf16ComplexPlanar>)
return {kDLComplex, 32, 1}; // Wrong, but no other choice
if constexpr (std::is_same_v<T, float>)
return {kDLFloat, 32, 1};
if constexpr (std::is_same_v<T, double>)
Expand Down
Loading