Custom Gather-scatter Operator by CUTLASS
This blog is to log my experience of building efficient custom operator based on CUTLASS. Jump to the final implementation of gather and scatter matrix multiplication operator.
Intro
Implementing efficient CUDA kernel is challenging and requires thorough understanding of GPU architecture and takes a lot of time to design. CUTLASS provides a collection of abstractions of GEMM-based operation. It exploits the hierarchical “memory” of GPU by swizzling the data to maximize the memory bandwidth. Using CUTLASS, we can easily build operators with high performance.
Simple Example
There are some examples provided by the developers in cutlass/examples/python at main · NVIDIA/cutlass · GitHub, which tells you how to generate your own GEMM operator with custom data type and layout. For example,
import cutlass
import torch
dtype = torch.float16
plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.ColumnMajor, element_accumulator=torch.float32)
op = plan.construct()
gs_gemm = cutlass.emit.pytorch(op, name='gemm', cc=plan.cc, sourcedir='gemm_out')
Then you will find CUDA, PyBind and setup code under gemm_out/
directory. To use it, run
cd gemm_out
pip install -e . # install in the current env
Now you can directly import your operator in your code, for example,
import torch
import gemm # import your op
import cutlass
dtype = torch.float16
import random
random.seed(2024)
# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
sizes = [(M, K), (K, N), (M, N)]
return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]
A, B, C = initialize(dtype, 128, 128, 128)
cutlass_out = gemm.run(A, B, C)
This is cool, but I am going to show you that cutlass.Gemm
is not a full wrapper of its CUDA operator.
Motivation
I was building a gather-and-scatter matrix multiplication operator several month ago. Although we chose Triton for easy kernel fusion, its performance is still not as good as the CUTLASS version provided in cutlass/examples/36_gather_scatter_fusion at main · NVIDIA/cutlass · GitHub.
Unfortunately, the template provided in the Python emit
function definition didn’t include the necessary options GatherA/B, ScatterD
and input indices
. Thus, my plan is to generate a normal GEMM kernel and modify the generated C++ code directly instead of changing the source code of CUTLASS.
check the source code in “python/cutlass/backend/gemm_operation.py”, the cpp code template looks like,
// Gemm operator ${operation_name}
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
${element_c}, ${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operation}
>::GemmKernel;
Compared with the full DefaultGemmUniversal
, it doesn’t include the last six options.
/// Gather operand A by using an index array
bool GatherA = false,
/// Gather operand B by using an index array
bool GatherB = false,
/// Scatter result D by using an index array
bool ScatterD = false,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
/// Permute operand A
typename PermuteALayout_ = layout::NoPermute,
/// Permute operand B
typename PermuteBLayout_ = layout::NoPermute,
///
Because changing the source of CUTLASS can be risky and complicated, I decide to modify the code generated by CUTLASS emitter.
Implement Custom Operator
Our first step is to generate a normal GEMM code by CUTLASS. Just as what we did in simple-example, but change the name to gs_gemm (gather-scatter-gemm),
import cutlass
import torch
dtype = torch.float16
plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.ColumnMajor, element_accumulator=torch.float32)
op = plan.construct()
gs_gemm = cutlass.emit.pytorch(op, name='gs_gemm', cc=plan.cc, sourcedir='gs_out')
After that, go to the generated folder, which looks like
.
├── gs_gemm.cpp
├── gs_gemm_kernel.cu
└── setup.py
We start from .cpp
file, and change the interface of the operator.
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
// indices shape is (gather_size, 1)
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
at::Tensor gs_gemm(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f) {
return gs_gemm_kernel(A, B, C, Indices, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, at::optional<const at::Tensor>, float, float>(&gs_gemm), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("Indices") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
I added a new argument Indices
, which specifies the columns we are going to select in matrix B
.
Then let’s go to the kernel code. First we need to set GatherB
and ScatterD
to true
when specifying the operator,
#include "cutlass/gemm/device/gemm_universal.h"
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_64x3_tt_align8
using DeviceKernel =
typename cutlass::gemm::device::GemmUniversal<
// Data type and layout of operand A
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type and layout of operand B
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type and layout of operand C
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type of accumulator
float,
// Class of operation
cutlass::arch::OpClassTensorOp,
// Compute capability of the target kernel
cutlass::arch::Sm80,
// Threadblock tile shape
cutlass::gemm::GemmShape<256, 128, 64>,
// Warp tile shape
cutlass::gemm::GemmShape<64, 64, 64>,
// Instruction shape
cutlass::gemm::GemmShape<16, 8, 16>,
// Epilogue functor
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,
// Swizzling function
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
// Number of pipeline stages
3,
// Alignment of operands A and B
8, 8,
// Type of math operation
cutlass::arch::OpMultiplyAdd,
// Complex transform types of operands A and B
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone,
+ false, /*GatherA*/
+ true, /*GatherB*/
+ true /*ScatterD*/
>;
Then in the following kernel launch function, add the new argument and pass that into the constructor of kernel’s argument
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status gs_gemm_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C,
+ const int* Indices,
DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta) {
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1, // split k dimension
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
+ nullptr, // <- pointer to index vector to gather A on device
+ Indices, // <- pointer to index vector to gather B on device
+ Indices // <- pointer to index vector to scatter D on device
};
// keep the rest of the kernel launcher
...
return status;
}
Finally, we also need to add the argument Indices
to the tensor to pointer converter.
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C,
+ at::optional<const at::Tensor> Indices,
float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
+ int* ptrIndices = (Indices == at::nullopt) ?
+ nullptr :
+ reinterpret_cast<int*>(Indices->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, torch::kF16);
cutlass::Status status = gs_gemm_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
+ ptrIndices,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
Here I made the Indices
an optional input, however, during my test, this looks like a bad idea since the kernel will throw fault if the indices is nullptr
because we specified GatherB
as True
. I will update the code later.
Experiment
Now let’s see if the new kernel works. Start from compiling the kernel by pip install -e .
, you will get
.
├── build
│ ├── lib.linux-x86_64-cpython-39
│ │ └── gs_gemm.cpython-39-x86_64-linux-gnu.so
│ └── temp.linux-x86_64-cpython-39
│ ├── build.ninja
│ ├── gs_gemm_kernel.o
│ └── gs_gemm.o
├── compiled_cache.db
├── gs_gemm.cpp
├── gs_gemm.cpython-39-x86_64-linux-gnu.so
├── gs_gemm.egg-info
│ ├── dependency_links.txt
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ └── top_level.txt
├── gs_gemm_kernel.cu
└── setup.py
I write a test code to check if the results are correct
import torch
import gs_gemm
import cutlass
dtype = torch.float16
import random
random.seed(2023)
# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
sizes = [(M, K), (K, N), (M, N)]
return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]
A, B, C = initialize(dtype, 128, 128, 128)
# select the first 12 columns
indices = torch.arange(12, device='cuda', dtype=torch.int32).reshape(1, -1)
cutlass_out = gs_gemm.run(A.clone(), B.clone(), None, indices)
torch.cuda.synchronize()
cutlass_out = cutlass_out.cpu()
print(cutlass_out)
The output is like
tensor([[ 98., 22., 30., ..., 24., 43., 31.],
[ 67., 0., 42., ..., -27., -21., 23.],
[ 18., -15., 64., ..., 57., 39., 33.],
...,
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float16)
Your can verify the correctness of the first 12 lines by comparing with
torch_out = torch.zeros_like(C)
torch_out[indices[0]] += (A.T @ B[indices[0]].T).T
print(torch_out)
The output is the same.
Here I made several transposes because the matrix declared in our operator is column major for better coalescing during gathering, you can find more detailed explanation in my previous post — Efficient Gather-and-scatter Matrix Multiplication Kernel with Triton - Xueshen Liu.
Final Implementation
gs_gemm.cpp
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
// indices shape is (gather_size, 1)
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f);
// C++ interface
at::Tensor gs_gemm(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f) {
return gs_gemm_kernel(A, B, C, Indices, alpha, beta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, at::optional<const at::Tensor>, float, float>(&gs_gemm), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("Indices") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}
gs_gemm_kernel.cu
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
if (size > 0) {
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({(long)size,}, options);
return reinterpret_cast<void*>(device_tensor.data_ptr());
} else {
return nullptr;
}
}
#include "cutlass/gemm/device/gemm_universal.h"
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_64x3_tt_align8
using DeviceKernel =
typename cutlass::gemm::device::GemmUniversal<
// Data type and layout of operand A
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type and layout of operand B
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type and layout of operand C
cutlass::half_t, cutlass::layout::ColumnMajor,
// Data type of accumulator
float,
// Class of operation
cutlass::arch::OpClassTensorOp,
// Compute capability of the target kernel
cutlass::arch::Sm80,
// Threadblock tile shape
cutlass::gemm::GemmShape<256, 128, 64>,
// Warp tile shape
cutlass::gemm::GemmShape<64, 64, 64>,
// Instruction shape
cutlass::gemm::GemmShape<16, 8, 16>,
// Epilogue functor
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,
// Swizzling function
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
// Number of pipeline stages
3,
// Alignment of operands A and B
8, 8,
// Type of math operation
cutlass::arch::OpMultiplyAdd,
// Complex transform types of operands A and B
cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone,
false, /*GatherA*/
true, /*GatherB*/
true /*ScatterD*/
>;
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status gs_gemm_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C,
const int* Indices,
DeviceKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta) {
typename DeviceKernel::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1, // split k dimension
{alpha, beta},
A, B, C, D,
0, 0, 0, 0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
nullptr, // <- pointer to index vector to gather A on device
Indices, // <- pointer to index vector to gather B on device
Indices // <- pointer to index vector to scatter D on device
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess) {
return status;
}
status = gemm_op();
return status;
}
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C,
at::optional<const at::Tensor> Indices,
float alpha, float beta) {
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
int* ptrIndices = (Indices == at::nullopt) ?
nullptr :
reinterpret_cast<int*>(Indices->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, torch::kF16);
cutlass::Status status = gs_gemm_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
ptrIndices,
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
Comments