Sources moved to GitHub.
Browse files- README.md +2 -1
- build.toml +0 -49
- causal-conv1d/causal_conv1d.cpp +0 -486
- causal-conv1d/causal_conv1d.h +0 -81
- causal-conv1d/causal_conv1d_bwd.cu +0 -627
- causal-conv1d/causal_conv1d_common.h +0 -98
- causal-conv1d/causal_conv1d_fwd.cu +0 -399
- causal-conv1d/causal_conv1d_update.cu +0 -137
- causal-conv1d/static_switch.h +0 -25
- flake.lock +0 -168
- flake.nix +0 -18
- tests/test_causal_conv1d.py +0 -353
- torch-ext/causal_conv1d/__init__.py +0 -4
- torch-ext/causal_conv1d/causal_conv1d_interface.py +0 -242
- torch-ext/causal_conv1d/causal_conv1d_varlen.py +0 -86
- torch-ext/causal_conv1d/cpp_functions.py +0 -96
- torch-ext/pytorch_shim.h +0 -105
- torch-ext/torch_binding.cpp +0 -32
- torch-ext/torch_binding.h +0 -39
README.md
CHANGED
|
@@ -6,5 +6,6 @@ tags:
|
|
| 6 |
|
| 7 |
## causal-conv1d
|
| 8 |
|
| 9 |
-
Causal depthwise conv1d kernel
|
| 10 |
|
|
|
|
|
|
| 6 |
|
| 7 |
## causal-conv1d
|
| 8 |
|
| 9 |
+
Causal [depthwise conv1d kernel](https://github.com/Dao-AILab/causal-conv1d/) by Tri Dao.
|
| 10 |
|
| 11 |
+
Kernel source: https://github.com/huggingface/kernels-community/tree/main/causal-conv1d
|
build.toml
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
[general]
|
| 2 |
-
name = "causal_conv1d"
|
| 3 |
-
universal = false
|
| 4 |
-
|
| 5 |
-
[torch]
|
| 6 |
-
src = [
|
| 7 |
-
"torch-ext/pytorch_shim.h",
|
| 8 |
-
"torch-ext/torch_binding.cpp",
|
| 9 |
-
"torch-ext/torch_binding.h"
|
| 10 |
-
]
|
| 11 |
-
|
| 12 |
-
[kernel.causal_conv1d]
|
| 13 |
-
backend = "cuda"
|
| 14 |
-
src = [
|
| 15 |
-
"causal-conv1d/causal_conv1d_bwd.cu",
|
| 16 |
-
"causal-conv1d/causal_conv1d_common.h",
|
| 17 |
-
"causal-conv1d/causal_conv1d.cpp",
|
| 18 |
-
"causal-conv1d/causal_conv1d_fwd.cu",
|
| 19 |
-
"causal-conv1d/causal_conv1d.h",
|
| 20 |
-
"causal-conv1d/causal_conv1d_update.cu",
|
| 21 |
-
"causal-conv1d/static_switch.h",
|
| 22 |
-
]
|
| 23 |
-
include = [ "causal-conv1d" ]
|
| 24 |
-
depends = [ "torch" ]
|
| 25 |
-
|
| 26 |
-
#[kernel.causal_conv1d_rocm]
|
| 27 |
-
#backend = "rocm"
|
| 28 |
-
#rocm-archs = [
|
| 29 |
-
# "gfx906",
|
| 30 |
-
# "gfx908",
|
| 31 |
-
# "gfx90a",
|
| 32 |
-
# "gfx940",
|
| 33 |
-
# "gfx941",
|
| 34 |
-
# "gfx942",
|
| 35 |
-
# "gfx1030",
|
| 36 |
-
# "gfx1100",
|
| 37 |
-
# "gfx1101",
|
| 38 |
-
#]
|
| 39 |
-
#src = [
|
| 40 |
-
# "causal-conv1d/causal_conv1d_bwd.cu",
|
| 41 |
-
# "causal-conv1d/causal_conv1d_common.h",
|
| 42 |
-
# "causal-conv1d/causal_conv1d.cpp",
|
| 43 |
-
# "causal-conv1d/causal_conv1d_fwd.cu",
|
| 44 |
-
# "causal-conv1d/causal_conv1d.h",
|
| 45 |
-
# "causal-conv1d/causal_conv1d_update.cu",
|
| 46 |
-
# "causal-conv1d/static_switch.h",
|
| 47 |
-
#]
|
| 48 |
-
#include = [ "causal-conv1d" ]
|
| 49 |
-
#depends = [ "torch" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d.cpp
DELETED
|
@@ -1,486 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include <torch/all.h>
|
| 6 |
-
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 7 |
-
#include <c10/core/DeviceGuard.h>
|
| 8 |
-
#else
|
| 9 |
-
#include <c10/cuda/CUDAGuard.h>
|
| 10 |
-
#endif
|
| 11 |
-
|
| 12 |
-
#include <c10/cuda/CUDAStream.h>
|
| 13 |
-
#include <vector>
|
| 14 |
-
|
| 15 |
-
#include "causal_conv1d.h"
|
| 16 |
-
|
| 17 |
-
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 18 |
-
|
| 19 |
-
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 20 |
-
if (ITYPE == at::ScalarType::Half) { \
|
| 21 |
-
using input_t = at::Half; \
|
| 22 |
-
__VA_ARGS__(); \
|
| 23 |
-
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 24 |
-
using input_t = at::BFloat16; \
|
| 25 |
-
__VA_ARGS__(); \
|
| 26 |
-
} else if (ITYPE == at::ScalarType::Float) { \
|
| 27 |
-
using input_t = float; \
|
| 28 |
-
__VA_ARGS__(); \
|
| 29 |
-
} else { \
|
| 30 |
-
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
| 34 |
-
if (WTYPE == at::ScalarType::Half) { \
|
| 35 |
-
using weight_t = at::Half; \
|
| 36 |
-
__VA_ARGS__(); \
|
| 37 |
-
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
| 38 |
-
using weight_t = at::BFloat16; \
|
| 39 |
-
__VA_ARGS__(); \
|
| 40 |
-
} else if (WTYPE == at::ScalarType::Float) { \
|
| 41 |
-
using weight_t = float; \
|
| 42 |
-
__VA_ARGS__(); \
|
| 43 |
-
} else { \
|
| 44 |
-
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
template<typename input_t, typename weight_t>
|
| 48 |
-
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 49 |
-
template <typename input_t, typename weight_t>
|
| 50 |
-
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 51 |
-
|
| 52 |
-
template<typename input_t, typename weight_t>
|
| 53 |
-
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 54 |
-
template<typename input_t, typename weight_t>
|
| 55 |
-
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 56 |
-
|
| 57 |
-
template<typename input_t, typename weight_t>
|
| 58 |
-
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 59 |
-
|
| 60 |
-
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
| 61 |
-
// sizes
|
| 62 |
-
const size_t batch,
|
| 63 |
-
const size_t dim,
|
| 64 |
-
const size_t seqlen,
|
| 65 |
-
const size_t width,
|
| 66 |
-
// device pointers
|
| 67 |
-
const at::Tensor x,
|
| 68 |
-
const at::Tensor weight,
|
| 69 |
-
const at::Tensor out,
|
| 70 |
-
void* bias_ptr,
|
| 71 |
-
bool silu_activation) {
|
| 72 |
-
|
| 73 |
-
// Reset the parameters
|
| 74 |
-
memset(¶ms, 0, sizeof(params));
|
| 75 |
-
|
| 76 |
-
params.batch = batch;
|
| 77 |
-
params.dim = dim;
|
| 78 |
-
params.seqlen = seqlen;
|
| 79 |
-
params.width = width;
|
| 80 |
-
|
| 81 |
-
params.silu_activation = silu_activation;
|
| 82 |
-
|
| 83 |
-
// Set the pointers and strides.
|
| 84 |
-
params.x_ptr = x.data_ptr();
|
| 85 |
-
params.weight_ptr = weight.data_ptr();
|
| 86 |
-
params.bias_ptr = bias_ptr;
|
| 87 |
-
params.out_ptr = out.data_ptr();
|
| 88 |
-
// All stride are in elements, not bytes.
|
| 89 |
-
params.x_batch_stride = x.stride(0);
|
| 90 |
-
params.x_c_stride = x.stride(1);
|
| 91 |
-
params.x_l_stride = x.stride(-1);
|
| 92 |
-
params.weight_c_stride = weight.stride(0);
|
| 93 |
-
params.weight_width_stride = weight.stride(1);
|
| 94 |
-
params.out_batch_stride = out.stride(0);
|
| 95 |
-
params.out_c_stride = out.stride(1);
|
| 96 |
-
params.out_l_stride = out.stride(-1);
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
void set_conv_params_bwd(ConvParamsBwd ¶ms,
|
| 101 |
-
// sizes
|
| 102 |
-
const size_t batch,
|
| 103 |
-
const size_t dim,
|
| 104 |
-
const size_t seqlen,
|
| 105 |
-
const size_t width,
|
| 106 |
-
// device pointers
|
| 107 |
-
const at::Tensor x,
|
| 108 |
-
const at::Tensor weight,
|
| 109 |
-
void* bias_ptr,
|
| 110 |
-
const at::Tensor dout,
|
| 111 |
-
const at::Tensor dx,
|
| 112 |
-
const at::Tensor dweight,
|
| 113 |
-
void* dbias_ptr,
|
| 114 |
-
bool silu_activation) {
|
| 115 |
-
// Pass in "dout" instead of "out", we're not gonna use "out" at all.
|
| 116 |
-
set_conv_params_fwd(params, batch, dim, seqlen, width,
|
| 117 |
-
x, weight, dout, bias_ptr, silu_activation);
|
| 118 |
-
|
| 119 |
-
// Set the pointers and strides.
|
| 120 |
-
params.dout_ptr = dout.data_ptr();
|
| 121 |
-
params.dx_ptr = dx.data_ptr();
|
| 122 |
-
params.dweight_ptr = dweight.data_ptr();
|
| 123 |
-
params.dbias_ptr = dbias_ptr;
|
| 124 |
-
// All stride are in elements, not bytes.
|
| 125 |
-
params.dout_batch_stride = dout.stride(0);
|
| 126 |
-
params.dout_c_stride = dout.stride(1);
|
| 127 |
-
params.dout_l_stride = dout.stride(2);
|
| 128 |
-
params.dweight_c_stride = dweight.stride(0);
|
| 129 |
-
params.dweight_width_stride = dweight.stride(1);
|
| 130 |
-
params.dx_batch_stride = dx.stride(0);
|
| 131 |
-
params.dx_c_stride = dx.stride(1);
|
| 132 |
-
params.dx_l_stride = dx.stride(2);
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
void
|
| 136 |
-
causal_conv1d_fwd(const at::Tensor &x,
|
| 137 |
-
const at::Tensor &weight,
|
| 138 |
-
const c10::optional<at::Tensor> &bias_,
|
| 139 |
-
const c10::optional<at::Tensor> &seq_idx_,
|
| 140 |
-
const c10::optional<at::Tensor> &initial_states_,
|
| 141 |
-
at::Tensor &out,
|
| 142 |
-
c10::optional<at::Tensor> &final_states_out_,
|
| 143 |
-
bool silu_activation) {
|
| 144 |
-
auto input_type = x.scalar_type();
|
| 145 |
-
auto weight_type = weight.scalar_type();
|
| 146 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 147 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 148 |
-
|
| 149 |
-
TORCH_CHECK(x.is_cuda());
|
| 150 |
-
TORCH_CHECK(weight.is_cuda());
|
| 151 |
-
|
| 152 |
-
const auto sizes = x.sizes();
|
| 153 |
-
const int batch_size = sizes[0];
|
| 154 |
-
const int dim = sizes[1];
|
| 155 |
-
const int seqlen = sizes[2];
|
| 156 |
-
const int width = weight.size(-1);
|
| 157 |
-
|
| 158 |
-
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 159 |
-
CHECK_SHAPE(weight, dim, width);
|
| 160 |
-
|
| 161 |
-
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 162 |
-
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 163 |
-
|
| 164 |
-
if (is_channel_last) {
|
| 165 |
-
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 166 |
-
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 167 |
-
}
|
| 168 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 169 |
-
|
| 170 |
-
if (bias_.has_value()) {
|
| 171 |
-
auto bias = bias_.value();
|
| 172 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 173 |
-
TORCH_CHECK(bias.is_cuda());
|
| 174 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
| 175 |
-
CHECK_SHAPE(bias, dim);
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
if (seq_idx_.has_value()) {
|
| 179 |
-
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
| 180 |
-
auto seq_idx = seq_idx_.value();
|
| 181 |
-
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 182 |
-
TORCH_CHECK(seq_idx.is_cuda());
|
| 183 |
-
TORCH_CHECK(seq_idx.is_contiguous());
|
| 184 |
-
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
ConvParamsBase params;
|
| 188 |
-
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 189 |
-
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 190 |
-
silu_activation);
|
| 191 |
-
|
| 192 |
-
if (seq_idx_.has_value()) {
|
| 193 |
-
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 194 |
-
} else {
|
| 195 |
-
params.seq_idx_ptr = nullptr;
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
if (initial_states_.has_value()) {
|
| 199 |
-
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 200 |
-
auto initial_states = initial_states_.value();
|
| 201 |
-
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 202 |
-
TORCH_CHECK(initial_states.is_cuda());
|
| 203 |
-
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 204 |
-
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 205 |
-
params.initial_states_ptr = initial_states.data_ptr();
|
| 206 |
-
params.initial_states_batch_stride = initial_states.stride(0);
|
| 207 |
-
params.initial_states_c_stride = initial_states.stride(1);
|
| 208 |
-
params.initial_states_l_stride = initial_states.stride(2);
|
| 209 |
-
} else {
|
| 210 |
-
params.initial_states_ptr = nullptr;
|
| 211 |
-
}
|
| 212 |
-
|
| 213 |
-
if (final_states_out_.has_value()) {
|
| 214 |
-
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
| 215 |
-
auto final_states = final_states_out_.value();
|
| 216 |
-
TORCH_CHECK(final_states.scalar_type() == input_type);
|
| 217 |
-
TORCH_CHECK(final_states.is_cuda());
|
| 218 |
-
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
| 219 |
-
TORCH_CHECK(final_states.stride(1) == 1);
|
| 220 |
-
params.final_states_ptr = final_states.data_ptr();
|
| 221 |
-
params.final_states_batch_stride = final_states.stride(0);
|
| 222 |
-
params.final_states_c_stride = final_states.stride(1);
|
| 223 |
-
params.final_states_l_stride = final_states.stride(2);
|
| 224 |
-
} else {
|
| 225 |
-
params.final_states_ptr = nullptr;
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 229 |
-
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 230 |
-
c10::DeviceGuard device_guard(x.device());
|
| 231 |
-
#else
|
| 232 |
-
at::cuda::CUDAGuard device_guard{x.device()};
|
| 233 |
-
#endif
|
| 234 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 235 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 236 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 237 |
-
if (!is_channel_last) {
|
| 238 |
-
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
| 239 |
-
} else {
|
| 240 |
-
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
| 241 |
-
}
|
| 242 |
-
});
|
| 243 |
-
});
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
void
|
| 247 |
-
causal_conv1d_bwd(const at::Tensor &x,
|
| 248 |
-
const at::Tensor &weight,
|
| 249 |
-
const c10::optional<at::Tensor> &bias_,
|
| 250 |
-
at::Tensor &dout,
|
| 251 |
-
const c10::optional<at::Tensor> &seq_idx_,
|
| 252 |
-
const c10::optional<at::Tensor> &initial_states_,
|
| 253 |
-
const c10::optional<at::Tensor> &dfinal_states_,
|
| 254 |
-
at::Tensor &dx,
|
| 255 |
-
at::Tensor &dweight,
|
| 256 |
-
c10::optional<at::Tensor> &dbias_,
|
| 257 |
-
c10::optional<at::Tensor> &dinitial_states_,
|
| 258 |
-
bool silu_activation) {
|
| 259 |
-
auto input_type = x.scalar_type();
|
| 260 |
-
auto weight_type = weight.scalar_type();
|
| 261 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 262 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 263 |
-
|
| 264 |
-
TORCH_CHECK(x.is_cuda());
|
| 265 |
-
TORCH_CHECK(weight.is_cuda());
|
| 266 |
-
TORCH_CHECK(dout.is_cuda());
|
| 267 |
-
TORCH_CHECK(bias_.has_value() == dbias_.has_value());
|
| 268 |
-
|
| 269 |
-
const auto sizes = x.sizes();
|
| 270 |
-
const int batch_size = sizes[0];
|
| 271 |
-
const int dim = sizes[1];
|
| 272 |
-
const int seqlen = sizes[2];
|
| 273 |
-
const int width = weight.size(-1);
|
| 274 |
-
|
| 275 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 276 |
-
|
| 277 |
-
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 278 |
-
CHECK_SHAPE(weight, dim, width);
|
| 279 |
-
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 280 |
-
|
| 281 |
-
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 282 |
-
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 283 |
-
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
|
| 284 |
-
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
|
| 285 |
-
|
| 286 |
-
if (is_channel_last) {
|
| 287 |
-
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 288 |
-
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 289 |
-
TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
if (bias_.has_value()) {
|
| 293 |
-
auto bias = bias_.value();
|
| 294 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 295 |
-
TORCH_CHECK(bias.is_cuda());
|
| 296 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
| 297 |
-
CHECK_SHAPE(bias, dim);
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
if (seq_idx_.has_value()) {
|
| 301 |
-
TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
|
| 302 |
-
auto seq_idx = seq_idx_.value();
|
| 303 |
-
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 304 |
-
TORCH_CHECK(seq_idx.is_cuda());
|
| 305 |
-
TORCH_CHECK(seq_idx.is_contiguous());
|
| 306 |
-
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 307 |
-
}
|
| 308 |
-
|
| 309 |
-
TORCH_CHECK(dx.scalar_type() == input_type);
|
| 310 |
-
TORCH_CHECK(dx.is_cuda());
|
| 311 |
-
CHECK_SHAPE(dx, batch_size, dim, seqlen);
|
| 312 |
-
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
|
| 313 |
-
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
|
| 314 |
-
|
| 315 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 316 |
-
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 317 |
-
c10::Device device = x.device();
|
| 318 |
-
c10::DeviceGuard device_guard(device);
|
| 319 |
-
#else
|
| 320 |
-
at::cuda::CUDAGuard device_guard{x.device()};
|
| 321 |
-
#endif
|
| 322 |
-
ConvParamsBwd params;
|
| 323 |
-
set_conv_params_bwd(params, batch_size, dim, seqlen, width,
|
| 324 |
-
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 325 |
-
dout, dx, dweight, bias_.has_value() ? dbias_.value().data_ptr() : nullptr,
|
| 326 |
-
silu_activation);
|
| 327 |
-
|
| 328 |
-
if (seq_idx_.has_value()) {
|
| 329 |
-
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 330 |
-
} else {
|
| 331 |
-
params.seq_idx_ptr = nullptr;
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
if (initial_states_.has_value()) {
|
| 335 |
-
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 336 |
-
auto initial_states = initial_states_.value();
|
| 337 |
-
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 338 |
-
TORCH_CHECK(initial_states.is_cuda());
|
| 339 |
-
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 340 |
-
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 341 |
-
params.initial_states_ptr = initial_states.data_ptr();
|
| 342 |
-
params.initial_states_batch_stride = initial_states.stride(0);
|
| 343 |
-
params.initial_states_c_stride = initial_states.stride(1);
|
| 344 |
-
params.initial_states_l_stride = initial_states.stride(2);
|
| 345 |
-
} else {
|
| 346 |
-
params.initial_states_ptr = nullptr;
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
if (dfinal_states_.has_value()) {
|
| 350 |
-
TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
|
| 351 |
-
auto dfinal_states = dfinal_states_.value();
|
| 352 |
-
TORCH_CHECK(dfinal_states.scalar_type() == input_type);
|
| 353 |
-
TORCH_CHECK(dfinal_states.is_cuda());
|
| 354 |
-
CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
|
| 355 |
-
params.dfinal_states_ptr = dfinal_states.data_ptr();
|
| 356 |
-
params.dfinal_states_batch_stride = dfinal_states.stride(0);
|
| 357 |
-
params.dfinal_states_c_stride = dfinal_states.stride(1);
|
| 358 |
-
params.dfinal_states_l_stride = dfinal_states.stride(2);
|
| 359 |
-
} else {
|
| 360 |
-
params.dfinal_states_ptr = nullptr;
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
if (dinitial_states_.has_value()) {
|
| 364 |
-
at::Tensor dinitial_states = dinitial_states_.value();
|
| 365 |
-
TORCH_CHECK(dinitial_states.stride(1) == 1);
|
| 366 |
-
params.dinitial_states_ptr = dinitial_states.data_ptr();
|
| 367 |
-
params.dinitial_states_batch_stride = dinitial_states.stride(0);
|
| 368 |
-
params.dinitial_states_c_stride = dinitial_states.stride(1);
|
| 369 |
-
params.dinitial_states_l_stride = dinitial_states.stride(2);
|
| 370 |
-
} else {
|
| 371 |
-
params.dinitial_states_ptr = nullptr;
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 375 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 376 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 377 |
-
if (!is_channel_last) {
|
| 378 |
-
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
|
| 379 |
-
} else {
|
| 380 |
-
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
|
| 381 |
-
}
|
| 382 |
-
});
|
| 383 |
-
});
|
| 384 |
-
}
|
| 385 |
-
|
| 386 |
-
void
|
| 387 |
-
causal_conv1d_update(const at::Tensor &x,
|
| 388 |
-
const at::Tensor &conv_state,
|
| 389 |
-
const at::Tensor &weight,
|
| 390 |
-
const c10::optional<at::Tensor> &bias_,
|
| 391 |
-
at::Tensor &out,
|
| 392 |
-
bool silu_activation,
|
| 393 |
-
const c10::optional<at::Tensor> &cache_seqlens_,
|
| 394 |
-
const c10::optional<at::Tensor> &conv_state_indices_
|
| 395 |
-
) {
|
| 396 |
-
auto input_type = x.scalar_type();
|
| 397 |
-
auto weight_type = weight.scalar_type();
|
| 398 |
-
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 399 |
-
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 400 |
-
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
| 401 |
-
|
| 402 |
-
TORCH_CHECK(x.is_cuda());
|
| 403 |
-
TORCH_CHECK(conv_state.is_cuda());
|
| 404 |
-
TORCH_CHECK(weight.is_cuda());
|
| 405 |
-
|
| 406 |
-
const auto sizes = x.sizes();
|
| 407 |
-
const int batch_size = sizes[0];
|
| 408 |
-
const int dim = sizes[1];
|
| 409 |
-
const int seqlen = sizes[2];
|
| 410 |
-
const int width = weight.size(-1);
|
| 411 |
-
const int conv_state_len = conv_state.size(2);
|
| 412 |
-
TORCH_CHECK(conv_state_len >= width - 1);
|
| 413 |
-
|
| 414 |
-
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 415 |
-
CHECK_SHAPE(weight, dim, width);
|
| 416 |
-
|
| 417 |
-
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 418 |
-
|
| 419 |
-
if (bias_.has_value()) {
|
| 420 |
-
auto bias = bias_.value();
|
| 421 |
-
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 422 |
-
TORCH_CHECK(bias.is_cuda());
|
| 423 |
-
TORCH_CHECK(bias.stride(-1) == 1);
|
| 424 |
-
CHECK_SHAPE(bias, dim);
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
-
ConvParamsBase params;
|
| 428 |
-
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 429 |
-
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 430 |
-
silu_activation);
|
| 431 |
-
params.conv_state_ptr = conv_state.data_ptr();
|
| 432 |
-
params.conv_state_len = conv_state_len;
|
| 433 |
-
// All stride are in elements, not bytes.
|
| 434 |
-
params.conv_state_batch_stride = conv_state.stride(0);
|
| 435 |
-
params.conv_state_c_stride = conv_state.stride(1);
|
| 436 |
-
params.conv_state_l_stride = conv_state.stride(2);
|
| 437 |
-
|
| 438 |
-
if (conv_state_indices_.has_value()) {
|
| 439 |
-
auto conv_state_indices = conv_state_indices_.value();
|
| 440 |
-
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
| 441 |
-
TORCH_CHECK(conv_state_indices.is_cuda());
|
| 442 |
-
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
| 443 |
-
CHECK_SHAPE(conv_state_indices, batch_size);
|
| 444 |
-
|
| 445 |
-
int conv_state_entries = conv_state.size(0);
|
| 446 |
-
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
| 447 |
-
|
| 448 |
-
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
| 449 |
-
} else {
|
| 450 |
-
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
| 451 |
-
params.conv_state_indices_ptr = nullptr;
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
if (cache_seqlens_.has_value()) {
|
| 455 |
-
auto cache_seqlens = cache_seqlens_.value();
|
| 456 |
-
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
| 457 |
-
TORCH_CHECK(cache_seqlens.is_cuda());
|
| 458 |
-
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
| 459 |
-
CHECK_SHAPE(cache_seqlens, batch_size);
|
| 460 |
-
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
| 461 |
-
} else {
|
| 462 |
-
params.cache_seqlens = nullptr;
|
| 463 |
-
}
|
| 464 |
-
|
| 465 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 466 |
-
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 467 |
-
c10::Device device = x.device();
|
| 468 |
-
c10::DeviceGuard device_guard(device);
|
| 469 |
-
#else
|
| 470 |
-
at::cuda::CUDAGuard device_guard{x.device()};
|
| 471 |
-
#endif
|
| 472 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 473 |
-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
| 474 |
-
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
|
| 475 |
-
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
| 476 |
-
});
|
| 477 |
-
});
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
/*
|
| 481 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 482 |
-
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
|
| 483 |
-
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
|
| 484 |
-
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
|
| 485 |
-
}
|
| 486 |
-
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d.h
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 8 |
-
|
| 9 |
-
struct ConvParamsBase {
|
| 10 |
-
using index_t = uint32_t;
|
| 11 |
-
|
| 12 |
-
int batch, dim, seqlen, width;
|
| 13 |
-
bool silu_activation;
|
| 14 |
-
|
| 15 |
-
index_t x_batch_stride;
|
| 16 |
-
index_t x_c_stride;
|
| 17 |
-
index_t x_l_stride;
|
| 18 |
-
index_t weight_c_stride;
|
| 19 |
-
index_t weight_width_stride;
|
| 20 |
-
index_t out_batch_stride;
|
| 21 |
-
index_t out_c_stride;
|
| 22 |
-
index_t out_l_stride;
|
| 23 |
-
|
| 24 |
-
int conv_state_len;
|
| 25 |
-
index_t conv_state_batch_stride;
|
| 26 |
-
index_t conv_state_c_stride;
|
| 27 |
-
index_t conv_state_l_stride;
|
| 28 |
-
|
| 29 |
-
// Common data pointers.
|
| 30 |
-
void *__restrict__ x_ptr;
|
| 31 |
-
void *__restrict__ weight_ptr;
|
| 32 |
-
void *__restrict__ bias_ptr;
|
| 33 |
-
void *__restrict__ out_ptr;
|
| 34 |
-
|
| 35 |
-
void *__restrict__ conv_state_ptr;
|
| 36 |
-
int32_t *__restrict__ cache_seqlens;
|
| 37 |
-
|
| 38 |
-
// Only used if the elements of the batch are gathered from a larger buffer,
|
| 39 |
-
// which may happen for continuous batching.
|
| 40 |
-
int32_t *__restrict__ conv_state_indices_ptr;
|
| 41 |
-
|
| 42 |
-
void *__restrict__ seq_idx_ptr;
|
| 43 |
-
|
| 44 |
-
// No __restrict__ since initial_states could be the same as final_states.
|
| 45 |
-
void * initial_states_ptr;
|
| 46 |
-
index_t initial_states_batch_stride;
|
| 47 |
-
index_t initial_states_l_stride;
|
| 48 |
-
index_t initial_states_c_stride;
|
| 49 |
-
|
| 50 |
-
void * final_states_ptr;
|
| 51 |
-
index_t final_states_batch_stride;
|
| 52 |
-
index_t final_states_l_stride;
|
| 53 |
-
index_t final_states_c_stride;
|
| 54 |
-
};
|
| 55 |
-
|
| 56 |
-
struct ConvParamsBwd: public ConvParamsBase {
|
| 57 |
-
index_t dx_batch_stride;
|
| 58 |
-
index_t dx_c_stride;
|
| 59 |
-
index_t dx_l_stride;
|
| 60 |
-
index_t dweight_c_stride;
|
| 61 |
-
index_t dweight_width_stride;
|
| 62 |
-
index_t dout_batch_stride;
|
| 63 |
-
index_t dout_c_stride;
|
| 64 |
-
index_t dout_l_stride;
|
| 65 |
-
|
| 66 |
-
// Common data pointers.
|
| 67 |
-
void *__restrict__ dx_ptr;
|
| 68 |
-
void *__restrict__ dweight_ptr;
|
| 69 |
-
void *__restrict__ dbias_ptr;
|
| 70 |
-
void *__restrict__ dout_ptr;
|
| 71 |
-
|
| 72 |
-
void * dinitial_states_ptr;
|
| 73 |
-
index_t dinitial_states_batch_stride;
|
| 74 |
-
index_t dinitial_states_l_stride;
|
| 75 |
-
index_t dinitial_states_c_stride;
|
| 76 |
-
|
| 77 |
-
void * dfinal_states_ptr;
|
| 78 |
-
index_t dfinal_states_batch_stride;
|
| 79 |
-
index_t dfinal_states_l_stride;
|
| 80 |
-
index_t dfinal_states_c_stride;
|
| 81 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d_bwd.cu
DELETED
|
@@ -1,627 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include <c10/util/BFloat16.h>
|
| 6 |
-
#include <c10/util/Half.h>
|
| 7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
-
|
| 9 |
-
#ifndef USE_ROCM
|
| 10 |
-
#include <cub/block/block_load.cuh>
|
| 11 |
-
#include <cub/block/block_store.cuh>
|
| 12 |
-
#include <cub/block/block_reduce.cuh>
|
| 13 |
-
#else
|
| 14 |
-
#include <hipcub/hipcub.hpp>
|
| 15 |
-
namespace cub = hipcub;
|
| 16 |
-
#endif
|
| 17 |
-
|
| 18 |
-
#include "causal_conv1d.h"
|
| 19 |
-
#include "causal_conv1d_common.h"
|
| 20 |
-
#include "static_switch.h"
|
| 21 |
-
|
| 22 |
-
template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 23 |
-
struct Causal_conv1d_bwd_kernel_traits {
|
| 24 |
-
using input_t = input_t_;
|
| 25 |
-
using weight_t = weight_t_;
|
| 26 |
-
static constexpr int kNThreads = kNThreads_;
|
| 27 |
-
static constexpr int kWidth = kWidth_;
|
| 28 |
-
static constexpr bool kSiluAct = kSiluAct_;
|
| 29 |
-
static constexpr int kNBytes = sizeof(input_t);
|
| 30 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 31 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 32 |
-
static_assert(kWidth <= kNElts);
|
| 33 |
-
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
|
| 34 |
-
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
|
| 35 |
-
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
|
| 36 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 37 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 38 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 39 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 40 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 41 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 42 |
-
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 43 |
-
static constexpr int kSmemIOSize = kIsVecLoad
|
| 44 |
-
? 0
|
| 45 |
-
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 46 |
-
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
|
| 47 |
-
static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
|
| 48 |
-
int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
|
| 49 |
-
};
|
| 50 |
-
|
| 51 |
-
template<typename Ktraits>
|
| 52 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 53 |
-
void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
|
| 54 |
-
constexpr int kWidth = Ktraits::kWidth;
|
| 55 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
| 56 |
-
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 57 |
-
static constexpr int kNElts = Ktraits::kNElts;
|
| 58 |
-
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
|
| 59 |
-
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 60 |
-
using input_t = typename Ktraits::input_t;
|
| 61 |
-
using vec_t = typename Ktraits::vec_t;
|
| 62 |
-
using weight_t = typename Ktraits::weight_t;
|
| 63 |
-
|
| 64 |
-
// Shared memory.
|
| 65 |
-
extern __shared__ char smem_[];
|
| 66 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 67 |
-
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 68 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 69 |
-
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 70 |
-
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 71 |
-
vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
|
| 72 |
-
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 73 |
-
|
| 74 |
-
const int tidx = threadIdx.x;
|
| 75 |
-
const int batch_id = blockIdx.x;
|
| 76 |
-
const int dim_id = blockIdx.y;
|
| 77 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 78 |
-
+ dim_id * params.x_c_stride;
|
| 79 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
|
| 80 |
-
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 81 |
-
+ dim_id * params.dout_c_stride;
|
| 82 |
-
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 83 |
-
+ dim_id * params.dx_c_stride;
|
| 84 |
-
float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
|
| 85 |
-
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
|
| 86 |
-
|
| 87 |
-
// Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
|
| 88 |
-
if (tidx == 0) {
|
| 89 |
-
if constexpr (!kSiluAct) {
|
| 90 |
-
input_t zeros[kNElts] = {0};
|
| 91 |
-
smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 92 |
-
} else {
|
| 93 |
-
float zeros[kNElts] = {0};
|
| 94 |
-
#pragma unroll
|
| 95 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 96 |
-
smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
|
| 97 |
-
}
|
| 98 |
-
}
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
float weight_vals[kWidth];
|
| 102 |
-
#pragma unroll
|
| 103 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
|
| 104 |
-
|
| 105 |
-
float dweight_vals[kWidth] = {0};
|
| 106 |
-
float dbias_val = 0;
|
| 107 |
-
|
| 108 |
-
constexpr int kChunkSize = kNThreads * kNElts;
|
| 109 |
-
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 110 |
-
x += (n_chunks - 1) * kChunkSize;
|
| 111 |
-
dout += (n_chunks - 1) * kChunkSize;
|
| 112 |
-
dx += (n_chunks - 1) * kChunkSize;
|
| 113 |
-
for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
|
| 114 |
-
input_t x_vals_load[2 * kNElts] = {0};
|
| 115 |
-
input_t dout_vals_load[2 * kNElts] = {0};
|
| 116 |
-
if constexpr(kIsVecLoad) {
|
| 117 |
-
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 118 |
-
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 119 |
-
} else {
|
| 120 |
-
__syncthreads();
|
| 121 |
-
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 122 |
-
__syncthreads();
|
| 123 |
-
typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
|
| 124 |
-
}
|
| 125 |
-
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
|
| 126 |
-
if constexpr (!kSiluAct) {
|
| 127 |
-
__syncthreads();
|
| 128 |
-
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 129 |
-
// the first elements of the next chunk.
|
| 130 |
-
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 131 |
-
__syncthreads();
|
| 132 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
|
| 133 |
-
__syncthreads();
|
| 134 |
-
// Now thread 0 can write the first elements of the current chunk.
|
| 135 |
-
if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 136 |
-
#pragma unroll
|
| 137 |
-
for (int i = 0; i < 2 * kNElts; ++i) {
|
| 138 |
-
dout_vals[i] = float(dout_vals_load[i]);
|
| 139 |
-
x_vals[i] = float(x_vals_load[i]);
|
| 140 |
-
}
|
| 141 |
-
} else {
|
| 142 |
-
if (tidx == 0 && chunk > 0) {
|
| 143 |
-
if constexpr(kIsVecLoad) {
|
| 144 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
|
| 145 |
-
} else {
|
| 146 |
-
#pragma unroll
|
| 147 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 148 |
-
if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
|
| 149 |
-
}
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
__syncthreads();
|
| 153 |
-
smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
|
| 154 |
-
__syncthreads();
|
| 155 |
-
if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
|
| 156 |
-
#pragma unroll
|
| 157 |
-
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 158 |
-
// Recompute the output
|
| 159 |
-
#pragma unroll
|
| 160 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 161 |
-
float out_val = bias_val;
|
| 162 |
-
#pragma unroll
|
| 163 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 164 |
-
out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 165 |
-
}
|
| 166 |
-
float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
|
| 167 |
-
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
|
| 168 |
-
* (1.0f + out_val * (1.0f - out_sigmoid_val));
|
| 169 |
-
}
|
| 170 |
-
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
|
| 171 |
-
// if input_t is 16 bits (since then we'd have 8 values of float)
|
| 172 |
-
__syncthreads();
|
| 173 |
-
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 174 |
-
// the first elements of the next chunk.
|
| 175 |
-
if (tidx > 0) {
|
| 176 |
-
#pragma unroll
|
| 177 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 178 |
-
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 179 |
-
}
|
| 180 |
-
}
|
| 181 |
-
__syncthreads();
|
| 182 |
-
#pragma unroll
|
| 183 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 184 |
-
reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
|
| 185 |
-
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
|
| 186 |
-
}
|
| 187 |
-
__syncthreads();
|
| 188 |
-
// Now thread 0 can write the first elements of the current chunk.
|
| 189 |
-
if (tidx == 0) {
|
| 190 |
-
#pragma unroll
|
| 191 |
-
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 192 |
-
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 193 |
-
}
|
| 194 |
-
}
|
| 195 |
-
}
|
| 196 |
-
dout -= kChunkSize;
|
| 197 |
-
x -= kChunkSize;
|
| 198 |
-
|
| 199 |
-
#pragma unroll
|
| 200 |
-
for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
|
| 201 |
-
|
| 202 |
-
float dx_vals[kNElts] = {0};
|
| 203 |
-
#pragma unroll
|
| 204 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 205 |
-
#pragma unroll
|
| 206 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 207 |
-
dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
|
| 208 |
-
}
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
input_t dx_vals_store[kNElts];
|
| 212 |
-
#pragma unroll
|
| 213 |
-
for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
|
| 214 |
-
if constexpr(kIsVecLoad) {
|
| 215 |
-
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 216 |
-
} else {
|
| 217 |
-
typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
|
| 218 |
-
}
|
| 219 |
-
dx -= kChunkSize;
|
| 220 |
-
|
| 221 |
-
#pragma unroll
|
| 222 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 223 |
-
#pragma unroll
|
| 224 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 225 |
-
dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
|
| 226 |
-
}
|
| 227 |
-
}
|
| 228 |
-
}
|
| 229 |
-
|
| 230 |
-
#pragma unroll
|
| 231 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 232 |
-
__syncthreads();
|
| 233 |
-
dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
|
| 234 |
-
if (tidx == 0) {
|
| 235 |
-
atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
|
| 236 |
-
}
|
| 237 |
-
}
|
| 238 |
-
if (params.bias_ptr != nullptr) {
|
| 239 |
-
__syncthreads();
|
| 240 |
-
dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
|
| 241 |
-
if (tidx == 0) {
|
| 242 |
-
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
|
| 243 |
-
}
|
| 244 |
-
}
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 248 |
-
void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 249 |
-
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 250 |
-
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 251 |
-
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 252 |
-
using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
|
| 253 |
-
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 254 |
-
dim3 grid(params.batch, params.dim);
|
| 255 |
-
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
|
| 256 |
-
|
| 257 |
-
if (kSmemSize >= 48 * 1024) {
|
| 258 |
-
#ifndef USE_ROCM
|
| 259 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 260 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 261 |
-
#else
|
| 262 |
-
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 263 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 264 |
-
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 265 |
-
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 266 |
-
#endif
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 271 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 272 |
-
});
|
| 273 |
-
});
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
template<typename input_t, typename weight_t>
|
| 277 |
-
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 278 |
-
if (params.width == 2) {
|
| 279 |
-
causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 280 |
-
} else if (params.width == 3) {
|
| 281 |
-
causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 282 |
-
} else if (params.width == 4) {
|
| 283 |
-
causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 288 |
-
struct Causal_conv1d_channellast_bwd_kernel_traits {
|
| 289 |
-
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 290 |
-
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 291 |
-
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 292 |
-
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 293 |
-
using input_t = input_t_;
|
| 294 |
-
using weight_t = weight_t_;
|
| 295 |
-
static constexpr bool kSiluAct = kSiluAct_;
|
| 296 |
-
static constexpr int kNThreads = kNThreads_;
|
| 297 |
-
static_assert(kNThreads % 32 == 0);
|
| 298 |
-
static constexpr int kNWarps = kNThreads / 32;
|
| 299 |
-
static constexpr int kWidth = kWidth_;
|
| 300 |
-
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 301 |
-
static constexpr int kNBytes = sizeof(input_t);
|
| 302 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 303 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 304 |
-
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 305 |
-
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 306 |
-
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 307 |
-
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 308 |
-
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 309 |
-
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 310 |
-
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 311 |
-
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 312 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 313 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 314 |
-
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 315 |
-
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 316 |
-
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 317 |
-
// sizeof(typename BlockStoreT::TempStorage)});
|
| 318 |
-
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 319 |
-
};
|
| 320 |
-
|
| 321 |
-
template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
|
| 322 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 323 |
-
void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
|
| 324 |
-
constexpr int kWidth = Ktraits::kWidth;
|
| 325 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
| 326 |
-
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 327 |
-
constexpr int kNElts = Ktraits::kNElts;
|
| 328 |
-
constexpr int kNWarp = Ktraits::kNWarps;
|
| 329 |
-
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 330 |
-
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 331 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 332 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 333 |
-
using input_t = typename Ktraits::input_t;
|
| 334 |
-
using vec_t = typename Ktraits::vec_t;
|
| 335 |
-
using weight_t = typename Ktraits::weight_t;
|
| 336 |
-
|
| 337 |
-
// Shared memory.
|
| 338 |
-
__shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 339 |
-
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 340 |
-
|
| 341 |
-
const int batch_id = blockIdx.x;
|
| 342 |
-
const int chunk_l_id = blockIdx.y;
|
| 343 |
-
const int chunk_c_id = blockIdx.z;
|
| 344 |
-
const int tid = threadIdx.x;
|
| 345 |
-
const int l_idx = tid / kNThreadsPerC;
|
| 346 |
-
const int c_idx = tid % kNThreadsPerC;
|
| 347 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 348 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 349 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 350 |
-
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 351 |
-
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 352 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 353 |
-
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 354 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 355 |
-
float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
|
| 356 |
-
+ chunk_c_id * kChunkSizeC * params.dweight_c_stride;
|
| 357 |
-
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 358 |
-
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 359 |
-
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 360 |
-
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 361 |
-
input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 362 |
-
: reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 363 |
-
input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
|
| 364 |
-
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
|
| 365 |
-
|
| 366 |
-
#pragma unroll
|
| 367 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 368 |
-
input_t dout_vals_load[kNElts] = {0};
|
| 369 |
-
input_t x_vals_load[kNElts] = {0};
|
| 370 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 371 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 372 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
|
| 373 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 374 |
-
}
|
| 375 |
-
reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 376 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 377 |
-
}
|
| 378 |
-
// Load the elements from the previous chunk or next chunk that are needed for convolution.
|
| 379 |
-
if (l_idx < kWidth - 1) {
|
| 380 |
-
input_t dout_vals_load[kNElts] = {0};
|
| 381 |
-
input_t x_vals_load[kNElts] = {0};
|
| 382 |
-
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 383 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 384 |
-
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
|
| 385 |
-
}
|
| 386 |
-
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 387 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 388 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 389 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 390 |
-
} else if (initial_states != nullptr
|
| 391 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 392 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 393 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 394 |
-
}
|
| 395 |
-
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 396 |
-
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 397 |
-
}
|
| 398 |
-
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
|
| 399 |
-
if constexpr (kSiluAct) {
|
| 400 |
-
if (l_idx < kWidth - 1) {
|
| 401 |
-
input_t x_vals_load[kNElts] = {0};
|
| 402 |
-
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 403 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 404 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
|
| 405 |
-
}
|
| 406 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 407 |
-
}
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
__syncthreads();
|
| 411 |
-
|
| 412 |
-
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 413 |
-
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 414 |
-
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 415 |
-
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 416 |
-
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 417 |
-
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 418 |
-
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 419 |
-
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 420 |
-
static_assert(kNThreadsPerRow <= 32);
|
| 421 |
-
|
| 422 |
-
const int row_idx = tid / kNThreadsPerRow;
|
| 423 |
-
const int col_idx = tid % kNThreadsPerRow;
|
| 424 |
-
|
| 425 |
-
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 426 |
-
float weight_vals[kWidth] = {0};
|
| 427 |
-
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 428 |
-
#pragma unroll
|
| 429 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 430 |
-
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 431 |
-
}
|
| 432 |
-
}
|
| 433 |
-
float dout_vals[kLPerThread + kWidth - 1];
|
| 434 |
-
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 435 |
-
#pragma unroll
|
| 436 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 437 |
-
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
|
| 438 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
-
int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 442 |
-
if constexpr (kHasSeqIdx) {
|
| 443 |
-
#pragma unroll
|
| 444 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 445 |
-
const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
|
| 446 |
-
seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 447 |
-
}
|
| 448 |
-
}
|
| 449 |
-
|
| 450 |
-
if constexpr (kSiluAct) { // Recompute the output
|
| 451 |
-
#pragma unroll
|
| 452 |
-
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 453 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 454 |
-
}
|
| 455 |
-
#pragma unroll
|
| 456 |
-
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
|
| 457 |
-
float out_val = bias_val;
|
| 458 |
-
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 459 |
-
#pragma unroll
|
| 460 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 461 |
-
if constexpr (!kHasSeqIdx) {
|
| 462 |
-
out_val += weight_vals[w] * x_vals[i + w];
|
| 463 |
-
} else {
|
| 464 |
-
out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 465 |
-
}
|
| 466 |
-
}
|
| 467 |
-
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
|
| 468 |
-
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
|
| 469 |
-
}
|
| 470 |
-
}
|
| 471 |
-
|
| 472 |
-
float dweight_vals[kWidth] = {0};
|
| 473 |
-
SumOp<float> sum_op;
|
| 474 |
-
#pragma unroll
|
| 475 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 476 |
-
#pragma unroll
|
| 477 |
-
for (int i = 0; i < kLPerThread; ++i) {
|
| 478 |
-
if constexpr (!kHasSeqIdx) {
|
| 479 |
-
dweight_vals[w] += x_vals[i + w] * dout_vals[i];
|
| 480 |
-
} else {
|
| 481 |
-
dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
|
| 482 |
-
}
|
| 483 |
-
}
|
| 484 |
-
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
|
| 485 |
-
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 486 |
-
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
|
| 487 |
-
}
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
if (params.bias_ptr != nullptr) {
|
| 491 |
-
float dbias_val = 0.f;
|
| 492 |
-
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
|
| 493 |
-
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
|
| 494 |
-
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 495 |
-
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
|
| 496 |
-
}
|
| 497 |
-
}
|
| 498 |
-
|
| 499 |
-
float dx_vals[kLPerThread] = {0};
|
| 500 |
-
#pragma unroll
|
| 501 |
-
for (int i = 0; i < kLPerThread; ++i) {
|
| 502 |
-
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 503 |
-
#pragma unroll
|
| 504 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 505 |
-
if constexpr (!kHasSeqIdx) {
|
| 506 |
-
dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
|
| 507 |
-
} else {
|
| 508 |
-
dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
|
| 509 |
-
}
|
| 510 |
-
}
|
| 511 |
-
// if (dfinal_states != nullptr) {
|
| 512 |
-
if constexpr (kHasDfinalStates) {
|
| 513 |
-
if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
|
| 514 |
-
&& chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
|
| 515 |
-
&& chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 516 |
-
dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 517 |
-
}
|
| 518 |
-
}
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
float dxinit_vals[kWidth - 1] = {0};
|
| 522 |
-
static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
|
| 523 |
-
if (dinitial_states != nullptr && col_idx == 0) {
|
| 524 |
-
#pragma unroll
|
| 525 |
-
for (int i = 0; i < kWidth - 1; ++i) {
|
| 526 |
-
#pragma unroll
|
| 527 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 528 |
-
dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
|
| 529 |
-
}
|
| 530 |
-
// chunk_l_id must be 0 because dinitial_states != nullptr
|
| 531 |
-
// if (dfinal_states != nullptr) {
|
| 532 |
-
if constexpr (kHasDfinalStates) {
|
| 533 |
-
if (i >= params.seqlen) {
|
| 534 |
-
dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 535 |
-
}
|
| 536 |
-
}
|
| 537 |
-
}
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
__syncthreads();
|
| 541 |
-
#pragma unroll
|
| 542 |
-
for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
|
| 543 |
-
if (dinitial_states != nullptr && col_idx == 0) {
|
| 544 |
-
#pragma unroll
|
| 545 |
-
for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
|
| 546 |
-
}
|
| 547 |
-
__syncthreads();
|
| 548 |
-
|
| 549 |
-
#pragma unroll
|
| 550 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 551 |
-
input_t dx_vals_store[kNElts];
|
| 552 |
-
reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
|
| 553 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 554 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 555 |
-
*reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
|
| 556 |
-
}
|
| 557 |
-
}
|
| 558 |
-
if (dinitial_states != nullptr
|
| 559 |
-
&& l_idx < kWidth - 1
|
| 560 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 561 |
-
input_t dxinit_vals_store[kNElts];
|
| 562 |
-
reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
|
| 563 |
-
*reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
|
| 564 |
-
}
|
| 565 |
-
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 569 |
-
void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 570 |
-
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 571 |
-
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 572 |
-
BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
|
| 573 |
-
BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
|
| 574 |
-
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
|
| 575 |
-
static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
|
| 576 |
-
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
|
| 577 |
-
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 578 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 579 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 580 |
-
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 581 |
-
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 582 |
-
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 583 |
-
dim3 block(Ktraits::kNThreads);
|
| 584 |
-
auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
|
| 585 |
-
// if (kSmemSize >= 48 * 1024) {
|
| 586 |
-
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 587 |
-
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 588 |
-
// }
|
| 589 |
-
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 590 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 591 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 592 |
-
});
|
| 593 |
-
});
|
| 594 |
-
});
|
| 595 |
-
});
|
| 596 |
-
}
|
| 597 |
-
|
| 598 |
-
template<typename input_t, typename weight_t>
|
| 599 |
-
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 600 |
-
if (params.width == 2) {
|
| 601 |
-
causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 602 |
-
} else if (params.width == 3) {
|
| 603 |
-
causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 604 |
-
} else if (params.width == 4) {
|
| 605 |
-
causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 606 |
-
}
|
| 607 |
-
}
|
| 608 |
-
|
| 609 |
-
template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 610 |
-
template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 611 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 612 |
-
template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 613 |
-
template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 614 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 615 |
-
template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 616 |
-
template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 617 |
-
template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 618 |
-
|
| 619 |
-
template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 620 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 621 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 622 |
-
template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 623 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 624 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 625 |
-
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 626 |
-
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 627 |
-
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d_common.h
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2023, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#ifndef USE_ROCM
|
| 8 |
-
#include <cuda_bf16.h>
|
| 9 |
-
|
| 10 |
-
template<typename T>
|
| 11 |
-
__device__ inline T shuffle_xor(T val, int offset) {
|
| 12 |
-
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
| 13 |
-
}
|
| 14 |
-
|
| 15 |
-
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 16 |
-
{
|
| 17 |
-
return std::max(ilist);
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
template<typename T>
|
| 21 |
-
constexpr T constexpr_min(T a, T b) {
|
| 22 |
-
return std::min(a, b);
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
#else
|
| 26 |
-
#include <hip/hip_bf16.h>
|
| 27 |
-
|
| 28 |
-
template<typename T>
|
| 29 |
-
__device__ inline T shuffle_xor(T val, int offset) {
|
| 30 |
-
return __shfl_xor(val, offset);
|
| 31 |
-
}
|
| 32 |
-
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 33 |
-
{
|
| 34 |
-
return *std::max_element(ilist.begin(), ilist.end());
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
template<typename T>
|
| 38 |
-
constexpr T constexpr_min(T a, T b) {
|
| 39 |
-
return a < b ? a : b;
|
| 40 |
-
}
|
| 41 |
-
#endif
|
| 42 |
-
#include <cuda_fp16.h>
|
| 43 |
-
|
| 44 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
-
|
| 46 |
-
template<int BYTES> struct BytesToType {};
|
| 47 |
-
|
| 48 |
-
template<> struct BytesToType<16> {
|
| 49 |
-
using Type = uint4;
|
| 50 |
-
static_assert(sizeof(Type) == 16);
|
| 51 |
-
};
|
| 52 |
-
|
| 53 |
-
template<> struct BytesToType<8> {
|
| 54 |
-
using Type = uint64_t;
|
| 55 |
-
static_assert(sizeof(Type) == 8);
|
| 56 |
-
};
|
| 57 |
-
|
| 58 |
-
template<> struct BytesToType<4> {
|
| 59 |
-
using Type = uint32_t;
|
| 60 |
-
static_assert(sizeof(Type) == 4);
|
| 61 |
-
};
|
| 62 |
-
|
| 63 |
-
template<> struct BytesToType<2> {
|
| 64 |
-
using Type = uint16_t;
|
| 65 |
-
static_assert(sizeof(Type) == 2);
|
| 66 |
-
};
|
| 67 |
-
|
| 68 |
-
template<> struct BytesToType<1> {
|
| 69 |
-
using Type = uint8_t;
|
| 70 |
-
static_assert(sizeof(Type) == 1);
|
| 71 |
-
};
|
| 72 |
-
|
| 73 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
-
|
| 75 |
-
template<typename T>
|
| 76 |
-
struct SumOp {
|
| 77 |
-
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
| 78 |
-
};
|
| 79 |
-
|
| 80 |
-
template<int THREADS>
|
| 81 |
-
struct Allreduce {
|
| 82 |
-
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
| 83 |
-
template<typename T, typename Operator>
|
| 84 |
-
static __device__ inline T run(T x, Operator &op) {
|
| 85 |
-
constexpr int OFFSET = THREADS / 2;
|
| 86 |
-
x = op(x, shuffle_xor(x, OFFSET));
|
| 87 |
-
return Allreduce<OFFSET>::run(x, op);
|
| 88 |
-
}
|
| 89 |
-
};
|
| 90 |
-
|
| 91 |
-
template<>
|
| 92 |
-
struct Allreduce<2> {
|
| 93 |
-
template<typename T, typename Operator>
|
| 94 |
-
static __device__ inline T run(T x, Operator &op) {
|
| 95 |
-
x = op(x, shuffle_xor(x, 1));
|
| 96 |
-
return x;
|
| 97 |
-
}
|
| 98 |
-
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d_fwd.cu
DELETED
|
@@ -1,399 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include <c10/util/BFloat16.h>
|
| 6 |
-
#include <c10/util/Half.h>
|
| 7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
-
|
| 9 |
-
#ifndef USE_ROCM
|
| 10 |
-
#include <cub/block/block_load.cuh>
|
| 11 |
-
#include <cub/block/block_store.cuh>
|
| 12 |
-
#else
|
| 13 |
-
#include <hipcub/hipcub.hpp>
|
| 14 |
-
namespace cub = hipcub;
|
| 15 |
-
#endif
|
| 16 |
-
|
| 17 |
-
#include "causal_conv1d.h"
|
| 18 |
-
#include "causal_conv1d_common.h"
|
| 19 |
-
#include "static_switch.h"
|
| 20 |
-
|
| 21 |
-
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 22 |
-
struct Causal_conv1d_fwd_kernel_traits {
|
| 23 |
-
using input_t = input_t_;
|
| 24 |
-
using weight_t = weight_t_;
|
| 25 |
-
static constexpr int kNThreads = kNThreads_;
|
| 26 |
-
static constexpr int kWidth = kWidth_;
|
| 27 |
-
static constexpr int kNBytes = sizeof(input_t);
|
| 28 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 29 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 30 |
-
static_assert(kWidth <= kNElts);
|
| 31 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 32 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 33 |
-
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 34 |
-
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 35 |
-
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 36 |
-
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 37 |
-
static constexpr int kSmemIOSize = kIsVecLoad
|
| 38 |
-
? 0
|
| 39 |
-
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 40 |
-
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
| 41 |
-
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
| 42 |
-
};
|
| 43 |
-
|
| 44 |
-
template<typename Ktraits>
|
| 45 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 46 |
-
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
| 47 |
-
constexpr int kWidth = Ktraits::kWidth;
|
| 48 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
| 49 |
-
constexpr int kNElts = Ktraits::kNElts;
|
| 50 |
-
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 51 |
-
using input_t = typename Ktraits::input_t;
|
| 52 |
-
using vec_t = typename Ktraits::vec_t;
|
| 53 |
-
using weight_t = typename Ktraits::weight_t;
|
| 54 |
-
|
| 55 |
-
// Shared memory.
|
| 56 |
-
extern __shared__ char smem_[];
|
| 57 |
-
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 58 |
-
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 59 |
-
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 60 |
-
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 61 |
-
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 62 |
-
|
| 63 |
-
const int tidx = threadIdx.x;
|
| 64 |
-
const int batch_id = blockIdx.x;
|
| 65 |
-
const int channel_id = blockIdx.y;
|
| 66 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 67 |
-
+ channel_id * params.x_c_stride;
|
| 68 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 69 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 70 |
-
+ channel_id * params.out_c_stride;
|
| 71 |
-
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 72 |
-
|
| 73 |
-
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
| 74 |
-
if (tidx == 0) {
|
| 75 |
-
input_t zeros[kNElts] = {0};
|
| 76 |
-
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
float weight_vals[kWidth];
|
| 80 |
-
#pragma unroll
|
| 81 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 82 |
-
|
| 83 |
-
constexpr int kChunkSize = kNThreads * kNElts;
|
| 84 |
-
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 85 |
-
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
| 86 |
-
input_t x_vals_load[2 * kNElts] = {0};
|
| 87 |
-
if constexpr(kIsVecLoad) {
|
| 88 |
-
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 89 |
-
} else {
|
| 90 |
-
__syncthreads();
|
| 91 |
-
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 92 |
-
}
|
| 93 |
-
x += kChunkSize;
|
| 94 |
-
__syncthreads();
|
| 95 |
-
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
| 96 |
-
// the last elements of the previous chunk.
|
| 97 |
-
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 98 |
-
__syncthreads();
|
| 99 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
| 100 |
-
__syncthreads();
|
| 101 |
-
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
| 102 |
-
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 103 |
-
|
| 104 |
-
float x_vals[2 * kNElts];
|
| 105 |
-
#pragma unroll
|
| 106 |
-
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 107 |
-
|
| 108 |
-
float out_vals[kNElts];
|
| 109 |
-
#pragma unroll
|
| 110 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 111 |
-
out_vals[i] = bias_val;
|
| 112 |
-
#pragma unroll
|
| 113 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 114 |
-
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 115 |
-
}
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
if (params.silu_activation) {
|
| 119 |
-
#pragma unroll
|
| 120 |
-
for (int i = 0; i < kNElts; ++i) {
|
| 121 |
-
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
| 122 |
-
}
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
input_t out_vals_store[kNElts];
|
| 126 |
-
#pragma unroll
|
| 127 |
-
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
| 128 |
-
if constexpr(kIsVecLoad) {
|
| 129 |
-
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 130 |
-
} else {
|
| 131 |
-
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
| 132 |
-
}
|
| 133 |
-
out += kChunkSize;
|
| 134 |
-
}
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 138 |
-
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 139 |
-
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 140 |
-
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 141 |
-
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
| 142 |
-
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 143 |
-
dim3 grid(params.batch, params.dim);
|
| 144 |
-
|
| 145 |
-
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
| 146 |
-
|
| 147 |
-
if (kSmemSize >= 48 * 1024) {
|
| 148 |
-
#ifndef USE_ROCM
|
| 149 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 150 |
-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 151 |
-
#else
|
| 152 |
-
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 153 |
-
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 154 |
-
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 155 |
-
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 156 |
-
#endif
|
| 157 |
-
}
|
| 158 |
-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 159 |
-
|
| 160 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 161 |
-
});
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
template<typename input_t, typename weight_t>
|
| 165 |
-
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 166 |
-
if (params.width == 2) {
|
| 167 |
-
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 168 |
-
} else if (params.width == 3) {
|
| 169 |
-
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 170 |
-
} else if (params.width == 4) {
|
| 171 |
-
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 172 |
-
}
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 176 |
-
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
| 177 |
-
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 178 |
-
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 179 |
-
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 180 |
-
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 181 |
-
using input_t = input_t_;
|
| 182 |
-
using weight_t = weight_t_;
|
| 183 |
-
static constexpr int kNThreads = kNThreads_;
|
| 184 |
-
static_assert(kNThreads % 32 == 0);
|
| 185 |
-
static constexpr int kNWarps = kNThreads / 32;
|
| 186 |
-
static constexpr int kWidth = kWidth_;
|
| 187 |
-
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 188 |
-
static constexpr int kNBytes = sizeof(input_t);
|
| 189 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 190 |
-
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 191 |
-
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 192 |
-
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 193 |
-
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 194 |
-
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 195 |
-
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 196 |
-
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 197 |
-
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 198 |
-
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 199 |
-
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 200 |
-
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 201 |
-
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 202 |
-
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 203 |
-
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 204 |
-
// sizeof(typename BlockStoreT::TempStorage)});
|
| 205 |
-
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 206 |
-
};
|
| 207 |
-
|
| 208 |
-
template<typename Ktraits, bool kHasSeqIdx>
|
| 209 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 210 |
-
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
| 211 |
-
constexpr int kWidth = Ktraits::kWidth;
|
| 212 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
| 213 |
-
constexpr int kNElts = Ktraits::kNElts;
|
| 214 |
-
constexpr int kNWarp = Ktraits::kNWarps;
|
| 215 |
-
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 216 |
-
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 217 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 218 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 219 |
-
using input_t = typename Ktraits::input_t;
|
| 220 |
-
using vec_t = typename Ktraits::vec_t;
|
| 221 |
-
using weight_t = typename Ktraits::weight_t;
|
| 222 |
-
|
| 223 |
-
// Shared memory.
|
| 224 |
-
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
| 225 |
-
|
| 226 |
-
const int batch_id = blockIdx.x;
|
| 227 |
-
const int chunk_l_id = blockIdx.y;
|
| 228 |
-
const int chunk_c_id = blockIdx.z;
|
| 229 |
-
const int tid = threadIdx.x;
|
| 230 |
-
const int l_idx = tid / kNThreadsPerC;
|
| 231 |
-
const int c_idx = tid % kNThreadsPerC;
|
| 232 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 233 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 234 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 235 |
-
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 236 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 237 |
-
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 238 |
-
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 239 |
-
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 240 |
-
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 241 |
-
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 242 |
-
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
| 243 |
-
// from the previous L-chunk.
|
| 244 |
-
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
| 245 |
-
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 246 |
-
|
| 247 |
-
#pragma unroll
|
| 248 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 249 |
-
input_t x_vals_load[kNElts] = {0};
|
| 250 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 251 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 252 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 253 |
-
}
|
| 254 |
-
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 255 |
-
}
|
| 256 |
-
// Load the elements from the previous chunk that are needed for convolution.
|
| 257 |
-
if (l_idx < kWidth - 1) {
|
| 258 |
-
input_t x_vals_load[kNElts] = {0};
|
| 259 |
-
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 260 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 261 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 262 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 263 |
-
} else if (initial_states != nullptr
|
| 264 |
-
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 265 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 266 |
-
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 267 |
-
}
|
| 268 |
-
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
__syncthreads();
|
| 272 |
-
|
| 273 |
-
if (final_states != nullptr
|
| 274 |
-
&& l_idx < kWidth - 1
|
| 275 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 276 |
-
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
| 277 |
-
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
| 278 |
-
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 282 |
-
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 283 |
-
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 284 |
-
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 285 |
-
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 286 |
-
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 287 |
-
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 288 |
-
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 289 |
-
static_assert(kNThreadsPerRow <= 32);
|
| 290 |
-
|
| 291 |
-
const int row_idx = tid / kNThreadsPerRow;
|
| 292 |
-
const int col_idx = tid % kNThreadsPerRow;
|
| 293 |
-
|
| 294 |
-
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 295 |
-
float weight_vals[kWidth] = {0};
|
| 296 |
-
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 297 |
-
#pragma unroll
|
| 298 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 299 |
-
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 300 |
-
}
|
| 301 |
-
}
|
| 302 |
-
float x_vals[kWidth - 1 + kLPerThread];
|
| 303 |
-
#pragma unroll
|
| 304 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 305 |
-
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 306 |
-
}
|
| 307 |
-
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
| 308 |
-
if constexpr (kHasSeqIdx) {
|
| 309 |
-
#pragma unroll
|
| 310 |
-
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 311 |
-
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 312 |
-
}
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
float out_vals[kLPerThread];
|
| 316 |
-
#pragma unroll
|
| 317 |
-
for (int i = 0; i < kLPerThread; ++i) {
|
| 318 |
-
out_vals[i] = bias_val;
|
| 319 |
-
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 320 |
-
#pragma unroll
|
| 321 |
-
for (int w = 0; w < kWidth; ++w) {
|
| 322 |
-
if constexpr (!kHasSeqIdx) {
|
| 323 |
-
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
| 324 |
-
} else {
|
| 325 |
-
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 326 |
-
}
|
| 327 |
-
}
|
| 328 |
-
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
| 329 |
-
}
|
| 330 |
-
|
| 331 |
-
__syncthreads();
|
| 332 |
-
#pragma unroll
|
| 333 |
-
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
| 334 |
-
__syncthreads();
|
| 335 |
-
|
| 336 |
-
#pragma unroll
|
| 337 |
-
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 338 |
-
input_t out_vals_store[kNElts];
|
| 339 |
-
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
| 340 |
-
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 341 |
-
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 342 |
-
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
| 343 |
-
}
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 349 |
-
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 350 |
-
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 351 |
-
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
| 352 |
-
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 353 |
-
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 354 |
-
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 355 |
-
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 356 |
-
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 357 |
-
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 358 |
-
dim3 block(Ktraits::kNThreads);
|
| 359 |
-
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
| 360 |
-
// if (kSmemSize >= 48 * 1024) {
|
| 361 |
-
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 362 |
-
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 363 |
-
// }
|
| 364 |
-
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 365 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 366 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 367 |
-
});
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
-
template<typename input_t, typename weight_t>
|
| 371 |
-
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 372 |
-
if (params.width == 2) {
|
| 373 |
-
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 374 |
-
} else if (params.width == 3) {
|
| 375 |
-
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 376 |
-
} else if (params.width == 4) {
|
| 377 |
-
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 378 |
-
}
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 382 |
-
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 383 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 384 |
-
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 385 |
-
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 386 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 387 |
-
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 388 |
-
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 389 |
-
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 390 |
-
|
| 391 |
-
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 392 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 393 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 394 |
-
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 395 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 396 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 397 |
-
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 398 |
-
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 399 |
-
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/causal_conv1d_update.cu
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2023, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include <c10/util/BFloat16.h>
|
| 6 |
-
#include <c10/util/Half.h>
|
| 7 |
-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
-
|
| 9 |
-
#include "causal_conv1d.h"
|
| 10 |
-
#include "causal_conv1d_common.h"
|
| 11 |
-
#include "static_switch.h"
|
| 12 |
-
|
| 13 |
-
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
| 14 |
-
struct Causal_conv1d_update_kernel_traits {
|
| 15 |
-
using input_t = input_t_;
|
| 16 |
-
using weight_t = weight_t_;
|
| 17 |
-
static constexpr int kNThreads = kNThreads_;
|
| 18 |
-
static constexpr int kWidth = kWidth_;
|
| 19 |
-
static constexpr int kNBytes = sizeof(input_t);
|
| 20 |
-
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 21 |
-
};
|
| 22 |
-
|
| 23 |
-
template<typename Ktraits, bool kIsCircularBuffer>
|
| 24 |
-
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 25 |
-
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
| 26 |
-
constexpr int kWidth = Ktraits::kWidth;
|
| 27 |
-
constexpr int kNThreads = Ktraits::kNThreads;
|
| 28 |
-
using input_t = typename Ktraits::input_t;
|
| 29 |
-
using weight_t = typename Ktraits::weight_t;
|
| 30 |
-
|
| 31 |
-
const int tidx = threadIdx.x;
|
| 32 |
-
const int batch_id = blockIdx.x;
|
| 33 |
-
const int channel_id = blockIdx.y * kNThreads + tidx;
|
| 34 |
-
if (channel_id >= params.dim) return;
|
| 35 |
-
|
| 36 |
-
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 37 |
-
+ channel_id * params.x_c_stride;
|
| 38 |
-
|
| 39 |
-
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
| 40 |
-
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
| 41 |
-
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
| 42 |
-
? batch_id
|
| 43 |
-
: params.conv_state_indices_ptr[batch_id];
|
| 44 |
-
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
| 45 |
-
+ conv_state_batch_coord * params.conv_state_batch_stride
|
| 46 |
-
+ channel_id * params.conv_state_c_stride;
|
| 47 |
-
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 48 |
-
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 49 |
-
+ channel_id * params.out_c_stride;
|
| 50 |
-
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 51 |
-
|
| 52 |
-
int state_len = params.conv_state_len;
|
| 53 |
-
int advance_len = params.seqlen;
|
| 54 |
-
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
| 55 |
-
int update_idx = cache_seqlen - (kWidth - 1);
|
| 56 |
-
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
| 57 |
-
|
| 58 |
-
float weight_vals[kWidth] = {0};
|
| 59 |
-
#pragma unroll
|
| 60 |
-
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 61 |
-
|
| 62 |
-
float x_vals[kWidth] = {0};
|
| 63 |
-
if constexpr (!kIsCircularBuffer) {
|
| 64 |
-
#pragma unroll 2
|
| 65 |
-
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
| 66 |
-
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
| 67 |
-
}
|
| 68 |
-
#pragma unroll
|
| 69 |
-
for (int i = 0; i < kWidth - 1; ++i) {
|
| 70 |
-
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
| 71 |
-
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
| 72 |
-
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
| 73 |
-
}
|
| 74 |
-
x_vals[i] = float(state_val);
|
| 75 |
-
}
|
| 76 |
-
} else {
|
| 77 |
-
#pragma unroll
|
| 78 |
-
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
| 79 |
-
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
| 80 |
-
x_vals[i] = float(state_val);
|
| 81 |
-
}
|
| 82 |
-
}
|
| 83 |
-
#pragma unroll 2
|
| 84 |
-
for (int i = 0; i < params.seqlen; ++i) {
|
| 85 |
-
input_t x_val = x[i * params.x_l_stride];
|
| 86 |
-
if constexpr (!kIsCircularBuffer) {
|
| 87 |
-
if (i < advance_len && state_len - advance_len + i >= 0) {
|
| 88 |
-
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
| 89 |
-
}
|
| 90 |
-
} else {
|
| 91 |
-
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
| 92 |
-
++update_idx;
|
| 93 |
-
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
| 94 |
-
}
|
| 95 |
-
x_vals[kWidth - 1] = float(x_val);
|
| 96 |
-
float out_val = bias_val;
|
| 97 |
-
#pragma unroll
|
| 98 |
-
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
| 99 |
-
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
| 100 |
-
out[i * params.out_l_stride] = input_t(out_val);
|
| 101 |
-
// Shift the input buffer by 1
|
| 102 |
-
#pragma unroll
|
| 103 |
-
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 108 |
-
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 109 |
-
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
| 110 |
-
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
| 111 |
-
auto kernel = params.cache_seqlens == nullptr
|
| 112 |
-
? &causal_conv1d_update_kernel<Ktraits, false>
|
| 113 |
-
: &causal_conv1d_update_kernel<Ktraits, true>;
|
| 114 |
-
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 115 |
-
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
template<typename input_t, typename weight_t>
|
| 119 |
-
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 120 |
-
if (params.width == 2) {
|
| 121 |
-
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
| 122 |
-
} else if (params.width == 3) {
|
| 123 |
-
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
| 124 |
-
} else if (params.width == 4) {
|
| 125 |
-
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
| 126 |
-
}
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 130 |
-
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 131 |
-
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 132 |
-
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 133 |
-
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 134 |
-
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 135 |
-
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 136 |
-
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 137 |
-
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal-conv1d/static_switch.h
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
| 2 |
-
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
/// @param COND - a boolean expression to switch by
|
| 7 |
-
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
| 8 |
-
/// @param ... - code to execute for true and false
|
| 9 |
-
///
|
| 10 |
-
/// Usage:
|
| 11 |
-
/// ```
|
| 12 |
-
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
| 13 |
-
/// some_function<BoolConst>(...);
|
| 14 |
-
/// });
|
| 15 |
-
/// ```
|
| 16 |
-
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
| 17 |
-
[&] { \
|
| 18 |
-
if (COND) { \
|
| 19 |
-
static constexpr bool CONST_NAME = true; \
|
| 20 |
-
return __VA_ARGS__(); \
|
| 21 |
-
} else { \
|
| 22 |
-
static constexpr bool CONST_NAME = false; \
|
| 23 |
-
return __VA_ARGS__(); \
|
| 24 |
-
} \
|
| 25 |
-
}()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flake.lock
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"nodes": {
|
| 3 |
-
"flake-compat": {
|
| 4 |
-
"locked": {
|
| 5 |
-
"lastModified": 1747046372,
|
| 6 |
-
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
-
"owner": "edolstra",
|
| 8 |
-
"repo": "flake-compat",
|
| 9 |
-
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
-
"type": "github"
|
| 11 |
-
},
|
| 12 |
-
"original": {
|
| 13 |
-
"owner": "edolstra",
|
| 14 |
-
"repo": "flake-compat",
|
| 15 |
-
"type": "github"
|
| 16 |
-
}
|
| 17 |
-
},
|
| 18 |
-
"flake-compat_2": {
|
| 19 |
-
"locked": {
|
| 20 |
-
"lastModified": 1747046372,
|
| 21 |
-
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 22 |
-
"owner": "edolstra",
|
| 23 |
-
"repo": "flake-compat",
|
| 24 |
-
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 25 |
-
"type": "github"
|
| 26 |
-
},
|
| 27 |
-
"original": {
|
| 28 |
-
"owner": "edolstra",
|
| 29 |
-
"repo": "flake-compat",
|
| 30 |
-
"type": "github"
|
| 31 |
-
}
|
| 32 |
-
},
|
| 33 |
-
"flake-utils": {
|
| 34 |
-
"inputs": {
|
| 35 |
-
"systems": "systems"
|
| 36 |
-
},
|
| 37 |
-
"locked": {
|
| 38 |
-
"lastModified": 1731533236,
|
| 39 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
-
"owner": "numtide",
|
| 41 |
-
"repo": "flake-utils",
|
| 42 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
-
"type": "github"
|
| 44 |
-
},
|
| 45 |
-
"original": {
|
| 46 |
-
"owner": "numtide",
|
| 47 |
-
"repo": "flake-utils",
|
| 48 |
-
"type": "github"
|
| 49 |
-
}
|
| 50 |
-
},
|
| 51 |
-
"flake-utils_2": {
|
| 52 |
-
"inputs": {
|
| 53 |
-
"systems": "systems_2"
|
| 54 |
-
},
|
| 55 |
-
"locked": {
|
| 56 |
-
"lastModified": 1731533236,
|
| 57 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
-
"owner": "numtide",
|
| 59 |
-
"repo": "flake-utils",
|
| 60 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
-
"type": "github"
|
| 62 |
-
},
|
| 63 |
-
"original": {
|
| 64 |
-
"owner": "numtide",
|
| 65 |
-
"repo": "flake-utils",
|
| 66 |
-
"type": "github"
|
| 67 |
-
}
|
| 68 |
-
},
|
| 69 |
-
"hf-nix": {
|
| 70 |
-
"inputs": {
|
| 71 |
-
"flake-compat": "flake-compat_2",
|
| 72 |
-
"flake-utils": "flake-utils_2",
|
| 73 |
-
"nixpkgs": "nixpkgs"
|
| 74 |
-
},
|
| 75 |
-
"locked": {
|
| 76 |
-
"lastModified": 1759493343,
|
| 77 |
-
"narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
|
| 78 |
-
"owner": "huggingface",
|
| 79 |
-
"repo": "hf-nix",
|
| 80 |
-
"rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
|
| 81 |
-
"type": "github"
|
| 82 |
-
},
|
| 83 |
-
"original": {
|
| 84 |
-
"owner": "huggingface",
|
| 85 |
-
"repo": "hf-nix",
|
| 86 |
-
"type": "github"
|
| 87 |
-
}
|
| 88 |
-
},
|
| 89 |
-
"kernel-builder": {
|
| 90 |
-
"inputs": {
|
| 91 |
-
"flake-compat": "flake-compat",
|
| 92 |
-
"flake-utils": "flake-utils",
|
| 93 |
-
"hf-nix": "hf-nix",
|
| 94 |
-
"nixpkgs": [
|
| 95 |
-
"kernel-builder",
|
| 96 |
-
"hf-nix",
|
| 97 |
-
"nixpkgs"
|
| 98 |
-
]
|
| 99 |
-
},
|
| 100 |
-
"locked": {
|
| 101 |
-
"lastModified": 1759516823,
|
| 102 |
-
"narHash": "sha256-UJVvZHtS9c64Dm4iZRaOKWB+VHI7jzcazGH57KXWeg8=",
|
| 103 |
-
"owner": "huggingface",
|
| 104 |
-
"repo": "kernel-builder",
|
| 105 |
-
"rev": "e13610a05f67b7296be9ead89ad172a0a088a1c3",
|
| 106 |
-
"type": "github"
|
| 107 |
-
},
|
| 108 |
-
"original": {
|
| 109 |
-
"owner": "huggingface",
|
| 110 |
-
"repo": "kernel-builder",
|
| 111 |
-
"type": "github"
|
| 112 |
-
}
|
| 113 |
-
},
|
| 114 |
-
"nixpkgs": {
|
| 115 |
-
"locked": {
|
| 116 |
-
"lastModified": 1755963616,
|
| 117 |
-
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
| 118 |
-
"owner": "nixos",
|
| 119 |
-
"repo": "nixpkgs",
|
| 120 |
-
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
| 121 |
-
"type": "github"
|
| 122 |
-
},
|
| 123 |
-
"original": {
|
| 124 |
-
"owner": "nixos",
|
| 125 |
-
"ref": "nixos-unstable-small",
|
| 126 |
-
"repo": "nixpkgs",
|
| 127 |
-
"type": "github"
|
| 128 |
-
}
|
| 129 |
-
},
|
| 130 |
-
"root": {
|
| 131 |
-
"inputs": {
|
| 132 |
-
"kernel-builder": "kernel-builder"
|
| 133 |
-
}
|
| 134 |
-
},
|
| 135 |
-
"systems": {
|
| 136 |
-
"locked": {
|
| 137 |
-
"lastModified": 1681028828,
|
| 138 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
-
"owner": "nix-systems",
|
| 140 |
-
"repo": "default",
|
| 141 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
-
"type": "github"
|
| 143 |
-
},
|
| 144 |
-
"original": {
|
| 145 |
-
"owner": "nix-systems",
|
| 146 |
-
"repo": "default",
|
| 147 |
-
"type": "github"
|
| 148 |
-
}
|
| 149 |
-
},
|
| 150 |
-
"systems_2": {
|
| 151 |
-
"locked": {
|
| 152 |
-
"lastModified": 1681028828,
|
| 153 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
-
"owner": "nix-systems",
|
| 155 |
-
"repo": "default",
|
| 156 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
-
"type": "github"
|
| 158 |
-
},
|
| 159 |
-
"original": {
|
| 160 |
-
"owner": "nix-systems",
|
| 161 |
-
"repo": "default",
|
| 162 |
-
"type": "github"
|
| 163 |
-
}
|
| 164 |
-
}
|
| 165 |
-
},
|
| 166 |
-
"root": "root",
|
| 167 |
-
"version": 7
|
| 168 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flake.nix
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
description = "Flake for attention kernels";
|
| 3 |
-
|
| 4 |
-
inputs = {
|
| 5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
-
};
|
| 7 |
-
|
| 8 |
-
outputs =
|
| 9 |
-
{
|
| 10 |
-
self,
|
| 11 |
-
kernel-builder,
|
| 12 |
-
}:
|
| 13 |
-
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
-
inherit self;
|
| 15 |
-
path = ./.;
|
| 16 |
-
pythonCheckInputs = pkgs: with pkgs; [ einops ];
|
| 17 |
-
};
|
| 18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_causal_conv1d.py
DELETED
|
@@ -1,353 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
import pytest
|
| 9 |
-
|
| 10 |
-
from einops import rearrange
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update, causal_conv1d_varlen_states
|
| 14 |
-
from causal_conv1d.causal_conv1d_interface import causal_conv1d_ref
|
| 15 |
-
from causal_conv1d.causal_conv1d_interface import causal_conv1d_update_ref
|
| 16 |
-
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states_ref
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@pytest.mark.parametrize("return_final_states", [False, True])
|
| 20 |
-
# @pytest.mark.parametrize("return_final_states", [True])
|
| 21 |
-
@pytest.mark.parametrize("has_initial_states", [False, True])
|
| 22 |
-
# @pytest.mark.parametrize("has_initial_states", [False])
|
| 23 |
-
@pytest.mark.parametrize("channel_last", [False, True])
|
| 24 |
-
# @pytest.mark.parametrize('channel_last', [True])
|
| 25 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 26 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 27 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 28 |
-
# @pytest.mark.parametrize('silu_activation', [True])
|
| 29 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
| 30 |
-
# @pytest.mark.parametrize('has_bias', [True])
|
| 31 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 32 |
-
# @pytest.mark.parametrize('width', [3])
|
| 33 |
-
@pytest.mark.parametrize(
|
| 34 |
-
"seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 35 |
-
)
|
| 36 |
-
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 37 |
-
# @pytest.mark.parametrize('seqlen', [128])
|
| 38 |
-
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 39 |
-
# @pytest.mark.parametrize('dim', [64])
|
| 40 |
-
def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
|
| 41 |
-
if not channel_last and (has_initial_states or return_final_states):
|
| 42 |
-
pytest.skip("Only channel_last support initial_states or return_final_states")
|
| 43 |
-
device = "cuda"
|
| 44 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 45 |
-
if itype == torch.bfloat16:
|
| 46 |
-
rtol, atol = 1e-2, 5e-2
|
| 47 |
-
rtolw, atolw = (1e-3, 1e-3)
|
| 48 |
-
# set seed
|
| 49 |
-
torch.random.manual_seed(0)
|
| 50 |
-
batch = 2
|
| 51 |
-
# batch = 1
|
| 52 |
-
if not channel_last:
|
| 53 |
-
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 54 |
-
else:
|
| 55 |
-
x = rearrange(
|
| 56 |
-
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 57 |
-
).requires_grad_()
|
| 58 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 59 |
-
if has_bias:
|
| 60 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 61 |
-
else:
|
| 62 |
-
bias = None
|
| 63 |
-
if has_initial_states:
|
| 64 |
-
initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
|
| 65 |
-
else:
|
| 66 |
-
initial_states = None
|
| 67 |
-
x_ref = x.detach().clone().requires_grad_()
|
| 68 |
-
weight_ref = weight.detach().clone().requires_grad_()
|
| 69 |
-
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 70 |
-
initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
|
| 71 |
-
activation = None if not silu_activation else "silu"
|
| 72 |
-
out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
|
| 73 |
-
activation=activation)
|
| 74 |
-
out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
|
| 75 |
-
if return_final_states:
|
| 76 |
-
out, final_states = out
|
| 77 |
-
out_ref, final_states_ref = out_ref
|
| 78 |
-
print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
|
| 79 |
-
print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
|
| 80 |
-
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
| 81 |
-
|
| 82 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 83 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 84 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 85 |
-
|
| 86 |
-
if return_final_states:
|
| 87 |
-
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
| 88 |
-
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
| 89 |
-
|
| 90 |
-
g = torch.randn_like(out)
|
| 91 |
-
out.backward(g)
|
| 92 |
-
out_ref.backward(g)
|
| 93 |
-
|
| 94 |
-
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 95 |
-
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 96 |
-
if has_bias:
|
| 97 |
-
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 98 |
-
if has_initial_states:
|
| 99 |
-
print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
|
| 100 |
-
|
| 101 |
-
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 102 |
-
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 103 |
-
if has_bias:
|
| 104 |
-
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
| 105 |
-
if has_initial_states:
|
| 106 |
-
assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 110 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 111 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 112 |
-
# @pytest.mark.parametrize('silu_activation', [True])
|
| 113 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
| 114 |
-
# @pytest.mark.parametrize('has_bias', [True])
|
| 115 |
-
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
| 116 |
-
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
| 117 |
-
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
| 118 |
-
# @pytest.mark.parametrize('seqlen', [4])
|
| 119 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 120 |
-
# @pytest.mark.parametrize('width', [4])
|
| 121 |
-
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 122 |
-
# @pytest.mark.parametrize("dim", [2048])
|
| 123 |
-
def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
| 124 |
-
device = "cuda"
|
| 125 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 126 |
-
if itype == torch.bfloat16:
|
| 127 |
-
rtol, atol = 1e-2, 5e-2
|
| 128 |
-
rtolw, atolw = (1e-3, 1e-3)
|
| 129 |
-
# set seed
|
| 130 |
-
torch.random.manual_seed(0)
|
| 131 |
-
batch = 64
|
| 132 |
-
# batch = 1
|
| 133 |
-
# dim = 64
|
| 134 |
-
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 135 |
-
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
| 136 |
-
conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 137 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 138 |
-
if has_bias:
|
| 139 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 140 |
-
else:
|
| 141 |
-
bias = None
|
| 142 |
-
conv_state_ref = conv_state.detach().clone()
|
| 143 |
-
activation = None if not silu_activation else "silu"
|
| 144 |
-
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
| 145 |
-
if has_cache_seqlens else None)
|
| 146 |
-
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 147 |
-
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 148 |
-
|
| 149 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 150 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 151 |
-
assert torch.equal(conv_state, conv_state_ref)
|
| 152 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 153 |
-
|
| 154 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 155 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 156 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 157 |
-
# @pytest.mark.parametrize('silu_activation', [True])
|
| 158 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
| 159 |
-
# @pytest.mark.parametrize('has_bias', [True])
|
| 160 |
-
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
| 161 |
-
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
| 162 |
-
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
| 163 |
-
# @pytest.mark.parametrize('seqlen', [4])
|
| 164 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 165 |
-
# @pytest.mark.parametrize('width', [4])
|
| 166 |
-
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 167 |
-
# @pytest.mark.parametrize("dim", [2048])
|
| 168 |
-
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
| 169 |
-
device = "cuda"
|
| 170 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 171 |
-
if itype == torch.bfloat16:
|
| 172 |
-
rtol, atol = 1e-2, 5e-2
|
| 173 |
-
rtolw, atolw = (1e-3, 1e-3)
|
| 174 |
-
# set seed
|
| 175 |
-
torch.random.manual_seed(0)
|
| 176 |
-
batch = 64
|
| 177 |
-
# batch = 1
|
| 178 |
-
# dim = 64
|
| 179 |
-
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 180 |
-
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
| 181 |
-
|
| 182 |
-
total_entries = 10 * batch
|
| 183 |
-
conv_state = torch.randn(total_entries, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 184 |
-
conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32, device=device)
|
| 185 |
-
|
| 186 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 187 |
-
if has_bias:
|
| 188 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 189 |
-
else:
|
| 190 |
-
bias = None
|
| 191 |
-
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
| 192 |
-
activation = None if not silu_activation else "silu"
|
| 193 |
-
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
| 194 |
-
if has_cache_seqlens else None)
|
| 195 |
-
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation,
|
| 196 |
-
cache_seqlens=cache_seqlens, conv_state_indices=conv_state_indices)
|
| 197 |
-
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 198 |
-
|
| 199 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 200 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 201 |
-
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
| 202 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 206 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 207 |
-
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 208 |
-
# @pytest.mark.parametrize("dim", [2048])
|
| 209 |
-
def test_causal_conv1d_get_states(dim, itype):
|
| 210 |
-
device = "cuda"
|
| 211 |
-
# set seed
|
| 212 |
-
torch.random.manual_seed(0)
|
| 213 |
-
seqlens = torch.randint(1, 32, (100,), device=device)
|
| 214 |
-
total_seqlen = seqlens.sum().item()
|
| 215 |
-
x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
|
| 216 |
-
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
|
| 217 |
-
state_len = 20
|
| 218 |
-
out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
|
| 219 |
-
out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
|
| 220 |
-
assert torch.equal(out, out_ref)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
# @pytest.mark.parametrize("channel_last", [False, True])
|
| 224 |
-
@pytest.mark.parametrize('channel_last', [True])
|
| 225 |
-
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 226 |
-
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
| 227 |
-
# @pytest.mark.parametrize("silu_activation", [False, True])
|
| 228 |
-
@pytest.mark.parametrize('silu_activation', [True])
|
| 229 |
-
# @pytest.mark.parametrize("has_bias", [False, True])
|
| 230 |
-
@pytest.mark.parametrize('has_bias', [True])
|
| 231 |
-
# @pytest.mark.parametrize("width", [2, 3, 4])
|
| 232 |
-
@pytest.mark.parametrize('width', [4])
|
| 233 |
-
@pytest.mark.parametrize(
|
| 234 |
-
# "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 235 |
-
"seqlen", [2048]
|
| 236 |
-
)
|
| 237 |
-
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 238 |
-
# @pytest.mark.parametrize('seqlen', [128])
|
| 239 |
-
def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
| 240 |
-
device = "cuda"
|
| 241 |
-
# set seed
|
| 242 |
-
torch.random.manual_seed(0)
|
| 243 |
-
batch = 2
|
| 244 |
-
# batch = 1
|
| 245 |
-
dim = 4096 + 32 # Try dim not divisible by 64
|
| 246 |
-
# dim = 64
|
| 247 |
-
if not channel_last:
|
| 248 |
-
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 249 |
-
else:
|
| 250 |
-
x = rearrange(
|
| 251 |
-
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 252 |
-
).requires_grad_()
|
| 253 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 254 |
-
if has_bias:
|
| 255 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 256 |
-
else:
|
| 257 |
-
bias = None
|
| 258 |
-
activation = None if not silu_activation else "silu"
|
| 259 |
-
out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 260 |
-
g = torch.randn_like(out0)
|
| 261 |
-
dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
|
| 262 |
-
dw_atol = 1e-4
|
| 263 |
-
db_atol = 1e-4
|
| 264 |
-
|
| 265 |
-
for i in range(10000):
|
| 266 |
-
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 267 |
-
dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
|
| 268 |
-
dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
|
| 269 |
-
# if not dw_equal:
|
| 270 |
-
# breakpoint()
|
| 271 |
-
if has_bias:
|
| 272 |
-
db_equal = torch.allclose(db, db0, atol=db_atol)
|
| 273 |
-
# if not db_equal:
|
| 274 |
-
# breakpoint()
|
| 275 |
-
assert torch.equal(out, out0)
|
| 276 |
-
assert torch.equal(dx, dx0)
|
| 277 |
-
assert dw_equal
|
| 278 |
-
if has_bias:
|
| 279 |
-
assert dw_equal
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 283 |
-
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 284 |
-
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 285 |
-
# @pytest.mark.parametrize('silu_activation', [False])
|
| 286 |
-
@pytest.mark.parametrize("has_bias", [False, True])
|
| 287 |
-
# @pytest.mark.parametrize('has_bias', [False])
|
| 288 |
-
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 289 |
-
# @pytest.mark.parametrize('width', [2])
|
| 290 |
-
@pytest.mark.parametrize(
|
| 291 |
-
"seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 292 |
-
)
|
| 293 |
-
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 294 |
-
# @pytest.mark.parametrize('seqlen', [2048])
|
| 295 |
-
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 296 |
-
# @pytest.mark.parametrize('dim', [64])
|
| 297 |
-
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
|
| 298 |
-
device = "cuda"
|
| 299 |
-
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 300 |
-
if itype == torch.bfloat16:
|
| 301 |
-
rtol, atol = 1e-2, 5e-2
|
| 302 |
-
rtolw, atolw = (1e-3, 1e-3)
|
| 303 |
-
# set seed
|
| 304 |
-
torch.random.manual_seed(seqlen + dim + width)
|
| 305 |
-
batch = 3
|
| 306 |
-
seqlens = []
|
| 307 |
-
for b in range(batch):
|
| 308 |
-
nsplits = torch.randint(1, 5, (1,)).item()
|
| 309 |
-
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
| 310 |
-
seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
|
| 311 |
-
assert sum(seqlens[-1]) == seqlen
|
| 312 |
-
assert all(s > 0 for s in seqlens[-1])
|
| 313 |
-
# Only support channel_last
|
| 314 |
-
x = rearrange(
|
| 315 |
-
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 316 |
-
).requires_grad_()
|
| 317 |
-
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 318 |
-
if has_bias:
|
| 319 |
-
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 320 |
-
else:
|
| 321 |
-
bias = None
|
| 322 |
-
seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
|
| 323 |
-
for sl in seqlens], dim=0)
|
| 324 |
-
x_ref = x.detach().clone().requires_grad_()
|
| 325 |
-
weight_ref = weight.detach().clone().requires_grad_()
|
| 326 |
-
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 327 |
-
activation = None if not silu_activation else "silu"
|
| 328 |
-
out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
|
| 329 |
-
out_ref = []
|
| 330 |
-
for b in range(batch):
|
| 331 |
-
out_ref_b = []
|
| 332 |
-
for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
|
| 333 |
-
out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
|
| 334 |
-
out_ref.append(torch.cat(out_ref_b, dim=2))
|
| 335 |
-
out_ref = torch.cat(out_ref, dim=0)
|
| 336 |
-
|
| 337 |
-
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 338 |
-
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 339 |
-
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 340 |
-
|
| 341 |
-
g = torch.randn_like(out)
|
| 342 |
-
out_ref.backward(g)
|
| 343 |
-
out.backward(g)
|
| 344 |
-
|
| 345 |
-
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 346 |
-
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 347 |
-
if has_bias:
|
| 348 |
-
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 349 |
-
|
| 350 |
-
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 351 |
-
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 352 |
-
if has_bias:
|
| 353 |
-
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/causal_conv1d/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
| 2 |
-
from .causal_conv1d_varlen import causal_conv1d_varlen_states
|
| 3 |
-
|
| 4 |
-
__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/causal_conv1d/causal_conv1d_interface.py
DELETED
|
@@ -1,242 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
|
| 6 |
-
from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class CausalConv1dFn(torch.autograd.Function):
|
| 10 |
-
@staticmethod
|
| 11 |
-
def forward(
|
| 12 |
-
ctx,
|
| 13 |
-
x,
|
| 14 |
-
weight,
|
| 15 |
-
bias=None,
|
| 16 |
-
seq_idx=None,
|
| 17 |
-
initial_states=None,
|
| 18 |
-
return_final_states=False,
|
| 19 |
-
final_states_out=None,
|
| 20 |
-
activation=None,
|
| 21 |
-
):
|
| 22 |
-
if activation not in [None, "silu", "swish"]:
|
| 23 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
| 24 |
-
if x.stride(2) != 1 and x.stride(1) != 1:
|
| 25 |
-
x = x.contiguous()
|
| 26 |
-
bias = bias.contiguous() if bias is not None else None
|
| 27 |
-
if seq_idx is not None:
|
| 28 |
-
assert (
|
| 29 |
-
initial_states is None
|
| 30 |
-
), "initial_states must be None if seq_idx is not None"
|
| 31 |
-
assert (
|
| 32 |
-
not return_final_states
|
| 33 |
-
), "If seq_idx is not None, we don't return final_states_out"
|
| 34 |
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 35 |
-
if initial_states is not None and (
|
| 36 |
-
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
| 37 |
-
):
|
| 38 |
-
initial_states = initial_states.contiguous()
|
| 39 |
-
if return_final_states:
|
| 40 |
-
assert (
|
| 41 |
-
x.stride(1) == 1
|
| 42 |
-
), "Only channel-last layout support returning final_states_out"
|
| 43 |
-
if final_states_out is not None:
|
| 44 |
-
assert (
|
| 45 |
-
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
| 46 |
-
)
|
| 47 |
-
else:
|
| 48 |
-
batch, dim, seqlen = x.shape
|
| 49 |
-
width = weight.shape[1]
|
| 50 |
-
final_states_out = torch.empty(
|
| 51 |
-
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
| 52 |
-
).transpose(1, 2)
|
| 53 |
-
else:
|
| 54 |
-
final_states_out = None
|
| 55 |
-
ctx.activation = activation in ["silu", "swish"]
|
| 56 |
-
out = causal_conv1d_fwd_function(
|
| 57 |
-
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
| 58 |
-
)
|
| 59 |
-
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
| 60 |
-
ctx.return_final_states = return_final_states
|
| 61 |
-
ctx.return_dinitial_states = (
|
| 62 |
-
initial_states is not None and initial_states.requires_grad
|
| 63 |
-
)
|
| 64 |
-
return out if not return_final_states else (out, final_states_out)
|
| 65 |
-
|
| 66 |
-
@staticmethod
|
| 67 |
-
def backward(ctx, dout, *args):
|
| 68 |
-
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
| 69 |
-
dfinal_states = args[0] if ctx.return_final_states else None
|
| 70 |
-
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
| 71 |
-
dout = dout.contiguous()
|
| 72 |
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 73 |
-
# backward of conv1d with the backward of chunk).
|
| 74 |
-
# Here we just pass in None and dx will be allocated in the C++ code.
|
| 75 |
-
dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
|
| 76 |
-
x,
|
| 77 |
-
weight,
|
| 78 |
-
bias,
|
| 79 |
-
dout,
|
| 80 |
-
seq_idx,
|
| 81 |
-
initial_states,
|
| 82 |
-
dfinal_states,
|
| 83 |
-
None,
|
| 84 |
-
ctx.return_dinitial_states,
|
| 85 |
-
ctx.activation,
|
| 86 |
-
)
|
| 87 |
-
return (
|
| 88 |
-
dx,
|
| 89 |
-
dweight,
|
| 90 |
-
dbias if bias is not None else None,
|
| 91 |
-
None,
|
| 92 |
-
dinitial_states if initial_states is not None else None,
|
| 93 |
-
None,
|
| 94 |
-
None,
|
| 95 |
-
None,
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def causal_conv1d_fn(
|
| 100 |
-
x,
|
| 101 |
-
weight,
|
| 102 |
-
bias=None,
|
| 103 |
-
seq_idx=None,
|
| 104 |
-
initial_states=None,
|
| 105 |
-
return_final_states=False,
|
| 106 |
-
final_states_out=None,
|
| 107 |
-
activation=None,
|
| 108 |
-
):
|
| 109 |
-
"""
|
| 110 |
-
x: (batch, dim, seqlen)
|
| 111 |
-
weight: (dim, width)
|
| 112 |
-
bias: (dim,)
|
| 113 |
-
seq_idx: (batch, seqlen)
|
| 114 |
-
initial_states: (batch, dim, width - 1)
|
| 115 |
-
final_states_out: (batch, dim, width - 1), to be written to
|
| 116 |
-
activation: either None or "silu" or "swish"
|
| 117 |
-
|
| 118 |
-
out: (batch, dim, seqlen)
|
| 119 |
-
"""
|
| 120 |
-
return CausalConv1dFn.apply(
|
| 121 |
-
x,
|
| 122 |
-
weight,
|
| 123 |
-
bias,
|
| 124 |
-
seq_idx,
|
| 125 |
-
initial_states,
|
| 126 |
-
return_final_states,
|
| 127 |
-
final_states_out,
|
| 128 |
-
activation,
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def causal_conv1d_ref(
|
| 133 |
-
x,
|
| 134 |
-
weight,
|
| 135 |
-
bias=None,
|
| 136 |
-
initial_states=None,
|
| 137 |
-
return_final_states=False,
|
| 138 |
-
final_states_out=None,
|
| 139 |
-
activation=None,
|
| 140 |
-
):
|
| 141 |
-
"""
|
| 142 |
-
x: (batch, dim, seqlen)
|
| 143 |
-
weight: (dim, width)
|
| 144 |
-
bias: (dim,)
|
| 145 |
-
initial_states: (batch, dim, width - 1)
|
| 146 |
-
final_states_out: (batch, dim, width - 1)
|
| 147 |
-
|
| 148 |
-
out: (batch, dim, seqlen)
|
| 149 |
-
"""
|
| 150 |
-
if activation not in [None, "silu", "swish"]:
|
| 151 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
| 152 |
-
dtype_in = x.dtype
|
| 153 |
-
x = x.to(weight.dtype)
|
| 154 |
-
seqlen = x.shape[-1]
|
| 155 |
-
dim, width = weight.shape
|
| 156 |
-
if initial_states is None:
|
| 157 |
-
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
| 158 |
-
else:
|
| 159 |
-
x = torch.cat([initial_states, x], dim=-1)
|
| 160 |
-
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
| 161 |
-
out = out[..., :seqlen]
|
| 162 |
-
if return_final_states:
|
| 163 |
-
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
| 164 |
-
dtype_in
|
| 165 |
-
) # (batch, dim, width - 1)
|
| 166 |
-
if final_states_out is not None:
|
| 167 |
-
final_states_out.copy_(final_states)
|
| 168 |
-
else:
|
| 169 |
-
final_states_out = final_states
|
| 170 |
-
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
| 171 |
-
return out if not return_final_states else (out, final_states_out)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
|
| 175 |
-
"""
|
| 176 |
-
x: (batch, dim) or (batch, dim, seqlen)
|
| 177 |
-
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 178 |
-
weight: (dim, width)
|
| 179 |
-
bias: (dim,)
|
| 180 |
-
cache_seqlens: (batch,), dtype int32.
|
| 181 |
-
If not None, the conv_state is treated as a circular buffer.
|
| 182 |
-
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 183 |
-
@cache_seqlens % state_len.
|
| 184 |
-
conv_state_indices: (batch,), dtype int32
|
| 185 |
-
If None, the conv_state is a larger tensor along the batch dim,
|
| 186 |
-
and we are selecting the batch coords specified by conv_state_indices.
|
| 187 |
-
Useful for a continuous batching scenario.
|
| 188 |
-
|
| 189 |
-
out: (batch, dim) or (batch, dim, seqlen)
|
| 190 |
-
"""
|
| 191 |
-
if activation not in [None, "silu", "swish"]:
|
| 192 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
| 193 |
-
activation = activation in ["silu", "swish"]
|
| 194 |
-
unsqueeze = x.dim() == 2
|
| 195 |
-
if unsqueeze:
|
| 196 |
-
x = x.unsqueeze(-1)
|
| 197 |
-
out = causal_conv1d_update_function(
|
| 198 |
-
x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
|
| 199 |
-
)
|
| 200 |
-
if unsqueeze:
|
| 201 |
-
out = out.squeeze(-1)
|
| 202 |
-
return out
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 206 |
-
"""
|
| 207 |
-
x: (batch, dim) or (batch, dim, seqlen)
|
| 208 |
-
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 209 |
-
weight: (dim, width)
|
| 210 |
-
bias: (dim,)
|
| 211 |
-
cache_seqlens: (batch,), dtype int32.
|
| 212 |
-
If not None, the conv_state is treated as a circular buffer.
|
| 213 |
-
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 214 |
-
@cache_seqlens % state_len before performing the convolution.
|
| 215 |
-
|
| 216 |
-
out: (batch, dim) or (batch, dim, seqlen)
|
| 217 |
-
"""
|
| 218 |
-
if activation not in [None, "silu", "swish"]:
|
| 219 |
-
raise NotImplementedError("activation must be None, silu, or swish")
|
| 220 |
-
dtype_in = x.dtype
|
| 221 |
-
unsqueeze = x.dim() == 2
|
| 222 |
-
if unsqueeze:
|
| 223 |
-
x = x.unsqueeze(-1)
|
| 224 |
-
batch, dim, seqlen = x.shape
|
| 225 |
-
width = weight.shape[1]
|
| 226 |
-
state_len = conv_state.shape[-1]
|
| 227 |
-
assert conv_state.shape == (batch, dim, state_len)
|
| 228 |
-
assert weight.shape == (dim, width)
|
| 229 |
-
if cache_seqlens is None:
|
| 230 |
-
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
| 231 |
-
conv_state.copy_(x_new[:, :, -state_len:])
|
| 232 |
-
else:
|
| 233 |
-
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 234 |
-
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 235 |
-
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
| 236 |
-
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 237 |
-
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 238 |
-
conv_state.scatter_(2, copy_idx, x)
|
| 239 |
-
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
| 240 |
-
if unsqueeze:
|
| 241 |
-
out = out.squeeze(-1)
|
| 242 |
-
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/causal_conv1d/causal_conv1d_varlen.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import Tensor
|
| 3 |
-
|
| 4 |
-
import triton
|
| 5 |
-
import triton.language as tl
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@triton.jit
|
| 9 |
-
def _causal_conv1d_varlen_states(
|
| 10 |
-
X,
|
| 11 |
-
CU_SEQLENS,
|
| 12 |
-
STATES,
|
| 13 |
-
state_len,
|
| 14 |
-
dim,
|
| 15 |
-
stride_x_seqlen, stride_x_dim,
|
| 16 |
-
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
| 17 |
-
BLOCK_M: tl.constexpr,
|
| 18 |
-
BLOCK_N: tl.constexpr
|
| 19 |
-
):
|
| 20 |
-
batch_idx = tl.program_id(2)
|
| 21 |
-
STATES += batch_idx * stride_states_batch
|
| 22 |
-
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
| 23 |
-
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
| 24 |
-
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 25 |
-
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 26 |
-
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
| 27 |
-
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
| 28 |
-
other=0)
|
| 29 |
-
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 30 |
-
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
| 31 |
-
x,
|
| 32 |
-
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 36 |
-
"""
|
| 37 |
-
Forward pass only, does not support backward pass.
|
| 38 |
-
Parameters:
|
| 39 |
-
x: (total_tokens, dim)
|
| 40 |
-
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 41 |
-
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 42 |
-
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 43 |
-
Return:
|
| 44 |
-
states: (batch, dim, state_len)
|
| 45 |
-
"""
|
| 46 |
-
_, dim = x.shape
|
| 47 |
-
batch = cu_seqlens.shape[0] - 1
|
| 48 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 49 |
-
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 50 |
-
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
| 51 |
-
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
| 52 |
-
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
| 53 |
-
with torch.cuda.device(x.device.index):
|
| 54 |
-
_causal_conv1d_varlen_states[grid](
|
| 55 |
-
x,
|
| 56 |
-
cu_seqlens,
|
| 57 |
-
states,
|
| 58 |
-
state_len,
|
| 59 |
-
dim,
|
| 60 |
-
x.stride(0), x.stride(1),
|
| 61 |
-
states.stride(0), states.stride(2), states.stride(1),
|
| 62 |
-
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
| 63 |
-
)
|
| 64 |
-
return states
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 68 |
-
"""
|
| 69 |
-
Forward pass only, does not support backward pass.
|
| 70 |
-
Parameters:
|
| 71 |
-
x: (total_tokens, dim)
|
| 72 |
-
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 73 |
-
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 74 |
-
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 75 |
-
Return:
|
| 76 |
-
states: (batch, dim, state_len)
|
| 77 |
-
"""
|
| 78 |
-
_, dim = x.shape
|
| 79 |
-
batch = cu_seqlens.shape[0] - 1
|
| 80 |
-
cu_seqlens = cu_seqlens.contiguous()
|
| 81 |
-
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 82 |
-
for i in range(batch):
|
| 83 |
-
end_idx = cu_seqlens[i + 1]
|
| 84 |
-
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
| 85 |
-
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
| 86 |
-
return states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/causal_conv1d/cpp_functions.py
DELETED
|
@@ -1,96 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, Tri Dao.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
from ._ops import ops
|
| 6 |
-
|
| 7 |
-
def causal_conv1d_fwd_function(
|
| 8 |
-
x: torch.Tensor,
|
| 9 |
-
weight: torch.Tensor,
|
| 10 |
-
bias: torch.Tensor | None,
|
| 11 |
-
seq_idx: torch.Tensor | None,
|
| 12 |
-
initial_states: torch.Tensor | None,
|
| 13 |
-
final_states_out: torch.Tensor | None,
|
| 14 |
-
silu_activation: bool,
|
| 15 |
-
) -> torch.Tensor:
|
| 16 |
-
out = torch.empty_like(x)
|
| 17 |
-
ops.causal_conv1d_fwd(
|
| 18 |
-
x=x,
|
| 19 |
-
weight=weight,
|
| 20 |
-
bias=bias,
|
| 21 |
-
seq_idx=seq_idx,
|
| 22 |
-
initial_states=initial_states,
|
| 23 |
-
out=out,
|
| 24 |
-
final_states_out=final_states_out,
|
| 25 |
-
silu_activation=silu_activation,
|
| 26 |
-
)
|
| 27 |
-
return out
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def causal_conv1d_bwd_function(
|
| 31 |
-
x: torch.Tensor,
|
| 32 |
-
weight: torch.Tensor,
|
| 33 |
-
bias: torch.Tensor | None,
|
| 34 |
-
dout: torch.Tensor,
|
| 35 |
-
seq_idx: torch.Tensor | None,
|
| 36 |
-
initial_states: torch.Tensor | None,
|
| 37 |
-
dfinal_states: torch.Tensor | None,
|
| 38 |
-
dx: torch.Tensor | None,
|
| 39 |
-
return_dinitial_states: torch.Tensor,
|
| 40 |
-
silu_activation: bool,
|
| 41 |
-
) -> tuple[torch.Tensor | None]:
|
| 42 |
-
batch_size, dim = x.size()[:2]
|
| 43 |
-
width = weight.size(-1)
|
| 44 |
-
|
| 45 |
-
if dx is None:
|
| 46 |
-
dx = torch.empty_like(x)
|
| 47 |
-
dweight = torch.zeros_like(weight, dtype=torch.float32)
|
| 48 |
-
dbias = None
|
| 49 |
-
if bias is not None:
|
| 50 |
-
dbias = torch.zeros_like(bias, dtype=torch.float32)
|
| 51 |
-
dinitial_states = None
|
| 52 |
-
if return_dinitial_states:
|
| 53 |
-
dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
|
| 54 |
-
|
| 55 |
-
ops.causal_conv1d_bwd(
|
| 56 |
-
x=x,
|
| 57 |
-
weight=weight,
|
| 58 |
-
bias=bias,
|
| 59 |
-
dout=dout,
|
| 60 |
-
seq_idx=seq_idx,
|
| 61 |
-
initial_states=initial_states,
|
| 62 |
-
dfinal_states=dfinal_states,
|
| 63 |
-
dx=dx,
|
| 64 |
-
dweight=dweight,
|
| 65 |
-
dbias=dbias,
|
| 66 |
-
dinitial_states=dinitial_states,
|
| 67 |
-
silu_activation=silu_activation,
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
dweight = dweight.type_as(weight)
|
| 71 |
-
if dbias is not None:
|
| 72 |
-
dbias = dbias.type_as(bias)
|
| 73 |
-
return dx, dweight, dbias, dinitial_states
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def causal_conv1d_update_function(
|
| 77 |
-
x: torch.Tensor,
|
| 78 |
-
conv_state: torch.Tensor,
|
| 79 |
-
weight: torch.Tensor,
|
| 80 |
-
bias: torch.Tensor | None,
|
| 81 |
-
silu_activation: bool,
|
| 82 |
-
cache_seqlens: torch.Tensor | None,
|
| 83 |
-
conv_state_indices: torch.Tensor | None,
|
| 84 |
-
) -> torch.Tensor:
|
| 85 |
-
out = torch.empty_like(x)
|
| 86 |
-
ops.causal_conv1d_update(
|
| 87 |
-
x=x,
|
| 88 |
-
conv_state=conv_state,
|
| 89 |
-
weight=weight,
|
| 90 |
-
bias=bias,
|
| 91 |
-
out=out,
|
| 92 |
-
silu_activation=silu_activation,
|
| 93 |
-
cache_seqlens=cache_seqlens,
|
| 94 |
-
conv_state_indices=conv_state_indices,
|
| 95 |
-
)
|
| 96 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/pytorch_shim.h
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
|
| 3 |
-
#include <torch/library.h>
|
| 4 |
-
|
| 5 |
-
/**
|
| 6 |
-
* Unforunately, the type signatures of the flash_attn ops are not compatible
|
| 7 |
-
* with the PyTorch library bindings. To get around that we use
|
| 8 |
-
* `make_pytorch_shim` which creates a lambda that exponses the API using
|
| 9 |
-
* PyTorch compatible types to the types, then converts them to the types
|
| 10 |
-
* expected by the flash_attn ops. This shims allows us to make minimal changes
|
| 11 |
-
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
|
| 12 |
-
*
|
| 13 |
-
* The `pytorch_library_compatible_type` struct is used to map from the
|
| 14 |
-
* flash_attn ops types to a PyTorch library compatible one. The main issues is
|
| 15 |
-
* that the following types are not support by PyTorch libary bindings:
|
| 16 |
-
* - `int`
|
| 17 |
-
* - `float`
|
| 18 |
-
* - `std::optional<T> &`
|
| 19 |
-
* - `std::optional<const at::Tensor> &`
|
| 20 |
-
* So we convert them to (respectively):
|
| 21 |
-
* - `int64_t`
|
| 22 |
-
* - `double`
|
| 23 |
-
* - `const std::optional<T>&`
|
| 24 |
-
* - `const std::optional<at::Tensor>&`
|
| 25 |
-
*/
|
| 26 |
-
|
| 27 |
-
template<typename T>
|
| 28 |
-
struct pytorch_library_compatible_type {
|
| 29 |
-
using type = T;
|
| 30 |
-
static T convert_from_type(T arg) { return arg; }
|
| 31 |
-
};
|
| 32 |
-
|
| 33 |
-
template<typename T>
|
| 34 |
-
using pytorch_library_compatible_type_t = \
|
| 35 |
-
typename pytorch_library_compatible_type<T>::type;
|
| 36 |
-
|
| 37 |
-
template<typename T>
|
| 38 |
-
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg)
|
| 39 |
-
{ return pytorch_library_compatible_type<T>::convert_from_type(arg); }
|
| 40 |
-
|
| 41 |
-
// Map `std::optional<T> &` -> `const std::optional<T>&`
|
| 42 |
-
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
|
| 43 |
-
// the optional container)
|
| 44 |
-
template<typename T>
|
| 45 |
-
struct pytorch_library_compatible_type<std::optional<T> &> {
|
| 46 |
-
using type = const std::optional<T>&;
|
| 47 |
-
static std::optional<T>& convert_from_type(const std::optional<T> &arg) {
|
| 48 |
-
return const_cast<std::optional<T>&>(arg);
|
| 49 |
-
}
|
| 50 |
-
};
|
| 51 |
-
|
| 52 |
-
// Map `std::optional<T>` ->
|
| 53 |
-
// `std::optional<pytorch_library_compatible_type_t<T>>`
|
| 54 |
-
// (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
|
| 55 |
-
template<typename T>
|
| 56 |
-
struct pytorch_library_compatible_type<std::optional<T>> {
|
| 57 |
-
using type = std::optional<pytorch_library_compatible_type_t<T>>;
|
| 58 |
-
static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(std::optional<T> arg) {
|
| 59 |
-
return arg;
|
| 60 |
-
}
|
| 61 |
-
};
|
| 62 |
-
|
| 63 |
-
// Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
|
| 64 |
-
template<>
|
| 65 |
-
struct pytorch_library_compatible_type<std::optional<const at::Tensor> &> {
|
| 66 |
-
using type = const std::optional<at::Tensor>&;
|
| 67 |
-
static std::optional<const at::Tensor>& convert_from_type(
|
| 68 |
-
const std::optional<at::Tensor> &arg) {
|
| 69 |
-
return const_cast<std::optional<const at::Tensor>&>(
|
| 70 |
-
reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
|
| 71 |
-
}
|
| 72 |
-
};
|
| 73 |
-
|
| 74 |
-
// Map `int` -> `int64_t`
|
| 75 |
-
template<> struct pytorch_library_compatible_type<int> {
|
| 76 |
-
using type = int64_t;
|
| 77 |
-
static int convert_from_type(int64_t arg) {
|
| 78 |
-
TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
|
| 79 |
-
"int64_t value is too large to be converted to int");
|
| 80 |
-
TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
|
| 81 |
-
"int64_t value is too small to be converted to int");
|
| 82 |
-
return arg;
|
| 83 |
-
}
|
| 84 |
-
};
|
| 85 |
-
|
| 86 |
-
// Map `float` -> `double`
|
| 87 |
-
template<> struct pytorch_library_compatible_type<float> {
|
| 88 |
-
using type = double;
|
| 89 |
-
static float convert_from_type(double arg) {
|
| 90 |
-
TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
|
| 91 |
-
"double value is too large to be converted to float");
|
| 92 |
-
return arg;
|
| 93 |
-
}
|
| 94 |
-
};
|
| 95 |
-
|
| 96 |
-
//
|
| 97 |
-
// Shim Utils
|
| 98 |
-
//
|
| 99 |
-
|
| 100 |
-
template <typename Ret, typename... Args>
|
| 101 |
-
auto make_pytorch_shim(Ret(*fun)(Args... args)){
|
| 102 |
-
return [fun](pytorch_library_compatible_type_t<Args>... args) {
|
| 103 |
-
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
|
| 104 |
-
};
|
| 105 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/torch_binding.cpp
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
#include <torch/library.h>
|
| 2 |
-
|
| 3 |
-
#include "registration.h"
|
| 4 |
-
|
| 5 |
-
#include "pytorch_shim.h"
|
| 6 |
-
#include "torch_binding.h"
|
| 7 |
-
|
| 8 |
-
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 9 |
-
ops.def(
|
| 10 |
-
"causal_conv1d_fwd("
|
| 11 |
-
" Tensor x, Tensor weight, Tensor? bias, Tensor? seq_idx,"
|
| 12 |
-
" Tensor? initial_states, Tensor! out, Tensor!? final_states_out,"
|
| 13 |
-
" bool silu_activation) -> ()");
|
| 14 |
-
ops.impl("causal_conv1d_fwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_fwd));
|
| 15 |
-
|
| 16 |
-
ops.def(
|
| 17 |
-
"causal_conv1d_bwd("
|
| 18 |
-
" Tensor x, Tensor weight, Tensor? bias, Tensor! dout,"
|
| 19 |
-
" Tensor? seq_idx, Tensor? initial_states, Tensor? dfinal_states,"
|
| 20 |
-
" Tensor! dx, Tensor! dweight, Tensor!? dbias,"
|
| 21 |
-
" Tensor!? dinitial_states, bool silu_activation) -> ()");
|
| 22 |
-
ops.impl("causal_conv1d_bwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_bwd));
|
| 23 |
-
|
| 24 |
-
ops.def(
|
| 25 |
-
"causal_conv1d_update("
|
| 26 |
-
" Tensor x, Tensor conv_state, Tensor weight, Tensor? bias,"
|
| 27 |
-
" Tensor! out, bool silu_activation, Tensor? cache_seqlens,"
|
| 28 |
-
" Tensor? conv_state_indices) -> ()");
|
| 29 |
-
ops.impl("causal_conv1d_update", torch::kCUDA, make_pytorch_shim(&causal_conv1d_update));
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/torch_binding.h
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
|
| 3 |
-
#include <torch/torch.h>
|
| 4 |
-
|
| 5 |
-
void
|
| 6 |
-
causal_conv1d_fwd(const at::Tensor &x,
|
| 7 |
-
const at::Tensor &weight,
|
| 8 |
-
const c10::optional<at::Tensor> &bias_,
|
| 9 |
-
const c10::optional<at::Tensor> &seq_idx_,
|
| 10 |
-
const c10::optional<at::Tensor> &initial_states_,
|
| 11 |
-
at::Tensor &out,
|
| 12 |
-
c10::optional<at::Tensor> &final_states_out_,
|
| 13 |
-
bool silu_activation);
|
| 14 |
-
|
| 15 |
-
void
|
| 16 |
-
causal_conv1d_bwd(const at::Tensor &x,
|
| 17 |
-
const at::Tensor &weight,
|
| 18 |
-
const c10::optional<at::Tensor> &bias_,
|
| 19 |
-
at::Tensor &dout,
|
| 20 |
-
const c10::optional<at::Tensor> &seq_idx_,
|
| 21 |
-
const c10::optional<at::Tensor> &initial_states_,
|
| 22 |
-
const c10::optional<at::Tensor> &dfinal_states_,
|
| 23 |
-
at::Tensor &dx,
|
| 24 |
-
at::Tensor &dweight,
|
| 25 |
-
c10::optional<at::Tensor> &dbias_,
|
| 26 |
-
c10::optional<at::Tensor> &dinitial_states_,
|
| 27 |
-
bool silu_activation);
|
| 28 |
-
|
| 29 |
-
void
|
| 30 |
-
causal_conv1d_update(const at::Tensor &x,
|
| 31 |
-
const at::Tensor &conv_state,
|
| 32 |
-
const at::Tensor &weight,
|
| 33 |
-
const c10::optional<at::Tensor> &bias_,
|
| 34 |
-
at::Tensor &out,
|
| 35 |
-
bool silu_activation,
|
| 36 |
-
const c10::optional<at::Tensor> &cache_seqlens_,
|
| 37 |
-
const c10::optional<at::Tensor> &conv_state_indices_
|
| 38 |
-
);
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|