diff --git a/include/matx/core/allocator.h b/include/matx/core/allocator.h index dc776a202..7f3586ca4 100644 --- a/include/matx/core/allocator.h +++ b/include/matx/core/allocator.h @@ -128,24 +128,36 @@ struct MemTracker { matxMemoryStats.currentBytesAllocated -= bytes; + // Check if the CUDA context is still valid before attempting to free. + // During static destruction at program exit, the CUDA context may already + // be destroyed, making cudaFree/cudaFreeAsync calls fail with + // CUDA_ERROR_CONTEXT_IS_DESTROYED. + auto is_cuda_free = [&]() { + if (iter->second.kind == MATX_HOST_MALLOC_MEMORY) return true; // not CUDA + int dev; + return cudaGetDevice(&dev) == cudaSuccess; + }; + switch (iter->second.kind) { case MATX_MANAGED_MEMORY: [[fallthrough]]; case MATX_DEVICE_MEMORY: - cudaFree(ptr); + if (is_cuda_free()) cudaFree(ptr); break; case MATX_HOST_MEMORY: - cudaFreeHost(ptr); + if (is_cuda_free()) cudaFreeHost(ptr); break; case MATX_HOST_MALLOC_MEMORY: free(ptr); break; case MATX_ASYNC_DEVICE_MEMORY: - if constexpr (std::is_same_v) { - cudaFreeAsync(ptr, iter->second.stream); - } - else { - cudaFreeAsync(ptr, st.stream); + if (is_cuda_free()) { + if constexpr (std::is_same_v) { + cudaFreeAsync(ptr, iter->second.stream); + } + else { + cudaFreeAsync(ptr, st.stream); + } } break; default: diff --git a/include/matx/core/half_complex.h b/include/matx/core/half_complex.h index 525003c6e..a4168500e 100644 --- a/include/matx/core/half_complex.h +++ b/include/matx/core/half_complex.h @@ -1055,6 +1055,30 @@ tanh(const matxHalfComplex &x) using matxFp16Complex = matxHalfComplex; ///< Alias for a MatX fp16 complex wrapper using matxBf16Complex = matxHalfComplex; ///< Alias for a MatXbf16 complex wrapper +struct matxFp16ComplexPlanar : public matxFp16Complex { + using matxFp16Complex::matxFp16Complex; + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar() = default; + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar(const matxFp16Complex &rhs) : matxFp16Complex(rhs) {} + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxFp16ComplexPlanar &operator=(const matxFp16Complex &rhs) + { + this->x = rhs.x; + this->y = rhs.y; + return *this; + } +}; + +struct matxBf16ComplexPlanar : public matxBf16Complex { + using matxBf16Complex::matxBf16Complex; + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar() = default; + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar(const matxBf16Complex &rhs) : matxBf16Complex(rhs) {} + __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxBf16ComplexPlanar &operator=(const matxBf16Complex &rhs) + { + this->x = rhs.x; + this->y = rhs.y; + return *this; + } +}; + }; // namespace matx #ifndef __CUDACC_RTC__ diff --git a/include/matx/core/operator_utils.h b/include/matx/core/operator_utils.h index 6a4271523..01a3fac8b 100644 --- a/include/matx/core/operator_utils.h +++ b/include/matx/core/operator_utils.h @@ -103,7 +103,9 @@ namespace matx { if(supported) { return make_tensor(in.Data(), in.Descriptor()); } else { - return make_tensor(in.Shape(), space, stream); + // Fresh allocation is row-major contiguous; copying an affine stride + // descriptor onto it would break transforms (e.g. cuFFT batch distance). + return make_tensor(Shape(in), space, stream); } } } diff --git a/include/matx/core/tensor.h b/include/matx/core/tensor.h index fa3328188..bed481886 100644 --- a/include/matx/core/tensor.h +++ b/include/matx/core/tensor.h @@ -179,6 +179,7 @@ class tensor_t : public detail::tensor_impl_t { detail::tensor_impl_t{std::forward(desc)}, storage_{std::move(s)} { + ValidatePlanarLayoutOnCreate_(); this->SetLocalData(storage_.data()); } @@ -194,6 +195,7 @@ class tensor_t : public detail::tensor_impl_t { detail::tensor_impl_t{std::forward(desc)}, storage_{std::move(s)} { + ValidatePlanarLayoutOnCreate_(); this->SetLocalData(ldata); } @@ -210,6 +212,7 @@ class tensor_t : public detail::tensor_impl_t { detail::tensor_impl_t{std::forward(desc)}, storage_{make_owning_storage(this->desc_.TotalSize())} { + ValidatePlanarLayoutOnCreate_(); this->SetLocalData(storage_.data()); } @@ -225,6 +228,7 @@ class tensor_t : public detail::tensor_impl_t { detail::tensor_impl_t(cuda::std::array{}), storage_{make_owning_storage(1)} { + ValidatePlanarLayoutOnCreate_(); this->SetLocalData(storage_.data()); } @@ -239,6 +243,7 @@ class tensor_t : public detail::tensor_impl_t { detail::tensor_impl_t(shape), storage_{make_owning_storage(this->desc_.TotalSize())} { + ValidatePlanarLayoutOnCreate_(); this->SetLocalData(storage_.data()); } @@ -944,6 +949,7 @@ MATX_LOOP_UNROLL Reset(T *const data, ShapeType &&shape) noexcept { this->desc_.InitFromShape(std::forward(shape)); + ValidatePlanarLayoutOnCreate_(); // For non-owning storage, we need to recreate the storage object storage_ = make_non_owning_storage(data, this->desc_.TotalSize()); this->SetData(data); @@ -965,6 +971,7 @@ MATX_LOOP_UNROLL { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + ValidatePlanarLayoutOnCreate_(); // For non-owning storage, we need to recreate the storage object storage_ = make_non_owning_storage(data, this->desc_.TotalSize()); this->SetData(data); @@ -986,6 +993,7 @@ MATX_LOOP_UNROLL __MATX_HOST__ __MATX_INLINE__ void Reset(T *const data, T *const ldata) noexcept { + ValidatePlanarLayoutOnCreate_(); // For non-owning storage, we need to recreate the storage object storage_ = make_non_owning_storage(data, this->desc_.TotalSize()); this->SetData(ldata); @@ -1529,6 +1537,18 @@ MATX_LOOP_UNROLL } private: + __MATX_HOST__ __MATX_INLINE__ void ValidatePlanarLayoutOnCreate_() const + { + if constexpr (is_planar_complex_v) { + if constexpr (RANK > 0) { + MATX_ASSERT_STR(this->Stride(RANK - 1) == 1, matxInvalidDim, + "Planar complex tensors must have unit innermost stride"); + } + MATX_ASSERT_STR(this->IsContiguous(), matxInvalidDim, + "Planar complex tensors must be contiguous (non-unity strides are not supported)"); + } + } + Storage storage_; std::string name_ = std::string("tensor_") + std::to_string(RANK) + "_" + detail::to_short_str(); }; diff --git a/include/matx/core/tensor_desc.h b/include/matx/core/tensor_desc.h index d2987214c..1116ddff3 100644 --- a/include/matx/core/tensor_desc.h +++ b/include/matx/core/tensor_desc.h @@ -41,6 +41,7 @@ #include #include #include "matx/core/error.h" +#include "matx/core/type_utils.h" namespace matx { diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index 32507262a..955f85a4d 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -99,9 +99,52 @@ class tensor_impl_t { using data_type = TensorData; using shape_type = typename Desc::shape_type; using stride_type = typename Desc::stride_type; + using shape_container = typename Desc::shape_container; + using stride_container = typename Desc::stride_container; using matxoplvalue = bool; using self_type = tensor_impl_t; + // Planar complex wrappers store real/imag in separate contiguous planes: + // [real_0..real_n-1][imag_0..imag_n-1]. Since there is no contiguous T object + // at element i, operator() cannot return a true T&. This proxy provides + // reference-like read/write semantics for expression assignment paths. + struct PlanarComplexProxy { + self_type *self; + index_t offset; + + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator T() const + { + return self->LoadPlanarComplex(offset); + } + + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const T &rhs) + { + self->StorePlanarComplex(offset, rhs); + return *this; + } + + template + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ PlanarComplexProxy &operator=(const U &rhs) + requires requires(const U &u) { u.real(); u.imag(); } + { + T tmp{}; + tmp.real(rhs.real()); + tmp.imag(rhs.imag()); + self->StorePlanarComplex(offset, tmp); + return *this; + } + + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto real() const + { + return self->LoadPlanarComplex(offset).real(); + } + + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto imag() const + { + return self->LoadPlanarComplex(offset).imag(); + } + }; + // Type specifier for signaling this is a matx operation using matxop = bool; @@ -1031,7 +1074,8 @@ MATX_IGNORE_WARNING_POP_GCC s[i] = this->Stride(d); } - return Desc{std::move(n), std::move(s)}; + auto new_desc = Desc{std::move(n), std::move(s)}; + return new_desc; } __MATX_INLINE__ auto Permute(const cuda::std::array &dims) const @@ -1302,7 +1346,12 @@ MATX_IGNORE_WARNING_POP_GCC const index_t offset = GetOffsetOptimized(indices...); if constexpr (CapType::ept == detail::ElementsPerThread::ONE) { - return data_.ldata_[offset]; + if constexpr (is_planar_complex_v) { + return LoadPlanarComplex(offset); + } + else { + return data_.ldata_[offset]; + } } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) { return *reinterpret_cast*>(data_.ldata_ + offset); } else { @@ -1366,7 +1415,12 @@ MATX_IGNORE_WARNING_POP_GCC const index_t offset = GetOffsetOptimized(indices...); if constexpr (CapType::ept == detail::ElementsPerThread::ONE) { - return data_.ldata_[offset]; + if constexpr (is_planar_complex_v) { + return PlanarComplexProxy{this, offset}; + } + else { + return data_.ldata_[offset]; + } } else { return *reinterpret_cast*>(data_.ldata_ + offset); } @@ -1386,7 +1440,7 @@ MATX_IGNORE_WARNING_POP_GCC template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) const noexcept { - return cuda::std::apply([&](auto &&...args) -> T { + return cuda::std::apply([&](auto &&...args) -> decltype(auto) { return this->operator()(args...); }, idx); } @@ -1400,7 +1454,7 @@ MATX_IGNORE_WARNING_POP_GCC template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) noexcept { - return cuda::std::apply([&](auto &&...args) -> T& { + return cuda::std::apply([&](auto &&...args) -> decltype(auto) { return this->operator()(args...); }, idx); } @@ -1413,7 +1467,7 @@ MATX_IGNORE_WARNING_POP_GCC */ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) const noexcept { - return cuda::std::apply([&](auto &&...args) -> T { + return cuda::std::apply([&](auto &&...args) -> decltype(auto) { return this->operator()(args...); }, idx); } @@ -1426,7 +1480,7 @@ MATX_IGNORE_WARNING_POP_GCC */ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) noexcept { - return cuda::std::apply([&](auto &&...args) -> T& { + return cuda::std::apply([&](auto &&...args) -> decltype(auto) { return this->operator()(args...); }, idx); } @@ -1437,6 +1491,10 @@ MATX_IGNORE_WARNING_POP_GCC // Since tensors are a "leaf" operator type, we will never have an operator passed to a tensor as the // type, but only POD types. if constexpr (Cap == detail::OperatorCapability::ELEMENTS_PER_THREAD) { + if constexpr (is_planar_complex_v) { + return cuda::std::array{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE}; + } + if constexpr (Rank() == 0) { return cuda::std::array{detail::ElementsPerThread::ONE, detail::ElementsPerThread::ONE}; } @@ -1702,6 +1760,27 @@ MATX_IGNORE_WARNING_POP_GCC protected: TensorData data_; Desc desc_; + + private: + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T LoadPlanarComplex(index_t offset) const + { + using Scalar = typename T::value_type; + const auto *base = reinterpret_cast(data_.ldata_); + const index_t total = this->TotalSize(); + T out{}; + out.real(base[offset]); + out.imag(base[offset + total]); + return out; + } + + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void StorePlanarComplex(index_t offset, const T &v) + { + using Scalar = typename T::value_type; + auto *base = reinterpret_cast(data_.ldata_); + const index_t total = this->TotalSize(); + base[offset] = v.real(); + base[offset + total] = v.imag(); + } }; } diff --git a/include/matx/core/tensor_utils.h b/include/matx/core/tensor_utils.h index 7b269bf2c..0b66df76d 100644 --- a/include/matx/core/tensor_utils.h +++ b/include/matx/core/tensor_utils.h @@ -129,14 +129,22 @@ namespace matx return "f32"; if constexpr (std::is_same_v) return "f64"; - if constexpr (std::is_same_v>) + if constexpr (std::is_same_v) return "f16"; - if constexpr (std::is_same_v>) + if constexpr (std::is_same_v) return "bf16"; else return "x" + std::to_string(sizeof(T)*8); } else { + if constexpr (std::is_same_v) + return "f16cp"; + if constexpr (std::is_same_v) + return "bf16cp"; + if constexpr (std::is_same_v) + return "f16c"; + if constexpr (std::is_same_v) + return "bf16c"; if constexpr (std::is_same_v) return "i32c"; if constexpr (std::is_same_v) @@ -149,10 +157,6 @@ namespace matx return "f32c"; if constexpr (std::is_same_v) return "f64c"; - if constexpr (std::is_same_v>) - return "f16"; - if constexpr (std::is_same_v>) - return "bf16"; else return "x" + std::to_string(sizeof(typename T::value_type)*8) + "c"; } @@ -199,6 +203,10 @@ namespace matx return {kDLComplex, 32, 1}; if constexpr (std::is_same_v) return {kDLComplex, 32, 1}; // Wrong, but no other choice + if constexpr (std::is_same_v) + return {kDLComplex, 32, 1}; + if constexpr (std::is_same_v) + return {kDLComplex, 32, 1}; // Wrong, but no other choice if constexpr (std::is_same_v) return {kDLFloat, 32, 1}; if constexpr (std::is_same_v) diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index bd6c70569..09979003f 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -240,6 +240,10 @@ template constexpr MatXDataType_t TypeToInt() return MATX_TYPE_COMPLEX_FP16; if constexpr (std::is_same_v) return MATX_TYPE_COMPLEX_BF16; + if constexpr (std::is_same_v) + return MATX_TYPE_COMPLEX_FP16; + if constexpr (std::is_same_v) + return MATX_TYPE_COMPLEX_BF16; if constexpr (std::is_same_v) return MATX_TYPE_FP32; if constexpr (std::is_same_v) @@ -348,6 +352,12 @@ template constexpr cudaDataType_t MatXTypeToCudaType() if constexpr (std::is_same_v) { return CUDA_C_16BF; } + if constexpr (std::is_same_v) { + return CUDA_C_16F; + } + if constexpr (std::is_same_v) { + return CUDA_C_16BF; + } return CUDA_C_32F; } @@ -357,7 +367,9 @@ template constexpr cublasComputeType_t MatXTypeToCudaComputeType() if constexpr (std::is_same_v> || std::is_same_v || is_matx_half_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v || + std::is_same_v) { return CUBLAS_COMPUTE_32F; } if constexpr (std::is_same_v> || diff --git a/include/matx/core/type_utils_both.h b/include/matx/core/type_utils_both.h index cf4c8f955..db99569cc 100644 --- a/include/matx/core/type_utils_both.h +++ b/include/matx/core/type_utils_both.h @@ -417,14 +417,18 @@ template concept is_complex = cuda::std::is_same_v, cuda::std::complex> || cuda::std::is_same_v, cuda::std::complex> || cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; // Legacy variable for backwards compatibility template inline constexpr bool is_complex_v = cuda::std::is_same_v, cuda::std::complex> || cuda::std::is_same_v, cuda::std::complex> || cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; namespace detail { template struct scalar_to_complex { @@ -505,11 +509,13 @@ struct inner_op_type_t { */ template concept is_bf16_type = cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; // Legacy variable for backwards compatibility template inline constexpr bool is_bf16_type_v = cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; /** @@ -519,11 +525,13 @@ inline constexpr bool is_bf16_type_v = cuda::std::is_same_v */ template concept is_fp16_type = cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; // Legacy variable for backwards compatibility template inline constexpr bool is_fp16_type_v = cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; /** @@ -564,12 +572,21 @@ inline constexpr bool is_matx_shape_v = requires { typename remove_cvref_t::m */ template concept is_complex_half = cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; // Legacy variable for backwards compatibility template inline constexpr bool is_complex_half_v = cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; + +template +inline constexpr bool is_planar_complex_v = + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; /** * @brief Tests if a type is a half precision floating point @@ -619,14 +636,18 @@ template concept is_matx_type = cuda::std::is_same_v, matxFp16> || cuda::std::is_same_v, matxBf16> || cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; // Legacy variable for backwards compatibility template inline constexpr bool is_matx_type_v = cuda::std::is_same_v, matxFp16> || cuda::std::is_same_v, matxBf16> || cuda::std::is_same_v, matxFp16Complex> || - cuda::std::is_same_v, matxBf16Complex>; + cuda::std::is_same_v, matxBf16Complex> || + cuda::std::is_same_v, matxFp16ComplexPlanar> || + cuda::std::is_same_v, matxBf16ComplexPlanar>; namespace detail { template @@ -993,9 +1014,19 @@ namespace detail { using type = __half; }; + template <> + struct inner_precision { + using type = __half; + }; + template <> struct inner_precision { using type = __nv_bfloat16; + }; + + template <> + struct inner_precision { + using type = __nv_bfloat16; }; } diff --git a/include/matx/core/utils.h b/include/matx/core/utils.h index 93e9faa4a..ad4819ee8 100644 --- a/include/matx/core/utils.h +++ b/include/matx/core/utils.h @@ -316,6 +316,12 @@ __MATX_INLINE__ __MATX_HOST__ std::string type_to_string() else if constexpr (std::is_same_v) { return "matx::matxBf16Complex"; } + else if constexpr (std::is_same_v) { + return "matx::matxFp16ComplexPlanar"; + } + else if constexpr (std::is_same_v) { + return "matx::matxBf16ComplexPlanar"; + } // CCCL complex types else if constexpr (std::is_same_v>) { return "cuda::std::complex"; @@ -372,6 +378,12 @@ __MATX_INLINE__ __MATX_HOST__ std::string type_to_string_c_name() else if constexpr (std::is_same_v) { return "matx_matxBf16Complex"; } + else if constexpr (std::is_same_v) { + return "matx_matxFp16ComplexPlanar"; + } + else if constexpr (std::is_same_v) { + return "matx_matxBf16ComplexPlanar"; + } else { return type_to_string(); } diff --git a/include/matx/operators/base_operator.h b/include/matx/operators/base_operator.h index bfba5a3a0..f16c94a7e 100644 --- a/include/matx/operators/base_operator.h +++ b/include/matx/operators/base_operator.h @@ -203,7 +203,14 @@ namespace matx MATX_THROW(matxInvalidParameter, "Possible aliased memory detected: LHS and RHS memory ranges overlap"); } - if (tp->get_lhs().IsContiguous() && tp->get_rhs().IsContiguous() && tp->get_lhs().Rank() == tp->get_rhs().Rank()) { + using lhs_value_type = remove_cvref_t; + using rhs_value_type = remove_cvref_t; + constexpr bool same_value_type = std::is_same_v; + + if (same_value_type && + tp->get_lhs().IsContiguous() && + tp->get_rhs().IsContiguous() && + tp->get_lhs().Rank() == tp->get_rhs().Rank()) { MATX_ASSERT_STR(tp->get_lhs().Bytes() >= tp->get_rhs().Bytes(), matxInvalidSize, "LHS tensor is smaller than RHS tensor in assignment"); MATX_LOG_TRACE("Copying {} bytes from {} to {} using cudaMemcpyAsync", tp->get_lhs().Bytes(), reinterpret_cast(tp->get_rhs().Data()), reinterpret_cast(tp->get_lhs().Data())); diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index c3063abfa..ba04e693f 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -446,11 +446,26 @@ namespace matx } } else { - if constexpr (Direction == detail::FFTDirection::FORWARD) { - fft_impl(permute(cuda::std::get<0>(out), perm_), permute(a_, perm_), fft_size_, norm_, ex); + bool perm_is_identity = true; + for (int32_t i = 0; i < Rank(); i++) { + if (perm_[static_cast(i)] != i) { + perm_is_identity = false; + break; + } } - else { - ifft_impl(permute(cuda::std::get<0>(out), perm_), permute(a_, perm_), fft_size_, norm_, ex); + auto &tout = cuda::std::get<0>(out); + if (perm_is_identity) { + if constexpr (Direction == detail::FFTDirection::FORWARD) { + fft_impl(tout, a_, fft_size_, norm_, ex); + } else { + ifft_impl(tout, a_, fft_size_, norm_, ex); + } + } else { + if constexpr (Direction == detail::FFTDirection::FORWARD) { + fft_impl(permute(tout, perm_), permute(a_, perm_), fft_size_, norm_, ex); + } else { + ifft_impl(permute(tout, perm_), permute(a_, perm_), fft_size_, norm_, ex); + } } } } @@ -788,11 +803,26 @@ namespace matx } } else { - if constexpr (Direction == detail::FFTDirection::FORWARD) { - fft2_impl(permute(cuda::std::get<0>(out), perm_), permute(a_, perm_), norm_, ex); + bool perm_is_identity = true; + for (int32_t i = 0; i < Rank(); i++) { + if (perm_[static_cast(i)] != i) { + perm_is_identity = false; + break; + } } - else { - ifft2_impl(permute(cuda::std::get<0>(out), perm_), permute(a_, perm_), norm_, ex); + auto &tout = cuda::std::get<0>(out); + if (perm_is_identity) { + if constexpr (Direction == detail::FFTDirection::FORWARD) { + fft2_impl(tout, a_, norm_, ex); + } else { + ifft2_impl(tout, a_, norm_, ex); + } + } else { + if constexpr (Direction == detail::FFTDirection::FORWARD) { + fft2_impl(permute(tout, perm_), permute(a_, perm_), norm_, ex); + } else { + ifft2_impl(permute(tout, perm_), permute(a_, perm_), norm_, ex); + } } } } diff --git a/include/matx/operators/interleaved.h b/include/matx/operators/interleaved.h index ba475289a..72a2e3929 100644 --- a/include/matx/operators/interleaved.h +++ b/include/matx/operators/interleaved.h @@ -35,6 +35,7 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" +#include "matx/operators/planar.h" namespace matx { @@ -110,6 +111,8 @@ namespace matx __MATX_INLINE__ std::string str() const { return "interleaved(" + op_.str() + ")"; } + __MATX_INLINE__ auto InnerOp() const { return op_; } + __MATX_INLINE__ ComplexInterleavedOp(const T1 &op) : op_(op) { MATX_LOG_TRACE("{} constructor: rank={}", str(), Rank()); static_assert(!is_complex_v>, "Complex interleaved op only works on scalar input types"); @@ -244,4 +247,17 @@ namespace matx static_assert(!is_complex_v>, "Input to interleaved operator must be real-valued"); return detail::ComplexInterleavedOp(t); } + + template + auto interleaved(const detail::ComplexPlanarOp &t) + { + return t.InnerOp(); + } + + template + auto planar(const detail::ComplexInterleavedOp &t) + { + return t.InnerOp(); + } + } // end namespace matx diff --git a/include/matx/operators/matmul.h b/include/matx/operators/matmul.h index 6c20ac126..ac73d3f2a 100644 --- a/include/matx/operators/matmul.h +++ b/include/matx/operators/matmul.h @@ -355,7 +355,19 @@ namespace matx } } else if constexpr (!std::is_same_v) { - matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_); + bool perm_is_identity = true; + for (int32_t i = 0; i < Rank(); i++) { + if (perm_[static_cast(i)] != i) { + perm_is_identity = false; + break; + } + } + auto &tout = cuda::std::get<0>(out); + if (perm_is_identity) { + matmul_impl(tout, a_, b_, ex, alpha_, beta_); + } else { + matmul_impl(permute(tout, perm_), a_, b_, ex, alpha_, beta_); + } } else { matmul_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_); diff --git a/include/matx/operators/planar.h b/include/matx/operators/planar.h index cb3e19325..e7515b4fd 100644 --- a/include/matx/operators/planar.h +++ b/include/matx/operators/planar.h @@ -107,6 +107,8 @@ namespace matx __MATX_INLINE__ std::string str() const { return "planar(" + op_.str() + ")"; } + __MATX_INLINE__ auto InnerOp() const { return op_; } + __MATX_INLINE__ ComplexPlanarOp(const T1 &op) : op_(op) { static_assert(is_complex_v>, "Complex planar op only works on complex types"); static_assert(Rank() > 0); diff --git a/include/matx/operators/reshape.h b/include/matx/operators/reshape.h index f9bb25092..40dfb0d8e 100644 --- a/include/matx/operators/reshape.h +++ b/include/matx/operators/reshape.h @@ -32,12 +32,23 @@ #pragma once +#include #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" namespace matx { + namespace detail { + template + struct is_initializer_list : std::false_type {}; + + template + struct is_initializer_list> : std::true_type {}; + + template + inline constexpr bool is_initializer_list_v = is_initializer_list::value; + } // namespace detail /** * logically reshapes dimensions of a tensor/operator * TotalSize for reshape and input operator must match @@ -239,6 +250,10 @@ MATX_LOOP_UNROLL else if constexpr (Cap == OperatorCapability::DYN_SHM_SIZE) { return detail::get_operator_capability(op_, in); } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; + return combine_capabilities(my_cap, detail::get_operator_capability(op_, in)); + } else { auto self_has_cap = capability_attributes::default_value; return combine_capabilities(self_has_cap, detail::get_operator_capability(op_, in)); @@ -292,7 +307,8 @@ MATX_LOOP_UNROLL * @return reshaped operator */ template - requires (!cuda::std::is_array_v>) + requires (!cuda::std::is_array_v> && + !detail::is_initializer_list_v>) __MATX_INLINE__ auto reshape(const T &op, ShapeType &&s) { return detail::ReshapeOp(op, std::forward(s)); diff --git a/include/matx/operators/set.h b/include/matx/operators/set.h index 43feb81f1..c1862ba41 100644 --- a/include/matx/operators/set.h +++ b/include/matx/operators/set.h @@ -114,9 +114,17 @@ class set : public BaseOp> { // functions, so we have to make a separate one. template __MATX_DEVICE__ __MATX_HOST__ inline auto _internal_mapply(Ts&&... args) const noexcept { - const auto r = detail::get_value(op_, args...); - out_(args...) = r; - return r; + if constexpr (is_planar_complex_v) { + const auto r = detail::get_value( + static_cast(op_), args...); + out_.template operator()(args...) = r; + return r; + } + else { + const auto r = detail::get_value(static_cast(op_), args...); + out_.template operator()(args...) = r; + return r; + } } template @@ -133,17 +141,25 @@ class set : public BaseOp> { template __MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(Is... indices) const noexcept { - const auto in_val = detail::get_value(op_, indices...); - using out_type = decltype(out_.template operator()(indices...)); - - if constexpr (!is_vector_v && is_vector_v) { - Vector, static_cast(CapType::ept)> vec{in_val}; - out_.template operator()(indices...) = vec; + if constexpr (is_planar_complex_v) { + const auto in_val = detail::get_value( + static_cast(op_), indices...); + out_.template operator()(indices...) = in_val; + return in_val; } else { - out_.template operator()(indices...) = in_val; + const auto in_val = detail::get_value(static_cast(op_), indices...); + using out_type = decltype(out_.template operator()(indices...)); + + if constexpr (!is_vector_v && is_vector_v) { + Vector, static_cast(CapType::ept)> vec{in_val}; + out_.template operator()(indices...) = vec; + } + else { + out_.template operator()(indices...) = in_val; + } + return in_val; } - return in_val; } #ifdef MATX_EN_JIT @@ -271,7 +287,18 @@ class set : public BaseOp> { #else return false; #endif - } + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + // Only force scalar EPT for planar-complex outputs. Non-planar SetOp + // should retain normal vectorization negotiation with operands. + const auto my_cap = is_planar_complex_v + ? cuda::std::array{ElementsPerThread::ONE, + ElementsPerThread::ONE} + : capability_attributes::default_value; + return combine_capabilities( + my_cap, detail::get_operator_capability(out_, in), + detail::get_operator_capability(op_, in)); + } else { auto self_has_cap = capability_attributes::default_value; diff --git a/include/matx/transforms/convert/sparse2dense_cusparse.h b/include/matx/transforms/convert/sparse2dense_cusparse.h index 30a13679f..1d1d7b7ea 100644 --- a/include/matx/transforms/convert/sparse2dense_cusparse.h +++ b/include/matx/transforms/convert/sparse2dense_cusparse.h @@ -249,7 +249,20 @@ void sparse2dense_impl(OutputTensorType &O, const InputTensorType &a, std::is_same_v> || std::is_same_v>, "unsupported data type"); - MATX_ASSERT(o.Stride(RANKO - 1) == 1, matxInvalidParameter); + + // cuSPARSE sparse-to-dense conversion requires unit innermost stride. + // For transformed/view outputs (e.g., transpose(output)), materialize + // into a contiguous temporary and copy back. + if (o.Stride(RANKO - 1) != 1) { + auto o_contig = + make_tensor(o.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + sparse2dense_impl(o_contig, a, exec); + (o = o_contig).run(stream); + if (!o.isSameView(O)) { + (O = o).run(stream); + } + return; + } // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/fft/fft_cuda.h b/include/matx/transforms/fft/fft_cuda.h index 4b31dcc18..efd996654 100644 --- a/include/matx/transforms/fft/fft_cuda.h +++ b/include/matx/transforms/fft/fft_cuda.h @@ -236,7 +236,7 @@ template class matxCUDAFFTPlan_t params.idist = (RANK == 1) ? 1 : i.Stride(RANK - 2); params.odist = (RANK == 1) ? 1 : o.Stride(RANK - 2); - if constexpr (is_complex_half_v || is_complex_half_v) { + if constexpr (is_complex_half_v || is_complex_half_v) { if ((params.n[0] & (params.n[0] - 1)) != 0) { MATX_THROW(matxInvalidDim, "Half precision only supports power of two transforms"); diff --git a/include/matx/transforms/fft/fft_fftw.h b/include/matx/transforms/fft/fft_fftw.h index 48b4a2854..8987cde48 100644 --- a/include/matx/transforms/fft/fft_fftw.h +++ b/include/matx/transforms/fft/fft_fftw.h @@ -76,6 +76,7 @@ struct FftFFTWParams_t { bool is_fp32; bool in_place; detail::FFTDirection dir; + int num_threads = 1; }; template @@ -227,7 +228,8 @@ struct FftFFTWParamsKeyHash { (std::hash()(k.fft_rank)) + (std::hash()(k.batch)) + (std::hash()(k.istride)) + (std::hash()(static_cast(k.dir))) + - (std::hash()(static_cast(k.is_fp32))); + (std::hash()(static_cast(k.is_fp32))) + + (std::hash()(static_cast(k.num_threads))); } }; @@ -240,6 +242,7 @@ struct FftFFTWParamsKeyEq { return l.n[0] == t.n[0] && l.n[1] == t.n[1] && l.batch == t.batch && l.dir == t.dir && l.fft_rank == t.fft_rank && l.is_fp32 == t.is_fp32 && l.in_place == t.in_place && + l.num_threads == t.num_threads && l.inembed[0] == t.inembed[0] && l.inembed[1] == t.inembed[1] && l.onembed[0] == t.onembed[0] && l.onembed[1] == t.onembed[1] && l.istride == t.istride && l.ostride == t.ostride && @@ -599,6 +602,7 @@ template class matxFFTWPlan_t { // Get parameters required by these tensors auto params = GetFFTParams(out, in, 1, dir); + params.num_threads = exec.GetNumThreads(); fft_exec(out, in, params, dir, exec); @@ -651,6 +655,7 @@ template class matxFFTWPlan_t { // Get parameters required by these tensors auto params = GetFFTParams(out, in, 2, dir); + params.num_threads = exec.GetNumThreads(); fft_exec(out, in, params, dir, exec); diff --git a/include/matx/transforms/matmul/matmul_cuda.h b/include/matx/transforms/matmul/matmul_cuda.h index 710600585..a9e2b1390 100644 --- a/include/matx/transforms/matmul/matmul_cuda.h +++ b/include/matx/transforms/matmul/matmul_cuda.h @@ -94,7 +94,9 @@ constexpr bool CompatibleGemmCUDATypes() { std::is_same_v> || std::is_same_v || std::is_same_v || - std::is_same_v; + std::is_same_v || + std::is_same_v || + std::is_same_v; } // Accumulator type different from A/B @@ -138,6 +140,9 @@ struct MatMulCUDAParams_t { MatXDataType_t dtype; cublasOperation_t opA; cublasOperation_t opB; + bool a_planar = false; + bool b_planar = false; + bool c_planar = false; }; template (); params.prov = PROV; params.rank = c.Rank(); + params.a_planar = is_planar_complex_v; + params.b_planar = is_planar_complex_v; + params.c_planar = is_planar_complex_v; // Batches params.batch = 1; @@ -409,19 +417,25 @@ class MatMulCUDAHandle_t { } // for complex half we have copied to planar row-major - if (is_complex_half_v) { - params.ldb = b.Size(TensorTypeB::Rank()-1); + if constexpr (is_complex_half_v) { + params.lda = a.Size(TensorTypeA::Rank()-1); } - // for complex half we have copied to planar row-major if constexpr (is_complex_half_v) { - params.lda = a.Size(TensorTypeA::Rank()-1); + params.ldb = b.Size(TensorTypeB::Rank()-1); } params.c_rows = params.a_rows; params.c_cols = params.b_cols; params.ldc = c.Stride(RANK - 2); + // For complex half paths we launch as planar row-major. Use compact + // row-major leading dimension so planar C metadata matches what cuBLAS + // expects, even when the original tensor type is planar. + if constexpr (is_complex_half_v) { + params.ldc = c.Size(RANK - 1); + } + } else if constexpr (PROV == PROVIDER_TYPE_CUTLASS) { params.opA = CUBLAS_OP_N; @@ -770,37 +784,58 @@ class MatMulCUDAHandle_t { // If the tensors are complex half precision, we need to do a planar // transform since all libraries expect this format at the moment. if constexpr (is_complex_half_v) { + constexpr bool a_is_planar = is_planar_complex_v; + constexpr bool b_is_planar = is_planar_complex_v; + constexpr bool c_is_planar = is_planar_complex_v; + + if (!a_is_planar) { + auto a_shape = a.Shape(); + *(a_shape.begin() + a.Rank() - 2) = a.Size(a.Rank() - 2) * 2; + if (a_hp == nullptr) { + matxAlloc(&a_hp, a.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + } + auto a_planar = make_tensor( + reinterpret_cast(a_hp), a_shape, false); - auto a_shape = a.Shape(); - *(a_shape.begin() + a.Rank() - 2) = a.Size(a.Rank() - 2) * 2; - if (a_hp == nullptr) { - matxAlloc(&a_hp, a.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); - } - auto a_planar = make_tensor(reinterpret_cast(a_hp), a_shape, false); + // Convert A to planar layout + (a_planar = planar(a)).run(stream); - auto b_shape = b.Shape(); - *(b_shape.begin() + b.Rank() - 2) = b.Size(b.Rank() - 2) * 2; - if (b_hp == nullptr) { - matxAlloc(&b_hp, b.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + // update pointers to planar data. + // must use Reset because types for planar are different + a_adj.Reset(reinterpret_cast(a_planar.Data())); } - auto b_planar = make_tensor(reinterpret_cast(b_hp), b_shape, false); - auto c_shape = c.Shape(); - *(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2; - if (c_hp == nullptr) { - matxAlloc(&c_hp, c.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); - } - auto c_planar = make_tensor(reinterpret_cast(c_hp), c_shape, false); + if (!b_is_planar) { + auto b_shape = b.Shape(); + *(b_shape.begin() + b.Rank() - 2) = b.Size(b.Rank() - 2) * 2; + if (b_hp == nullptr) { + matxAlloc(&b_hp, b.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + } + auto b_planar = make_tensor( + reinterpret_cast(b_hp), b_shape, false); + + // Convert B to planar layout + (b_planar = planar(b)).run(stream); - // Convert A/B to planar layout - (a_planar = planar(a)).run(stream); - (b_planar = planar(b)).run(stream); + // update pointers to planar data. + // must use Reset because types for planar are different + b_adj.Reset(reinterpret_cast(b_planar.Data())); + } - // update pointers to planar data. - // must use Reset because types for planar are different - a_adj.Reset(reinterpret_cast(a_planar.Data())); - b_adj.Reset(reinterpret_cast(b_planar.Data())); - c_adj.Reset(reinterpret_cast(c_planar.Data())); + if (!c_is_planar) { + auto c_shape = c.Shape(); + *(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2; + if (c_hp == nullptr) { + matxAlloc(&c_hp, c.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + } + auto c_planar = make_tensor( + reinterpret_cast(c_hp), c_shape, false); + c_adj.Reset(reinterpret_cast(c_planar.Data())); + } + else { + // Keep C adjustment explicit for the planar-output path. + c_adj.Reset(c.Data()); + } } // Prep for batch looping @@ -999,13 +1034,15 @@ class MatMulCUDAHandle_t { // If the tensors are complex half precisions, we need to convert C back to // interleaved format and free all temporary buffers if constexpr (is_complex_half_v) { - auto c_shape = c.Shape(); - *(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2; - auto c_planar = make_tensor( - reinterpret_cast(c_adj.Data()), c_shape); - - // Convert A/B to planar layout - (c = interleaved(c_planar)).run(stream); + constexpr bool c_is_planar = is_planar_complex_v; + if (!c_is_planar) { + auto c_shape = c.Shape(); + *(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2; + auto c_planar = make_tensor( + reinterpret_cast(c_adj.Data()), c_shape); + + (c = interleaved(c_planar)).run(stream); + } } } @@ -1089,6 +1126,9 @@ struct MatMulCUDAParamsKeyHash { return std::hash()(k.m) + std::hash()(k.n) + std::hash()(k.k) + std::hash()(k.batch) + std::hash()(k.prov) + + std::hash()(static_cast(k.a_planar)) + + std::hash()(static_cast(k.b_planar)) + + std::hash()(static_cast(k.c_planar)) + std::hash()((size_t)k.stream); } }; @@ -1109,7 +1149,10 @@ struct MatMulCUDAParamsKeyEq { l.stream == t.stream && l.lda == t.lda && l.ldb == t.ldb && l.ldc == t.ldc && l.batch == t.batch && l.prov == t.prov && l.dtype == t.dtype && l.opA == t.opA && - l.opB == t.opB && l.rank == t.rank; + l.opB == t.opB && l.rank == t.rank && + l.a_planar == t.a_planar && + l.b_planar == t.b_planar && + l.c_planar == t.c_planar; } }; diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h index 83ca78240..5d2adcf3f 100644 --- a/include/matx/transforms/matmul/matmul_cusparse.h +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -310,8 +310,25 @@ void sparse_matmul_impl(TensorTypeC &C, const TensorTypeA &a, c.Size(RANKC - 1) == b.Size(RANKB - 1) && c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); - MATX_ASSERT(b.Stride(RANKB - 1) == 1 && c.Stride(RANKC - 1) == 1, - matxInvalidParameter); + + // cuSPARSE SpMM requires unit innermost stride for dense inputs/outputs. + // If either view is not compatible (e.g., transformed/permuted output), + // materialize contiguous temporaries, execute, then copy back. + if (b.Stride(RANKB - 1) != 1 || c.Stride(RANKC - 1) != 1) { + auto b_contig = + make_tensor(b.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto c_contig = + make_tensor(c.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + (b_contig = b).run(stream); + + sparse_matmul_impl(c_contig, a, b_contig, exec, alpha, beta); + + (c = c_contig).run(stream); + if (!c.isSameView(C)) { + (C = c).run(stream); + } + return; + } // Get parameters required by these tensors (for caching). auto params = diff --git a/test/00_operators/planar_test.cu b/test/00_operators/planar_test.cu index c764f0480..db88fa73b 100644 --- a/test/00_operators/planar_test.cu +++ b/test/00_operators/planar_test.cu @@ -2,6 +2,7 @@ #include "matx.h" #include "test_types.h" #include "utilities.h" +#include using namespace matx; using namespace matx::test; @@ -37,3 +38,44 @@ TYPED_TEST(OperatorTestsComplexTypesAllExecs, PlanarTransform) } MATX_EXIT_HANDLER(); } + +namespace { +template +void ValidatePlanarTensorOperatorLayout() +{ + constexpr index_t m = 3; + constexpr index_t k = 5; + constexpr index_t mk = m * k; + using ScalarType = typename PlanarType::value_type; + + tensor_t t{{m, k}}; + std::vector raw(static_cast(mk * 2)); + + for (index_t i = 0; i < mk; i++) { + raw[static_cast(i)] = ScalarType{static_cast(i + 1)}; + raw[static_cast(i + mk)] = ScalarType{static_cast(-(i + 1))}; + } + + ASSERT_EQ(cudaMemcpy(t.Data(), raw.data(), raw.size() * sizeof(ScalarType), + cudaMemcpyHostToDevice), + cudaSuccess); + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < k; j++) { + const index_t idx = i * k + j; + const auto v = t(i, j); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(v.real(), raw[static_cast(idx)])); + EXPECT_TRUE( + MatXUtils::MatXTypeCompare(v.imag(), raw[static_cast(idx + mk)])); + } + } +} +} // namespace + +TEST(PlanarHalfComplexTypes, TensorOperatorUsesPlanarStorageForFp16AndBf16) +{ + MATX_ENTER_HANDLER(); + ValidatePlanarTensorOperatorLayout(); + ValidatePlanarTensorOperatorLayout(); + MATX_EXIT_HANDLER(); +} diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index b435c93d4..796218d14 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -294,7 +294,6 @@ TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT2Axis) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; -; const int d1 = 8; const int d2 = 16; const int d3 = 32; diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index de008f120..286c47cdd 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -35,6 +35,7 @@ #include "test_types.h" #include "utilities.h" #include "gtest/gtest.h" +#include using namespace matx; @@ -81,10 +82,14 @@ class MatMulTestFloatNonHalfTypes : public MatMulTest { template class MatMulTestFloatNonComplexTypes : public MatMulTest { }; +template +class MatMulTestComplexHalfPlanarTypes : public MatMulTest { +}; TYPED_TEST_SUITE(MatMulTestFloatTypes, MatXTypesFloatAllExecs); TYPED_TEST_SUITE(MatMulTestFloatNonHalfTypes, MatXFloatNonHalfTypesAllExecs); TYPED_TEST_SUITE(MatMulTestFloatNonComplexTypes, MatXTypesFloatNonComplexAllExecs); +TYPED_TEST_SUITE(MatMulTestComplexHalfPlanarTypes, MatXComplexHalfPlanarTypesAllExecs); template struct float_to_complex @@ -107,6 +112,22 @@ struct float_to_complex template using float_to_complex_t = typename float_to_complex::type; +template +struct planar_to_interleaved; + +template <> +struct planar_to_interleaved { + using type = matxFp16Complex; +}; + +template <> +struct planar_to_interleaved { + using type = matxBf16Complex; +}; + +template +using planar_to_interleaved_t = typename planar_to_interleaved::type; + TYPED_TEST(MatMulTestFloatTypes, SmallRect) { MATX_ENTER_HANDLER(); @@ -549,6 +570,138 @@ TYPED_TEST(MatMulTestFloatNonComplexTypes, MixedTypes) MATX_EXIT_HANDLER(); } +TYPED_TEST(MatMulTestComplexHalfPlanarTypes, ComplexHalfPlanarLayoutAnnotation) +{ + MATX_ENTER_HANDLER(); + using PlanarTestType = cuda::std::tuple_element_t<0, TypeParam>; + using TestType = planar_to_interleaved_t; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } + else { + constexpr index_t m = 4; + constexpr index_t k = 8; + constexpr index_t n = 16; + + tensor_t a{{m, k}}; + tensor_t b{{k, n}}; + tensor_t c_ref{{m, n}}; + tensor_t c_from_planar{{m, n}}; + + this->pb->template InitAndRunTVGenerator( + "00_transforms", "matmul_operators", "run", {m, k, n}); + this->pb->NumpyToTensorView(a, "a"); + this->pb->NumpyToTensorView(b, "b"); + + // Interleaved reference path. + (c_ref = matmul(a, b)).run(this->exec); + + // Materialize planar typed inputs/outputs into temporary tensors before GEMM. + tensor_t a_planar_tmp{{m, k}}; + tensor_t b_planar_tmp{{k, n}}; + tensor_t c_planar_tmp{{m, n}}; + + EXPECT_TRUE(is_planar_complex_v); + EXPECT_TRUE(is_planar_complex_v); + EXPECT_TRUE(is_planar_complex_v); + + // Write interleaved tensors into planar-typed temporaries. + (a_planar_tmp = a).run(this->exec); + (b_planar_tmp = b).run(this->exec); + + // Planar path (using planar-typed temporary tensors as matmul inputs/outputs). + (c_planar_tmp = matmul(a_planar_tmp, b_planar_tmp)).run(this->exec); + (c_from_planar = c_planar_tmp).run(this->exec); + + this->exec.sync(); + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_ref(i, j), c_from_planar(i, j), this->thresh)); + } + } + } + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(MatMulTestComplexHalfPlanarTypes, ComplexHalfPlanarRawStorageMatchesInterleavedReference) +{ + MATX_ENTER_HANDLER(); + using PlanarTestType = cuda::std::tuple_element_t<0, TypeParam>; + using TestType = planar_to_interleaved_t; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } + else { + constexpr index_t m = 7; + constexpr index_t k = 9; + constexpr index_t n = 5; + constexpr index_t a_elems = m * k; + constexpr index_t b_elems = k * n; + + using ScalarType = typename PlanarTestType::value_type; + + tensor_t a_planar{{m, k}}; + tensor_t b_planar{{k, n}}; + tensor_t c_planar{{m, n}}; + + tensor_t a_interleaved{{m, k}}; + tensor_t b_interleaved{{k, n}}; + tensor_t c_interleaved{{m, n}}; + tensor_t c_from_planar{{m, n}}; + + std::vector a_raw(static_cast(a_elems * 2)); + std::vector b_raw(static_cast(b_elems * 2)); + + for (index_t i = 0; i < a_elems; i++) { + a_raw[static_cast(i)] = ScalarType{static_cast((i % 13) - 6)}; + a_raw[static_cast(i + a_elems)] = + ScalarType{static_cast((i % 9) - 4)}; + } + + for (index_t i = 0; i < b_elems; i++) { + b_raw[static_cast(i)] = ScalarType{static_cast((i % 11) - 5)}; + b_raw[static_cast(i + b_elems)] = + ScalarType{static_cast(3 - (i % 7))}; + } + + ASSERT_EQ(cudaMemcpy(a_planar.Data(), a_raw.data(), + a_raw.size() * sizeof(ScalarType), cudaMemcpyHostToDevice), + cudaSuccess); + ASSERT_EQ(cudaMemcpy(b_planar.Data(), b_raw.data(), + b_raw.size() * sizeof(ScalarType), cudaMemcpyHostToDevice), + cudaSuccess); + + EXPECT_TRUE(is_planar_complex_v); + EXPECT_TRUE(is_planar_complex_v); + EXPECT_TRUE(is_planar_complex_v); + + // Build interleaved references from the same logical planar tensors. + (a_interleaved = a_planar).run(this->exec); + (b_interleaved = b_planar).run(this->exec); + + // Interleaved path should still perform the required conversion. + (c_interleaved = matmul(a_interleaved, b_interleaved)).run(this->exec); + + // Planar typed tensors should be consumed directly by matmul. + (c_planar = matmul(a_planar, b_planar)).run(this->exec); + (c_from_planar = c_planar).run(this->exec); + + this->exec.sync(); + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_interleaved(i, j), + c_from_planar(i, j), + this->thresh)) + << "Mismatch at (" << i << ", " << j << ")"; + } + } + } + MATX_EXIT_HANDLER(); +} + TYPED_TEST(MatMulTestFloatTypes, MediumRectBatched4D) { MATX_ENTER_HANDLER(); diff --git a/test/include/test_types.h b/test/include/test_types.h index 455663565..e59800b68 100644 --- a/test/include/test_types.h +++ b/test/include/test_types.h @@ -143,6 +143,7 @@ using MatXFloatNonComplexTuple = cuda::std::tuple; using MatXBooleanTuple = cuda::std::tuple; using MatXDoubleOnlyTuple = cuda::std::tuple; +using MatXComplexHalfPlanarTuple = cuda::std::tuple; // CUDA-only types using MatXAllTypesCUDAExec = TupleToTypes::type>::type; @@ -176,5 +177,6 @@ using MatXTypesNumericAllExecs = TupleToTypes::type>::type; using MatXTypesBooleanAllExecs = TupleToTypes::type>::type; using MatXTypesCastToFloatAllExecs = TupleToTypes::type>::type; +using MatXComplexHalfPlanarTypesAllExecs = TupleToTypes::type>::type; using MatXTypesFloatNonComplexSingleThreadedHostAllExecs = TupleToTypes::type>::type;