Kernels
danieldk HF Staff commited on
Commit
2a51f9d
·
1 Parent(s): e786cb7

Sources moved to GitHub.

Browse files
README.md CHANGED
@@ -6,5 +6,6 @@ tags:
6
 
7
  ## causal-conv1d
8
 
9
- Causal depthwise conv1d kernel by Tri Dao. Source: https://github.com/Dao-AILab/causal-conv1d/
10
 
 
 
6
 
7
  ## causal-conv1d
8
 
9
+ Causal [depthwise conv1d kernel](https://github.com/Dao-AILab/causal-conv1d/) by Tri Dao.
10
 
11
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/causal-conv1d
build.toml DELETED
@@ -1,49 +0,0 @@
1
- [general]
2
- name = "causal_conv1d"
3
- universal = false
4
-
5
- [torch]
6
- src = [
7
- "torch-ext/pytorch_shim.h",
8
- "torch-ext/torch_binding.cpp",
9
- "torch-ext/torch_binding.h"
10
- ]
11
-
12
- [kernel.causal_conv1d]
13
- backend = "cuda"
14
- src = [
15
- "causal-conv1d/causal_conv1d_bwd.cu",
16
- "causal-conv1d/causal_conv1d_common.h",
17
- "causal-conv1d/causal_conv1d.cpp",
18
- "causal-conv1d/causal_conv1d_fwd.cu",
19
- "causal-conv1d/causal_conv1d.h",
20
- "causal-conv1d/causal_conv1d_update.cu",
21
- "causal-conv1d/static_switch.h",
22
- ]
23
- include = [ "causal-conv1d" ]
24
- depends = [ "torch" ]
25
-
26
- #[kernel.causal_conv1d_rocm]
27
- #backend = "rocm"
28
- #rocm-archs = [
29
- # "gfx906",
30
- # "gfx908",
31
- # "gfx90a",
32
- # "gfx940",
33
- # "gfx941",
34
- # "gfx942",
35
- # "gfx1030",
36
- # "gfx1100",
37
- # "gfx1101",
38
- #]
39
- #src = [
40
- # "causal-conv1d/causal_conv1d_bwd.cu",
41
- # "causal-conv1d/causal_conv1d_common.h",
42
- # "causal-conv1d/causal_conv1d.cpp",
43
- # "causal-conv1d/causal_conv1d_fwd.cu",
44
- # "causal-conv1d/causal_conv1d.h",
45
- # "causal-conv1d/causal_conv1d_update.cu",
46
- # "causal-conv1d/static_switch.h",
47
- #]
48
- #include = [ "causal-conv1d" ]
49
- #depends = [ "torch" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d.cpp DELETED
@@ -1,486 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <torch/all.h>
6
- #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
7
- #include <c10/core/DeviceGuard.h>
8
- #else
9
- #include <c10/cuda/CUDAGuard.h>
10
- #endif
11
-
12
- #include <c10/cuda/CUDAStream.h>
13
- #include <vector>
14
-
15
- #include "causal_conv1d.h"
16
-
17
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
18
-
19
- #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
20
- if (ITYPE == at::ScalarType::Half) { \
21
- using input_t = at::Half; \
22
- __VA_ARGS__(); \
23
- } else if (ITYPE == at::ScalarType::BFloat16) { \
24
- using input_t = at::BFloat16; \
25
- __VA_ARGS__(); \
26
- } else if (ITYPE == at::ScalarType::Float) { \
27
- using input_t = float; \
28
- __VA_ARGS__(); \
29
- } else { \
30
- AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
31
- }
32
-
33
- #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
34
- if (WTYPE == at::ScalarType::Half) { \
35
- using weight_t = at::Half; \
36
- __VA_ARGS__(); \
37
- } else if (WTYPE == at::ScalarType::BFloat16) { \
38
- using weight_t = at::BFloat16; \
39
- __VA_ARGS__(); \
40
- } else if (WTYPE == at::ScalarType::Float) { \
41
- using weight_t = float; \
42
- __VA_ARGS__(); \
43
- } else { \
44
- AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
45
- }
46
-
47
- template<typename input_t, typename weight_t>
48
- void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
49
- template <typename input_t, typename weight_t>
50
- void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
51
-
52
- template<typename input_t, typename weight_t>
53
- void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
54
- template<typename input_t, typename weight_t>
55
- void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
56
-
57
- template<typename input_t, typename weight_t>
58
- void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
59
-
60
- void set_conv_params_fwd(ConvParamsBase &params,
61
- // sizes
62
- const size_t batch,
63
- const size_t dim,
64
- const size_t seqlen,
65
- const size_t width,
66
- // device pointers
67
- const at::Tensor x,
68
- const at::Tensor weight,
69
- const at::Tensor out,
70
- void* bias_ptr,
71
- bool silu_activation) {
72
-
73
- // Reset the parameters
74
- memset(&params, 0, sizeof(params));
75
-
76
- params.batch = batch;
77
- params.dim = dim;
78
- params.seqlen = seqlen;
79
- params.width = width;
80
-
81
- params.silu_activation = silu_activation;
82
-
83
- // Set the pointers and strides.
84
- params.x_ptr = x.data_ptr();
85
- params.weight_ptr = weight.data_ptr();
86
- params.bias_ptr = bias_ptr;
87
- params.out_ptr = out.data_ptr();
88
- // All stride are in elements, not bytes.
89
- params.x_batch_stride = x.stride(0);
90
- params.x_c_stride = x.stride(1);
91
- params.x_l_stride = x.stride(-1);
92
- params.weight_c_stride = weight.stride(0);
93
- params.weight_width_stride = weight.stride(1);
94
- params.out_batch_stride = out.stride(0);
95
- params.out_c_stride = out.stride(1);
96
- params.out_l_stride = out.stride(-1);
97
- }
98
-
99
-
100
- void set_conv_params_bwd(ConvParamsBwd &params,
101
- // sizes
102
- const size_t batch,
103
- const size_t dim,
104
- const size_t seqlen,
105
- const size_t width,
106
- // device pointers
107
- const at::Tensor x,
108
- const at::Tensor weight,
109
- void* bias_ptr,
110
- const at::Tensor dout,
111
- const at::Tensor dx,
112
- const at::Tensor dweight,
113
- void* dbias_ptr,
114
- bool silu_activation) {
115
- // Pass in "dout" instead of "out", we're not gonna use "out" at all.
116
- set_conv_params_fwd(params, batch, dim, seqlen, width,
117
- x, weight, dout, bias_ptr, silu_activation);
118
-
119
- // Set the pointers and strides.
120
- params.dout_ptr = dout.data_ptr();
121
- params.dx_ptr = dx.data_ptr();
122
- params.dweight_ptr = dweight.data_ptr();
123
- params.dbias_ptr = dbias_ptr;
124
- // All stride are in elements, not bytes.
125
- params.dout_batch_stride = dout.stride(0);
126
- params.dout_c_stride = dout.stride(1);
127
- params.dout_l_stride = dout.stride(2);
128
- params.dweight_c_stride = dweight.stride(0);
129
- params.dweight_width_stride = dweight.stride(1);
130
- params.dx_batch_stride = dx.stride(0);
131
- params.dx_c_stride = dx.stride(1);
132
- params.dx_l_stride = dx.stride(2);
133
- }
134
-
135
- void
136
- causal_conv1d_fwd(const at::Tensor &x,
137
- const at::Tensor &weight,
138
- const c10::optional<at::Tensor> &bias_,
139
- const c10::optional<at::Tensor> &seq_idx_,
140
- const c10::optional<at::Tensor> &initial_states_,
141
- at::Tensor &out,
142
- c10::optional<at::Tensor> &final_states_out_,
143
- bool silu_activation) {
144
- auto input_type = x.scalar_type();
145
- auto weight_type = weight.scalar_type();
146
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
147
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
148
-
149
- TORCH_CHECK(x.is_cuda());
150
- TORCH_CHECK(weight.is_cuda());
151
-
152
- const auto sizes = x.sizes();
153
- const int batch_size = sizes[0];
154
- const int dim = sizes[1];
155
- const int seqlen = sizes[2];
156
- const int width = weight.size(-1);
157
-
158
- CHECK_SHAPE(x, batch_size, dim, seqlen);
159
- CHECK_SHAPE(weight, dim, width);
160
-
161
- TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
162
- const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
163
-
164
- if (is_channel_last) {
165
- TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
166
- TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
167
- }
168
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
169
-
170
- if (bias_.has_value()) {
171
- auto bias = bias_.value();
172
- TORCH_CHECK(bias.scalar_type() == weight_type);
173
- TORCH_CHECK(bias.is_cuda());
174
- TORCH_CHECK(bias.stride(-1) == 1);
175
- CHECK_SHAPE(bias, dim);
176
- }
177
-
178
- if (seq_idx_.has_value()) {
179
- TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
180
- auto seq_idx = seq_idx_.value();
181
- TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
182
- TORCH_CHECK(seq_idx.is_cuda());
183
- TORCH_CHECK(seq_idx.is_contiguous());
184
- CHECK_SHAPE(seq_idx, batch_size, seqlen);
185
- }
186
-
187
- ConvParamsBase params;
188
- set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
189
- bias_.has_value() ? bias_.value().data_ptr() : nullptr,
190
- silu_activation);
191
-
192
- if (seq_idx_.has_value()) {
193
- params.seq_idx_ptr = seq_idx_.value().data_ptr();
194
- } else {
195
- params.seq_idx_ptr = nullptr;
196
- }
197
-
198
- if (initial_states_.has_value()) {
199
- TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
200
- auto initial_states = initial_states_.value();
201
- TORCH_CHECK(initial_states.scalar_type() == input_type);
202
- TORCH_CHECK(initial_states.is_cuda());
203
- CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
204
- TORCH_CHECK(initial_states.stride(1) == 1);
205
- params.initial_states_ptr = initial_states.data_ptr();
206
- params.initial_states_batch_stride = initial_states.stride(0);
207
- params.initial_states_c_stride = initial_states.stride(1);
208
- params.initial_states_l_stride = initial_states.stride(2);
209
- } else {
210
- params.initial_states_ptr = nullptr;
211
- }
212
-
213
- if (final_states_out_.has_value()) {
214
- TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
215
- auto final_states = final_states_out_.value();
216
- TORCH_CHECK(final_states.scalar_type() == input_type);
217
- TORCH_CHECK(final_states.is_cuda());
218
- CHECK_SHAPE(final_states, batch_size, dim, width - 1);
219
- TORCH_CHECK(final_states.stride(1) == 1);
220
- params.final_states_ptr = final_states.data_ptr();
221
- params.final_states_batch_stride = final_states.stride(0);
222
- params.final_states_c_stride = final_states.stride(1);
223
- params.final_states_l_stride = final_states.stride(2);
224
- } else {
225
- params.final_states_ptr = nullptr;
226
- }
227
-
228
- // Otherwise the kernel will be launched from cuda:0 device
229
- #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
230
- c10::DeviceGuard device_guard(x.device());
231
- #else
232
- at::cuda::CUDAGuard device_guard{x.device()};
233
- #endif
234
- auto stream = at::cuda::getCurrentCUDAStream().stream();
235
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
236
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
237
- if (!is_channel_last) {
238
- causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
239
- } else {
240
- causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
241
- }
242
- });
243
- });
244
- }
245
-
246
- void
247
- causal_conv1d_bwd(const at::Tensor &x,
248
- const at::Tensor &weight,
249
- const c10::optional<at::Tensor> &bias_,
250
- at::Tensor &dout,
251
- const c10::optional<at::Tensor> &seq_idx_,
252
- const c10::optional<at::Tensor> &initial_states_,
253
- const c10::optional<at::Tensor> &dfinal_states_,
254
- at::Tensor &dx,
255
- at::Tensor &dweight,
256
- c10::optional<at::Tensor> &dbias_,
257
- c10::optional<at::Tensor> &dinitial_states_,
258
- bool silu_activation) {
259
- auto input_type = x.scalar_type();
260
- auto weight_type = weight.scalar_type();
261
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
262
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
263
-
264
- TORCH_CHECK(x.is_cuda());
265
- TORCH_CHECK(weight.is_cuda());
266
- TORCH_CHECK(dout.is_cuda());
267
- TORCH_CHECK(bias_.has_value() == dbias_.has_value());
268
-
269
- const auto sizes = x.sizes();
270
- const int batch_size = sizes[0];
271
- const int dim = sizes[1];
272
- const int seqlen = sizes[2];
273
- const int width = weight.size(-1);
274
-
275
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
276
-
277
- CHECK_SHAPE(x, batch_size, dim, seqlen);
278
- CHECK_SHAPE(weight, dim, width);
279
- CHECK_SHAPE(dout, batch_size, dim, seqlen);
280
-
281
- TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
282
- const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
283
- if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
284
- if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
285
-
286
- if (is_channel_last) {
287
- TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
288
- TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
289
- TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
290
- }
291
-
292
- if (bias_.has_value()) {
293
- auto bias = bias_.value();
294
- TORCH_CHECK(bias.scalar_type() == weight_type);
295
- TORCH_CHECK(bias.is_cuda());
296
- TORCH_CHECK(bias.stride(-1) == 1);
297
- CHECK_SHAPE(bias, dim);
298
- }
299
-
300
- if (seq_idx_.has_value()) {
301
- TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
302
- auto seq_idx = seq_idx_.value();
303
- TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
304
- TORCH_CHECK(seq_idx.is_cuda());
305
- TORCH_CHECK(seq_idx.is_contiguous());
306
- CHECK_SHAPE(seq_idx, batch_size, seqlen);
307
- }
308
-
309
- TORCH_CHECK(dx.scalar_type() == input_type);
310
- TORCH_CHECK(dx.is_cuda());
311
- CHECK_SHAPE(dx, batch_size, dim, seqlen);
312
- if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
313
- if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
314
-
315
- // Otherwise the kernel will be launched from cuda:0 device
316
- #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
317
- c10::Device device = x.device();
318
- c10::DeviceGuard device_guard(device);
319
- #else
320
- at::cuda::CUDAGuard device_guard{x.device()};
321
- #endif
322
- ConvParamsBwd params;
323
- set_conv_params_bwd(params, batch_size, dim, seqlen, width,
324
- x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
325
- dout, dx, dweight, bias_.has_value() ? dbias_.value().data_ptr() : nullptr,
326
- silu_activation);
327
-
328
- if (seq_idx_.has_value()) {
329
- params.seq_idx_ptr = seq_idx_.value().data_ptr();
330
- } else {
331
- params.seq_idx_ptr = nullptr;
332
- }
333
-
334
- if (initial_states_.has_value()) {
335
- TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
336
- auto initial_states = initial_states_.value();
337
- TORCH_CHECK(initial_states.scalar_type() == input_type);
338
- TORCH_CHECK(initial_states.is_cuda());
339
- CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
340
- TORCH_CHECK(initial_states.stride(1) == 1);
341
- params.initial_states_ptr = initial_states.data_ptr();
342
- params.initial_states_batch_stride = initial_states.stride(0);
343
- params.initial_states_c_stride = initial_states.stride(1);
344
- params.initial_states_l_stride = initial_states.stride(2);
345
- } else {
346
- params.initial_states_ptr = nullptr;
347
- }
348
-
349
- if (dfinal_states_.has_value()) {
350
- TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
351
- auto dfinal_states = dfinal_states_.value();
352
- TORCH_CHECK(dfinal_states.scalar_type() == input_type);
353
- TORCH_CHECK(dfinal_states.is_cuda());
354
- CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
355
- params.dfinal_states_ptr = dfinal_states.data_ptr();
356
- params.dfinal_states_batch_stride = dfinal_states.stride(0);
357
- params.dfinal_states_c_stride = dfinal_states.stride(1);
358
- params.dfinal_states_l_stride = dfinal_states.stride(2);
359
- } else {
360
- params.dfinal_states_ptr = nullptr;
361
- }
362
-
363
- if (dinitial_states_.has_value()) {
364
- at::Tensor dinitial_states = dinitial_states_.value();
365
- TORCH_CHECK(dinitial_states.stride(1) == 1);
366
- params.dinitial_states_ptr = dinitial_states.data_ptr();
367
- params.dinitial_states_batch_stride = dinitial_states.stride(0);
368
- params.dinitial_states_c_stride = dinitial_states.stride(1);
369
- params.dinitial_states_l_stride = dinitial_states.stride(2);
370
- } else {
371
- params.dinitial_states_ptr = nullptr;
372
- }
373
-
374
- auto stream = at::cuda::getCurrentCUDAStream().stream();
375
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
376
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
377
- if (!is_channel_last) {
378
- causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
379
- } else {
380
- causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
381
- }
382
- });
383
- });
384
- }
385
-
386
- void
387
- causal_conv1d_update(const at::Tensor &x,
388
- const at::Tensor &conv_state,
389
- const at::Tensor &weight,
390
- const c10::optional<at::Tensor> &bias_,
391
- at::Tensor &out,
392
- bool silu_activation,
393
- const c10::optional<at::Tensor> &cache_seqlens_,
394
- const c10::optional<at::Tensor> &conv_state_indices_
395
- ) {
396
- auto input_type = x.scalar_type();
397
- auto weight_type = weight.scalar_type();
398
- TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
399
- TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
400
- TORCH_CHECK(conv_state.scalar_type() == input_type);
401
-
402
- TORCH_CHECK(x.is_cuda());
403
- TORCH_CHECK(conv_state.is_cuda());
404
- TORCH_CHECK(weight.is_cuda());
405
-
406
- const auto sizes = x.sizes();
407
- const int batch_size = sizes[0];
408
- const int dim = sizes[1];
409
- const int seqlen = sizes[2];
410
- const int width = weight.size(-1);
411
- const int conv_state_len = conv_state.size(2);
412
- TORCH_CHECK(conv_state_len >= width - 1);
413
-
414
- CHECK_SHAPE(x, batch_size, dim, seqlen);
415
- CHECK_SHAPE(weight, dim, width);
416
-
417
- TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
418
-
419
- if (bias_.has_value()) {
420
- auto bias = bias_.value();
421
- TORCH_CHECK(bias.scalar_type() == weight_type);
422
- TORCH_CHECK(bias.is_cuda());
423
- TORCH_CHECK(bias.stride(-1) == 1);
424
- CHECK_SHAPE(bias, dim);
425
- }
426
-
427
- ConvParamsBase params;
428
- set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
429
- bias_.has_value() ? bias_.value().data_ptr() : nullptr,
430
- silu_activation);
431
- params.conv_state_ptr = conv_state.data_ptr();
432
- params.conv_state_len = conv_state_len;
433
- // All stride are in elements, not bytes.
434
- params.conv_state_batch_stride = conv_state.stride(0);
435
- params.conv_state_c_stride = conv_state.stride(1);
436
- params.conv_state_l_stride = conv_state.stride(2);
437
-
438
- if (conv_state_indices_.has_value()) {
439
- auto conv_state_indices = conv_state_indices_.value();
440
- TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
441
- TORCH_CHECK(conv_state_indices.is_cuda());
442
- TORCH_CHECK(conv_state_indices.stride(0) == 1)
443
- CHECK_SHAPE(conv_state_indices, batch_size);
444
-
445
- int conv_state_entries = conv_state.size(0);
446
- CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
447
-
448
- params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
449
- } else {
450
- CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
451
- params.conv_state_indices_ptr = nullptr;
452
- }
453
-
454
- if (cache_seqlens_.has_value()) {
455
- auto cache_seqlens = cache_seqlens_.value();
456
- TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
457
- TORCH_CHECK(cache_seqlens.is_cuda());
458
- TORCH_CHECK(cache_seqlens.stride(-1) == 1);
459
- CHECK_SHAPE(cache_seqlens, batch_size);
460
- params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
461
- } else {
462
- params.cache_seqlens = nullptr;
463
- }
464
-
465
- // Otherwise the kernel will be launched from cuda:0 device
466
- #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
467
- c10::Device device = x.device();
468
- c10::DeviceGuard device_guard(device);
469
- #else
470
- at::cuda::CUDAGuard device_guard{x.device()};
471
- #endif
472
- auto stream = at::cuda::getCurrentCUDAStream().stream();
473
- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
474
- DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
475
- causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
476
- });
477
- });
478
- }
479
-
480
- /*
481
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
482
- m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
483
- m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
484
- m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
485
- }
486
- */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d.h DELETED
@@ -1,81 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- ////////////////////////////////////////////////////////////////////////////////////////////////////
8
-
9
- struct ConvParamsBase {
10
- using index_t = uint32_t;
11
-
12
- int batch, dim, seqlen, width;
13
- bool silu_activation;
14
-
15
- index_t x_batch_stride;
16
- index_t x_c_stride;
17
- index_t x_l_stride;
18
- index_t weight_c_stride;
19
- index_t weight_width_stride;
20
- index_t out_batch_stride;
21
- index_t out_c_stride;
22
- index_t out_l_stride;
23
-
24
- int conv_state_len;
25
- index_t conv_state_batch_stride;
26
- index_t conv_state_c_stride;
27
- index_t conv_state_l_stride;
28
-
29
- // Common data pointers.
30
- void *__restrict__ x_ptr;
31
- void *__restrict__ weight_ptr;
32
- void *__restrict__ bias_ptr;
33
- void *__restrict__ out_ptr;
34
-
35
- void *__restrict__ conv_state_ptr;
36
- int32_t *__restrict__ cache_seqlens;
37
-
38
- // Only used if the elements of the batch are gathered from a larger buffer,
39
- // which may happen for continuous batching.
40
- int32_t *__restrict__ conv_state_indices_ptr;
41
-
42
- void *__restrict__ seq_idx_ptr;
43
-
44
- // No __restrict__ since initial_states could be the same as final_states.
45
- void * initial_states_ptr;
46
- index_t initial_states_batch_stride;
47
- index_t initial_states_l_stride;
48
- index_t initial_states_c_stride;
49
-
50
- void * final_states_ptr;
51
- index_t final_states_batch_stride;
52
- index_t final_states_l_stride;
53
- index_t final_states_c_stride;
54
- };
55
-
56
- struct ConvParamsBwd: public ConvParamsBase {
57
- index_t dx_batch_stride;
58
- index_t dx_c_stride;
59
- index_t dx_l_stride;
60
- index_t dweight_c_stride;
61
- index_t dweight_width_stride;
62
- index_t dout_batch_stride;
63
- index_t dout_c_stride;
64
- index_t dout_l_stride;
65
-
66
- // Common data pointers.
67
- void *__restrict__ dx_ptr;
68
- void *__restrict__ dweight_ptr;
69
- void *__restrict__ dbias_ptr;
70
- void *__restrict__ dout_ptr;
71
-
72
- void * dinitial_states_ptr;
73
- index_t dinitial_states_batch_stride;
74
- index_t dinitial_states_l_stride;
75
- index_t dinitial_states_c_stride;
76
-
77
- void * dfinal_states_ptr;
78
- index_t dfinal_states_batch_stride;
79
- index_t dfinal_states_l_stride;
80
- index_t dfinal_states_c_stride;
81
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d_bwd.cu DELETED
@@ -1,627 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #ifndef USE_ROCM
10
- #include <cub/block/block_load.cuh>
11
- #include <cub/block/block_store.cuh>
12
- #include <cub/block/block_reduce.cuh>
13
- #else
14
- #include <hipcub/hipcub.hpp>
15
- namespace cub = hipcub;
16
- #endif
17
-
18
- #include "causal_conv1d.h"
19
- #include "causal_conv1d_common.h"
20
- #include "static_switch.h"
21
-
22
- template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
23
- struct Causal_conv1d_bwd_kernel_traits {
24
- using input_t = input_t_;
25
- using weight_t = weight_t_;
26
- static constexpr int kNThreads = kNThreads_;
27
- static constexpr int kWidth = kWidth_;
28
- static constexpr bool kSiluAct = kSiluAct_;
29
- static constexpr int kNBytes = sizeof(input_t);
30
- static_assert(kNBytes == 2 || kNBytes == 4);
31
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
32
- static_assert(kWidth <= kNElts);
33
- // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
34
- // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
35
- static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
36
- static constexpr bool kIsVecLoad = kIsVecLoad_;
37
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
38
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
39
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
40
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
41
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
42
- using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
43
- static constexpr int kSmemIOSize = kIsVecLoad
44
- ? 0
45
- : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
46
- static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
47
- static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
48
- int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
49
- };
50
-
51
- template<typename Ktraits>
52
- __global__ __launch_bounds__(Ktraits::kNThreads)
53
- void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
54
- constexpr int kWidth = Ktraits::kWidth;
55
- constexpr int kNThreads = Ktraits::kNThreads;
56
- constexpr bool kSiluAct = Ktraits::kSiluAct;
57
- static constexpr int kNElts = Ktraits::kNElts;
58
- constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
59
- static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
60
- using input_t = typename Ktraits::input_t;
61
- using vec_t = typename Ktraits::vec_t;
62
- using weight_t = typename Ktraits::weight_t;
63
-
64
- // Shared memory.
65
- extern __shared__ char smem_[];
66
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
67
- auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
68
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
69
- auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
70
- vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
71
- vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
72
- auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
73
-
74
- const int tidx = threadIdx.x;
75
- const int batch_id = blockIdx.x;
76
- const int dim_id = blockIdx.y;
77
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
78
- + dim_id * params.x_c_stride;
79
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
80
- input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
81
- + dim_id * params.dout_c_stride;
82
- input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
83
- + dim_id * params.dx_c_stride;
84
- float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
85
- float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
86
-
87
- // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
88
- if (tidx == 0) {
89
- if constexpr (!kSiluAct) {
90
- input_t zeros[kNElts] = {0};
91
- smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
92
- } else {
93
- float zeros[kNElts] = {0};
94
- #pragma unroll
95
- for (int r = 0; r < kNExchangeRounds; ++r) {
96
- smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
97
- }
98
- }
99
- }
100
-
101
- float weight_vals[kWidth];
102
- #pragma unroll
103
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
104
-
105
- float dweight_vals[kWidth] = {0};
106
- float dbias_val = 0;
107
-
108
- constexpr int kChunkSize = kNThreads * kNElts;
109
- const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
110
- x += (n_chunks - 1) * kChunkSize;
111
- dout += (n_chunks - 1) * kChunkSize;
112
- dx += (n_chunks - 1) * kChunkSize;
113
- for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
114
- input_t x_vals_load[2 * kNElts] = {0};
115
- input_t dout_vals_load[2 * kNElts] = {0};
116
- if constexpr(kIsVecLoad) {
117
- typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
118
- typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
119
- } else {
120
- __syncthreads();
121
- typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
122
- __syncthreads();
123
- typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
124
- }
125
- float dout_vals[2 * kNElts], x_vals[2 * kNElts];
126
- if constexpr (!kSiluAct) {
127
- __syncthreads();
128
- // Thread 0 don't write yet, so that thread kNThreads - 1 can read
129
- // the first elements of the next chunk.
130
- if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
131
- __syncthreads();
132
- reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
133
- __syncthreads();
134
- // Now thread 0 can write the first elements of the current chunk.
135
- if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
136
- #pragma unroll
137
- for (int i = 0; i < 2 * kNElts; ++i) {
138
- dout_vals[i] = float(dout_vals_load[i]);
139
- x_vals[i] = float(x_vals_load[i]);
140
- }
141
- } else {
142
- if (tidx == 0 && chunk > 0) {
143
- if constexpr(kIsVecLoad) {
144
- reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
145
- } else {
146
- #pragma unroll
147
- for (int i = 0; i < kNElts; ++i) {
148
- if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
149
- }
150
- }
151
- }
152
- __syncthreads();
153
- smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
154
- __syncthreads();
155
- if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
156
- #pragma unroll
157
- for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
158
- // Recompute the output
159
- #pragma unroll
160
- for (int i = 0; i < kNElts; ++i) {
161
- float out_val = bias_val;
162
- #pragma unroll
163
- for (int w = 0; w < kWidth; ++w) {
164
- out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
165
- }
166
- float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
167
- dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
168
- * (1.0f + out_val * (1.0f - out_sigmoid_val));
169
- }
170
- // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
171
- // if input_t is 16 bits (since then we'd have 8 values of float)
172
- __syncthreads();
173
- // Thread 0 don't write yet, so that thread kNThreads - 1 can read
174
- // the first elements of the next chunk.
175
- if (tidx > 0) {
176
- #pragma unroll
177
- for (int r = 0; r < kNExchangeRounds; ++r) {
178
- smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
179
- }
180
- }
181
- __syncthreads();
182
- #pragma unroll
183
- for (int r = 0; r < kNExchangeRounds; ++r) {
184
- reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
185
- = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
186
- }
187
- __syncthreads();
188
- // Now thread 0 can write the first elements of the current chunk.
189
- if (tidx == 0) {
190
- #pragma unroll
191
- for (int r = 0; r < kNExchangeRounds; ++r) {
192
- smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
193
- }
194
- }
195
- }
196
- dout -= kChunkSize;
197
- x -= kChunkSize;
198
-
199
- #pragma unroll
200
- for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
201
-
202
- float dx_vals[kNElts] = {0};
203
- #pragma unroll
204
- for (int i = 0; i < kNElts; ++i) {
205
- #pragma unroll
206
- for (int w = 0; w < kWidth; ++w) {
207
- dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
208
- }
209
- }
210
-
211
- input_t dx_vals_store[kNElts];
212
- #pragma unroll
213
- for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
214
- if constexpr(kIsVecLoad) {
215
- typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
216
- } else {
217
- typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
218
- }
219
- dx -= kChunkSize;
220
-
221
- #pragma unroll
222
- for (int w = 0; w < kWidth; ++w) {
223
- #pragma unroll
224
- for (int i = 0; i < kNElts; ++i) {
225
- dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
226
- }
227
- }
228
- }
229
-
230
- #pragma unroll
231
- for (int w = 0; w < kWidth; ++w) {
232
- __syncthreads();
233
- dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
234
- if (tidx == 0) {
235
- atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
236
- }
237
- }
238
- if (params.bias_ptr != nullptr) {
239
- __syncthreads();
240
- dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
241
- if (tidx == 0) {
242
- atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
243
- }
244
- }
245
- }
246
-
247
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
248
- void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
249
- static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
250
- BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
251
- BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
252
- using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
253
- constexpr int kSmemSize = Ktraits::kSmemSize;
254
- dim3 grid(params.batch, params.dim);
255
- auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
256
-
257
- if (kSmemSize >= 48 * 1024) {
258
- #ifndef USE_ROCM
259
- C10_CUDA_CHECK(cudaFuncSetAttribute(
260
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
261
- #else
262
- // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
263
- C10_CUDA_CHECK(cudaFuncSetAttribute(
264
- (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
265
- std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
266
- #endif
267
- }
268
-
269
-
270
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
271
- C10_CUDA_KERNEL_LAUNCH_CHECK();
272
- });
273
- });
274
- }
275
-
276
- template<typename input_t, typename weight_t>
277
- void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
278
- if (params.width == 2) {
279
- causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
280
- } else if (params.width == 3) {
281
- causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
282
- } else if (params.width == 4) {
283
- causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
284
- }
285
- }
286
-
287
- template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
288
- struct Causal_conv1d_channellast_bwd_kernel_traits {
289
- // The cache line is 128 bytes, and we try to read 16 bytes per thread.
290
- // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
291
- // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
292
- // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
293
- using input_t = input_t_;
294
- using weight_t = weight_t_;
295
- static constexpr bool kSiluAct = kSiluAct_;
296
- static constexpr int kNThreads = kNThreads_;
297
- static_assert(kNThreads % 32 == 0);
298
- static constexpr int kNWarps = kNThreads / 32;
299
- static constexpr int kWidth = kWidth_;
300
- static constexpr int kChunkSizeL = kChunkSizeL_;
301
- static constexpr int kNBytes = sizeof(input_t);
302
- static_assert(kNBytes == 2 || kNBytes == 4);
303
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
304
- static constexpr int kNEltsPerRow = 128 / kNBytes;
305
- static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
306
- static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
307
- static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
308
- static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
309
- static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
310
- static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
311
- static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
312
- static constexpr bool kIsVecLoad = kIsVecLoad_;
313
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
314
- // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
315
- // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
316
- // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
317
- // sizeof(typename BlockStoreT::TempStorage)});
318
- // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
319
- };
320
-
321
- template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
322
- __global__ __launch_bounds__(Ktraits::kNThreads)
323
- void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
324
- constexpr int kWidth = Ktraits::kWidth;
325
- constexpr int kNThreads = Ktraits::kNThreads;
326
- constexpr bool kSiluAct = Ktraits::kSiluAct;
327
- constexpr int kNElts = Ktraits::kNElts;
328
- constexpr int kNWarp = Ktraits::kNWarps;
329
- constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
330
- constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
331
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
332
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
333
- using input_t = typename Ktraits::input_t;
334
- using vec_t = typename Ktraits::vec_t;
335
- using weight_t = typename Ktraits::weight_t;
336
-
337
- // Shared memory.
338
- __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
339
- __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
340
-
341
- const int batch_id = blockIdx.x;
342
- const int chunk_l_id = blockIdx.y;
343
- const int chunk_c_id = blockIdx.z;
344
- const int tid = threadIdx.x;
345
- const int l_idx = tid / kNThreadsPerC;
346
- const int c_idx = tid % kNThreadsPerC;
347
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
348
- + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
349
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
350
- + chunk_c_id * kChunkSizeC * params.weight_c_stride;
351
- input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
352
- + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
353
- input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
354
- + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
355
- float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
356
- + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
357
- int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
358
- + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
359
- input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
360
- : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
361
- input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
362
- : reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
363
- input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
364
- : reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
365
-
366
- #pragma unroll
367
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
368
- input_t dout_vals_load[kNElts] = {0};
369
- input_t x_vals_load[kNElts] = {0};
370
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
371
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
372
- reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
373
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
374
- }
375
- reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
376
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
377
- }
378
- // Load the elements from the previous chunk or next chunk that are needed for convolution.
379
- if (l_idx < kWidth - 1) {
380
- input_t dout_vals_load[kNElts] = {0};
381
- input_t x_vals_load[kNElts] = {0};
382
- if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
383
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
384
- reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
385
- }
386
- if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
387
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
388
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
389
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
390
- } else if (initial_states != nullptr
391
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
392
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
393
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
394
- }
395
- reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
396
- reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
397
- }
398
- // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
399
- if constexpr (kSiluAct) {
400
- if (l_idx < kWidth - 1) {
401
- input_t x_vals_load[kNElts] = {0};
402
- if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
403
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
404
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
405
- }
406
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
407
- }
408
- }
409
-
410
- __syncthreads();
411
-
412
- constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
413
- static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
414
- constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
415
- static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
416
- // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
417
- static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
418
- static_assert((kLPerThread & (kLPerThread - 1)) == 0);
419
- static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
420
- static_assert(kNThreadsPerRow <= 32);
421
-
422
- const int row_idx = tid / kNThreadsPerRow;
423
- const int col_idx = tid % kNThreadsPerRow;
424
-
425
- float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
426
- float weight_vals[kWidth] = {0};
427
- if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
428
- #pragma unroll
429
- for (int w = 0; w < kWidth; ++w) {
430
- weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
431
- }
432
- }
433
- float dout_vals[kLPerThread + kWidth - 1];
434
- float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
435
- #pragma unroll
436
- for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
437
- dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
438
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
439
- }
440
-
441
- int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
442
- if constexpr (kHasSeqIdx) {
443
- #pragma unroll
444
- for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
445
- const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
446
- seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
447
- }
448
- }
449
-
450
- if constexpr (kSiluAct) { // Recompute the output
451
- #pragma unroll
452
- for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
453
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
454
- }
455
- #pragma unroll
456
- for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
457
- float out_val = bias_val;
458
- const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
459
- #pragma unroll
460
- for (int w = 0; w < kWidth; ++w) {
461
- if constexpr (!kHasSeqIdx) {
462
- out_val += weight_vals[w] * x_vals[i + w];
463
- } else {
464
- out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
465
- }
466
- }
467
- float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
468
- dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
469
- }
470
- }
471
-
472
- float dweight_vals[kWidth] = {0};
473
- SumOp<float> sum_op;
474
- #pragma unroll
475
- for (int w = 0; w < kWidth; ++w) {
476
- #pragma unroll
477
- for (int i = 0; i < kLPerThread; ++i) {
478
- if constexpr (!kHasSeqIdx) {
479
- dweight_vals[w] += x_vals[i + w] * dout_vals[i];
480
- } else {
481
- dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
482
- }
483
- }
484
- dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
485
- if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
486
- atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
487
- }
488
- }
489
-
490
- if (params.bias_ptr != nullptr) {
491
- float dbias_val = 0.f;
492
- for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
493
- dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
494
- if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
495
- atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
496
- }
497
- }
498
-
499
- float dx_vals[kLPerThread] = {0};
500
- #pragma unroll
501
- for (int i = 0; i < kLPerThread; ++i) {
502
- const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
503
- #pragma unroll
504
- for (int w = 0; w < kWidth; ++w) {
505
- if constexpr (!kHasSeqIdx) {
506
- dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
507
- } else {
508
- dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
509
- }
510
- }
511
- // if (dfinal_states != nullptr) {
512
- if constexpr (kHasDfinalStates) {
513
- if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
514
- && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
515
- && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
516
- dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
517
- }
518
- }
519
- }
520
-
521
- float dxinit_vals[kWidth - 1] = {0};
522
- static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
523
- if (dinitial_states != nullptr && col_idx == 0) {
524
- #pragma unroll
525
- for (int i = 0; i < kWidth - 1; ++i) {
526
- #pragma unroll
527
- for (int w = 0; w < kWidth; ++w) {
528
- dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
529
- }
530
- // chunk_l_id must be 0 because dinitial_states != nullptr
531
- // if (dfinal_states != nullptr) {
532
- if constexpr (kHasDfinalStates) {
533
- if (i >= params.seqlen) {
534
- dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
535
- }
536
- }
537
- }
538
- }
539
-
540
- __syncthreads();
541
- #pragma unroll
542
- for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
543
- if (dinitial_states != nullptr && col_idx == 0) {
544
- #pragma unroll
545
- for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
546
- }
547
- __syncthreads();
548
-
549
- #pragma unroll
550
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
551
- input_t dx_vals_store[kNElts];
552
- reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
553
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
554
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
555
- *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
556
- }
557
- }
558
- if (dinitial_states != nullptr
559
- && l_idx < kWidth - 1
560
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
561
- input_t dxinit_vals_store[kNElts];
562
- reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
563
- *reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
564
- }
565
-
566
- }
567
-
568
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
569
- void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
570
- BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
571
- BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
572
- BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
573
- BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
574
- // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
575
- static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
576
- using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
577
- // constexpr int kSmemSize = Ktraits::kSmemSize;
578
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
579
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
580
- const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
581
- const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
582
- dim3 grid(params.batch, n_chunks_L, n_chunks_C);
583
- dim3 block(Ktraits::kNThreads);
584
- auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
585
- // if (kSmemSize >= 48 * 1024) {
586
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
587
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
588
- // }
589
- // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
590
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
591
- C10_CUDA_KERNEL_LAUNCH_CHECK();
592
- });
593
- });
594
- });
595
- });
596
- }
597
-
598
- template<typename input_t, typename weight_t>
599
- void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
600
- if (params.width == 2) {
601
- causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
602
- } else if (params.width == 3) {
603
- causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
604
- } else if (params.width == 4) {
605
- causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
606
- }
607
- }
608
-
609
- template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
610
- template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
611
- template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
612
- template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
613
- template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
614
- template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
615
- template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
616
- template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
617
- template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
618
-
619
- template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
620
- template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
621
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
622
- template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
623
- template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
624
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
625
- template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
626
- template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
627
- template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d_common.h DELETED
@@ -1,98 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #ifndef USE_ROCM
8
- #include <cuda_bf16.h>
9
-
10
- template<typename T>
11
- __device__ inline T shuffle_xor(T val, int offset) {
12
- return __shfl_xor_sync(uint32_t(-1), val, offset);
13
- }
14
-
15
- constexpr size_t custom_max(std::initializer_list<size_t> ilist)
16
- {
17
- return std::max(ilist);
18
- }
19
-
20
- template<typename T>
21
- constexpr T constexpr_min(T a, T b) {
22
- return std::min(a, b);
23
- }
24
-
25
- #else
26
- #include <hip/hip_bf16.h>
27
-
28
- template<typename T>
29
- __device__ inline T shuffle_xor(T val, int offset) {
30
- return __shfl_xor(val, offset);
31
- }
32
- constexpr size_t custom_max(std::initializer_list<size_t> ilist)
33
- {
34
- return *std::max_element(ilist.begin(), ilist.end());
35
- }
36
-
37
- template<typename T>
38
- constexpr T constexpr_min(T a, T b) {
39
- return a < b ? a : b;
40
- }
41
- #endif
42
- #include <cuda_fp16.h>
43
-
44
- ////////////////////////////////////////////////////////////////////////////////////////////////////
45
-
46
- template<int BYTES> struct BytesToType {};
47
-
48
- template<> struct BytesToType<16> {
49
- using Type = uint4;
50
- static_assert(sizeof(Type) == 16);
51
- };
52
-
53
- template<> struct BytesToType<8> {
54
- using Type = uint64_t;
55
- static_assert(sizeof(Type) == 8);
56
- };
57
-
58
- template<> struct BytesToType<4> {
59
- using Type = uint32_t;
60
- static_assert(sizeof(Type) == 4);
61
- };
62
-
63
- template<> struct BytesToType<2> {
64
- using Type = uint16_t;
65
- static_assert(sizeof(Type) == 2);
66
- };
67
-
68
- template<> struct BytesToType<1> {
69
- using Type = uint8_t;
70
- static_assert(sizeof(Type) == 1);
71
- };
72
-
73
- ////////////////////////////////////////////////////////////////////////////////////////////////////
74
-
75
- template<typename T>
76
- struct SumOp {
77
- __device__ inline T operator()(T const & x, T const & y) { return x + y; }
78
- };
79
-
80
- template<int THREADS>
81
- struct Allreduce {
82
- static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
83
- template<typename T, typename Operator>
84
- static __device__ inline T run(T x, Operator &op) {
85
- constexpr int OFFSET = THREADS / 2;
86
- x = op(x, shuffle_xor(x, OFFSET));
87
- return Allreduce<OFFSET>::run(x, op);
88
- }
89
- };
90
-
91
- template<>
92
- struct Allreduce<2> {
93
- template<typename T, typename Operator>
94
- static __device__ inline T run(T x, Operator &op) {
95
- x = op(x, shuffle_xor(x, 1));
96
- return x;
97
- }
98
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d_fwd.cu DELETED
@@ -1,399 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #ifndef USE_ROCM
10
- #include <cub/block/block_load.cuh>
11
- #include <cub/block/block_store.cuh>
12
- #else
13
- #include <hipcub/hipcub.hpp>
14
- namespace cub = hipcub;
15
- #endif
16
-
17
- #include "causal_conv1d.h"
18
- #include "causal_conv1d_common.h"
19
- #include "static_switch.h"
20
-
21
- template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
22
- struct Causal_conv1d_fwd_kernel_traits {
23
- using input_t = input_t_;
24
- using weight_t = weight_t_;
25
- static constexpr int kNThreads = kNThreads_;
26
- static constexpr int kWidth = kWidth_;
27
- static constexpr int kNBytes = sizeof(input_t);
28
- static_assert(kNBytes == 2 || kNBytes == 4);
29
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
30
- static_assert(kWidth <= kNElts);
31
- static constexpr bool kIsVecLoad = kIsVecLoad_;
32
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
33
- using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
34
- using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
35
- using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
36
- using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
37
- static constexpr int kSmemIOSize = kIsVecLoad
38
- ? 0
39
- : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
40
- static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
41
- static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
42
- };
43
-
44
- template<typename Ktraits>
45
- __global__ __launch_bounds__(Ktraits::kNThreads)
46
- void causal_conv1d_fwd_kernel(ConvParamsBase params) {
47
- constexpr int kWidth = Ktraits::kWidth;
48
- constexpr int kNThreads = Ktraits::kNThreads;
49
- constexpr int kNElts = Ktraits::kNElts;
50
- static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
51
- using input_t = typename Ktraits::input_t;
52
- using vec_t = typename Ktraits::vec_t;
53
- using weight_t = typename Ktraits::weight_t;
54
-
55
- // Shared memory.
56
- extern __shared__ char smem_[];
57
- auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
58
- auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
59
- auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
60
- auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
61
- vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
62
-
63
- const int tidx = threadIdx.x;
64
- const int batch_id = blockIdx.x;
65
- const int channel_id = blockIdx.y;
66
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
67
- + channel_id * params.x_c_stride;
68
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
69
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
70
- + channel_id * params.out_c_stride;
71
- float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
72
-
73
- // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
74
- if (tidx == 0) {
75
- input_t zeros[kNElts] = {0};
76
- smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
77
- }
78
-
79
- float weight_vals[kWidth];
80
- #pragma unroll
81
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
82
-
83
- constexpr int kChunkSize = kNThreads * kNElts;
84
- const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
85
- for (int chunk = 0; chunk < n_chunks; ++chunk) {
86
- input_t x_vals_load[2 * kNElts] = {0};
87
- if constexpr(kIsVecLoad) {
88
- typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
89
- } else {
90
- __syncthreads();
91
- typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
92
- }
93
- x += kChunkSize;
94
- __syncthreads();
95
- // Thread kNThreads - 1 don't write yet, so that thread 0 can read
96
- // the last elements of the previous chunk.
97
- if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
98
- __syncthreads();
99
- reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
100
- __syncthreads();
101
- // Now thread kNThreads - 1 can write the last elements of the current chunk.
102
- if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
103
-
104
- float x_vals[2 * kNElts];
105
- #pragma unroll
106
- for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
107
-
108
- float out_vals[kNElts];
109
- #pragma unroll
110
- for (int i = 0; i < kNElts; ++i) {
111
- out_vals[i] = bias_val;
112
- #pragma unroll
113
- for (int w = 0; w < kWidth; ++w) {
114
- out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
115
- }
116
- }
117
-
118
- if (params.silu_activation) {
119
- #pragma unroll
120
- for (int i = 0; i < kNElts; ++i) {
121
- out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
122
- }
123
- }
124
-
125
- input_t out_vals_store[kNElts];
126
- #pragma unroll
127
- for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
128
- if constexpr(kIsVecLoad) {
129
- typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
130
- } else {
131
- typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
132
- }
133
- out += kChunkSize;
134
- }
135
- }
136
-
137
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
138
- void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
139
- static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
140
- BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
141
- using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
142
- constexpr int kSmemSize = Ktraits::kSmemSize;
143
- dim3 grid(params.batch, params.dim);
144
-
145
- auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
146
-
147
- if (kSmemSize >= 48 * 1024) {
148
- #ifndef USE_ROCM
149
- C10_CUDA_CHECK(cudaFuncSetAttribute(
150
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
151
- #else
152
- // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
153
- C10_CUDA_CHECK(cudaFuncSetAttribute(
154
- (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
155
- std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
156
- #endif
157
- }
158
- kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
159
-
160
- C10_CUDA_KERNEL_LAUNCH_CHECK();
161
- });
162
- }
163
-
164
- template<typename input_t, typename weight_t>
165
- void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
166
- if (params.width == 2) {
167
- causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
168
- } else if (params.width == 3) {
169
- causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
170
- } else if (params.width == 4) {
171
- causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
172
- }
173
- }
174
-
175
- template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
176
- struct Causal_conv1d_channellast_fwd_kernel_traits {
177
- // The cache line is 128 bytes, and we try to read 16 bytes per thread.
178
- // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
179
- // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
180
- // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
181
- using input_t = input_t_;
182
- using weight_t = weight_t_;
183
- static constexpr int kNThreads = kNThreads_;
184
- static_assert(kNThreads % 32 == 0);
185
- static constexpr int kNWarps = kNThreads / 32;
186
- static constexpr int kWidth = kWidth_;
187
- static constexpr int kChunkSizeL = kChunkSizeL_;
188
- static constexpr int kNBytes = sizeof(input_t);
189
- static_assert(kNBytes == 2 || kNBytes == 4);
190
- static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
191
- static constexpr int kNEltsPerRow = 128 / kNBytes;
192
- static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
193
- static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
194
- static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
195
- static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
196
- static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
197
- static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
198
- static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
199
- static constexpr bool kIsVecLoad = kIsVecLoad_;
200
- using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
201
- // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
202
- // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
203
- // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
204
- // sizeof(typename BlockStoreT::TempStorage)});
205
- // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
206
- };
207
-
208
- template<typename Ktraits, bool kHasSeqIdx>
209
- __global__ __launch_bounds__(Ktraits::kNThreads)
210
- void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
211
- constexpr int kWidth = Ktraits::kWidth;
212
- constexpr int kNThreads = Ktraits::kNThreads;
213
- constexpr int kNElts = Ktraits::kNElts;
214
- constexpr int kNWarp = Ktraits::kNWarps;
215
- constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
216
- constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
217
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
218
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
219
- using input_t = typename Ktraits::input_t;
220
- using vec_t = typename Ktraits::vec_t;
221
- using weight_t = typename Ktraits::weight_t;
222
-
223
- // Shared memory.
224
- __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
225
-
226
- const int batch_id = blockIdx.x;
227
- const int chunk_l_id = blockIdx.y;
228
- const int chunk_c_id = blockIdx.z;
229
- const int tid = threadIdx.x;
230
- const int l_idx = tid / kNThreadsPerC;
231
- const int c_idx = tid % kNThreadsPerC;
232
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
233
- + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
234
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
235
- + chunk_c_id * kChunkSizeC * params.weight_c_stride;
236
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
237
- + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
238
- int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
239
- + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
240
- input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
241
- : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
242
- // The last L-chunk will also have enough info to write to final states, since it also contain a few x values
243
- // from the previous L-chunk.
244
- input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
245
- : reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
246
-
247
- #pragma unroll
248
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
249
- input_t x_vals_load[kNElts] = {0};
250
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
251
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
252
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
253
- }
254
- reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
255
- }
256
- // Load the elements from the previous chunk that are needed for convolution.
257
- if (l_idx < kWidth - 1) {
258
- input_t x_vals_load[kNElts] = {0};
259
- if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
260
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
261
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
262
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
263
- } else if (initial_states != nullptr
264
- && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
265
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
266
- reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
267
- }
268
- reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
269
- }
270
-
271
- __syncthreads();
272
-
273
- if (final_states != nullptr
274
- && l_idx < kWidth - 1
275
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
276
- // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
277
- // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
278
- *reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
279
- }
280
-
281
- constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
282
- static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
283
- constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
284
- static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
285
- // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
286
- static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
287
- static_assert((kLPerThread & (kLPerThread - 1)) == 0);
288
- static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
289
- static_assert(kNThreadsPerRow <= 32);
290
-
291
- const int row_idx = tid / kNThreadsPerRow;
292
- const int col_idx = tid % kNThreadsPerRow;
293
-
294
- float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
295
- float weight_vals[kWidth] = {0};
296
- if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
297
- #pragma unroll
298
- for (int w = 0; w < kWidth; ++w) {
299
- weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
300
- }
301
- }
302
- float x_vals[kWidth - 1 + kLPerThread];
303
- #pragma unroll
304
- for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
305
- x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
306
- }
307
- int seq_idx_thread[kWidth - 1 + kLPerThread];
308
- if constexpr (kHasSeqIdx) {
309
- #pragma unroll
310
- for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
311
- seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
312
- }
313
- }
314
-
315
- float out_vals[kLPerThread];
316
- #pragma unroll
317
- for (int i = 0; i < kLPerThread; ++i) {
318
- out_vals[i] = bias_val;
319
- const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
320
- #pragma unroll
321
- for (int w = 0; w < kWidth; ++w) {
322
- if constexpr (!kHasSeqIdx) {
323
- out_vals[i] += weight_vals[w] * x_vals[i + w];
324
- } else {
325
- out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
326
- }
327
- }
328
- if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
329
- }
330
-
331
- __syncthreads();
332
- #pragma unroll
333
- for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
334
- __syncthreads();
335
-
336
- #pragma unroll
337
- for (int l = 0; l < Ktraits::kNLoads; ++l) {
338
- input_t out_vals_store[kNElts];
339
- reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
340
- if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
341
- && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
342
- *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
343
- }
344
- }
345
-
346
- }
347
-
348
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
349
- void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
350
- BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
351
- using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
352
- // constexpr int kSmemSize = Ktraits::kSmemSize;
353
- constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
354
- constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
355
- const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
356
- const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
357
- dim3 grid(params.batch, n_chunks_L, n_chunks_C);
358
- dim3 block(Ktraits::kNThreads);
359
- auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
360
- // if (kSmemSize >= 48 * 1024) {
361
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
362
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
363
- // }
364
- // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
365
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
366
- C10_CUDA_KERNEL_LAUNCH_CHECK();
367
- });
368
- }
369
-
370
- template<typename input_t, typename weight_t>
371
- void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
372
- if (params.width == 2) {
373
- causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
374
- } else if (params.width == 3) {
375
- causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
376
- } else if (params.width == 4) {
377
- causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
378
- }
379
- }
380
-
381
- template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
382
- template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
383
- template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
384
- template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
385
- template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
386
- template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
387
- template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
388
- template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
389
- template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
390
-
391
- template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
392
- template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
393
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
394
- template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
395
- template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
396
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
397
- template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
398
- template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
399
- template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/causal_conv1d_update.cu DELETED
@@ -1,137 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <c10/util/BFloat16.h>
6
- #include <c10/util/Half.h>
7
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
-
9
- #include "causal_conv1d.h"
10
- #include "causal_conv1d_common.h"
11
- #include "static_switch.h"
12
-
13
- template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
14
- struct Causal_conv1d_update_kernel_traits {
15
- using input_t = input_t_;
16
- using weight_t = weight_t_;
17
- static constexpr int kNThreads = kNThreads_;
18
- static constexpr int kWidth = kWidth_;
19
- static constexpr int kNBytes = sizeof(input_t);
20
- static_assert(kNBytes == 2 || kNBytes == 4);
21
- };
22
-
23
- template<typename Ktraits, bool kIsCircularBuffer>
24
- __global__ __launch_bounds__(Ktraits::kNThreads)
25
- void causal_conv1d_update_kernel(ConvParamsBase params) {
26
- constexpr int kWidth = Ktraits::kWidth;
27
- constexpr int kNThreads = Ktraits::kNThreads;
28
- using input_t = typename Ktraits::input_t;
29
- using weight_t = typename Ktraits::weight_t;
30
-
31
- const int tidx = threadIdx.x;
32
- const int batch_id = blockIdx.x;
33
- const int channel_id = blockIdx.y * kNThreads + tidx;
34
- if (channel_id >= params.dim) return;
35
-
36
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
37
- + channel_id * params.x_c_stride;
38
-
39
- // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
40
- // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
41
- const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
42
- ? batch_id
43
- : params.conv_state_indices_ptr[batch_id];
44
- input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
45
- + conv_state_batch_coord * params.conv_state_batch_stride
46
- + channel_id * params.conv_state_c_stride;
47
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
48
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
49
- + channel_id * params.out_c_stride;
50
- float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
51
-
52
- int state_len = params.conv_state_len;
53
- int advance_len = params.seqlen;
54
- int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
55
- int update_idx = cache_seqlen - (kWidth - 1);
56
- update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
57
-
58
- float weight_vals[kWidth] = {0};
59
- #pragma unroll
60
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
61
-
62
- float x_vals[kWidth] = {0};
63
- if constexpr (!kIsCircularBuffer) {
64
- #pragma unroll 2
65
- for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
66
- conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
67
- }
68
- #pragma unroll
69
- for (int i = 0; i < kWidth - 1; ++i) {
70
- input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
71
- if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
72
- conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
73
- }
74
- x_vals[i] = float(state_val);
75
- }
76
- } else {
77
- #pragma unroll
78
- for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
79
- input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
80
- x_vals[i] = float(state_val);
81
- }
82
- }
83
- #pragma unroll 2
84
- for (int i = 0; i < params.seqlen; ++i) {
85
- input_t x_val = x[i * params.x_l_stride];
86
- if constexpr (!kIsCircularBuffer) {
87
- if (i < advance_len && state_len - advance_len + i >= 0) {
88
- conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
89
- }
90
- } else {
91
- conv_state[update_idx * params.conv_state_l_stride] = x_val;
92
- ++update_idx;
93
- update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
94
- }
95
- x_vals[kWidth - 1] = float(x_val);
96
- float out_val = bias_val;
97
- #pragma unroll
98
- for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
99
- if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
100
- out[i * params.out_l_stride] = input_t(out_val);
101
- // Shift the input buffer by 1
102
- #pragma unroll
103
- for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
104
- }
105
- }
106
-
107
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
108
- void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
109
- using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
110
- dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
111
- auto kernel = params.cache_seqlens == nullptr
112
- ? &causal_conv1d_update_kernel<Ktraits, false>
113
- : &causal_conv1d_update_kernel<Ktraits, true>;
114
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
115
- C10_CUDA_KERNEL_LAUNCH_CHECK();
116
- }
117
-
118
- template<typename input_t, typename weight_t>
119
- void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
120
- if (params.width == 2) {
121
- causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
122
- } else if (params.width == 3) {
123
- causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
124
- } else if (params.width == 4) {
125
- causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
126
- }
127
- }
128
-
129
- template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
130
- template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
131
- template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
132
- template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
133
- template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
134
- template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
135
- template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
136
- template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
137
- template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
causal-conv1d/static_switch.h DELETED
@@ -1,25 +0,0 @@
1
- // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
- // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
-
4
- #pragma once
5
-
6
- /// @param COND - a boolean expression to switch by
7
- /// @param CONST_NAME - a name given for the constexpr bool variable.
8
- /// @param ... - code to execute for true and false
9
- ///
10
- /// Usage:
11
- /// ```
12
- /// BOOL_SWITCH(flag, BoolConst, [&] {
13
- /// some_function<BoolConst>(...);
14
- /// });
15
- /// ```
16
- #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
- [&] { \
18
- if (COND) { \
19
- static constexpr bool CONST_NAME = true; \
20
- return __VA_ARGS__(); \
21
- } else { \
22
- static constexpr bool CONST_NAME = false; \
23
- return __VA_ARGS__(); \
24
- } \
25
- }()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1747046372,
21
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1759493343,
77
- "narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1759516823,
102
- "narHash": "sha256-UJVvZHtS9c64Dm4iZRaOKWB+VHI7jzcazGH57KXWeg8=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "e13610a05f67b7296be9ead89ad172a0a088a1c3",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1755963616,
117
- "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
- "owner": "nixos",
119
- "repo": "nixpkgs",
120
- "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "nixos",
125
- "ref": "nixos-unstable-small",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,18 +0,0 @@
1
- {
2
- description = "Flake for attention kernels";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- inherit self;
15
- path = ./.;
16
- pythonCheckInputs = pkgs: with pkgs; [ einops ];
17
- };
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_causal_conv1d.py DELETED
@@ -1,353 +0,0 @@
1
- # Copyright (C) 2024, Tri Dao.
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn.functional as F
7
-
8
- import pytest
9
-
10
- from einops import rearrange
11
-
12
-
13
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update, causal_conv1d_varlen_states
14
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_ref
15
- from causal_conv1d.causal_conv1d_interface import causal_conv1d_update_ref
16
- from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states_ref
17
-
18
-
19
- @pytest.mark.parametrize("return_final_states", [False, True])
20
- # @pytest.mark.parametrize("return_final_states", [True])
21
- @pytest.mark.parametrize("has_initial_states", [False, True])
22
- # @pytest.mark.parametrize("has_initial_states", [False])
23
- @pytest.mark.parametrize("channel_last", [False, True])
24
- # @pytest.mark.parametrize('channel_last', [True])
25
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
26
- # @pytest.mark.parametrize('itype', [torch.float16])
27
- @pytest.mark.parametrize("silu_activation", [False, True])
28
- # @pytest.mark.parametrize('silu_activation', [True])
29
- @pytest.mark.parametrize("has_bias", [False, True])
30
- # @pytest.mark.parametrize('has_bias', [True])
31
- @pytest.mark.parametrize("width", [2, 3, 4])
32
- # @pytest.mark.parametrize('width', [3])
33
- @pytest.mark.parametrize(
34
- "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
35
- )
36
- # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
37
- # @pytest.mark.parametrize('seqlen', [128])
38
- @pytest.mark.parametrize('dim', [64, 4096 + 32])
39
- # @pytest.mark.parametrize('dim', [64])
40
- def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
41
- if not channel_last and (has_initial_states or return_final_states):
42
- pytest.skip("Only channel_last support initial_states or return_final_states")
43
- device = "cuda"
44
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
45
- if itype == torch.bfloat16:
46
- rtol, atol = 1e-2, 5e-2
47
- rtolw, atolw = (1e-3, 1e-3)
48
- # set seed
49
- torch.random.manual_seed(0)
50
- batch = 2
51
- # batch = 1
52
- if not channel_last:
53
- x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
54
- else:
55
- x = rearrange(
56
- torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
57
- ).requires_grad_()
58
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
59
- if has_bias:
60
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
61
- else:
62
- bias = None
63
- if has_initial_states:
64
- initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
65
- else:
66
- initial_states = None
67
- x_ref = x.detach().clone().requires_grad_()
68
- weight_ref = weight.detach().clone().requires_grad_()
69
- bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
70
- initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
71
- activation = None if not silu_activation else "silu"
72
- out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
73
- activation=activation)
74
- out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
75
- if return_final_states:
76
- out, final_states = out
77
- out_ref, final_states_ref = out_ref
78
- print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
79
- print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
80
- assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
81
-
82
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
83
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
84
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
85
-
86
- if return_final_states:
87
- out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
88
- out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
89
-
90
- g = torch.randn_like(out)
91
- out.backward(g)
92
- out_ref.backward(g)
93
-
94
- print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
95
- print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
96
- if has_bias:
97
- print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
98
- if has_initial_states:
99
- print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
100
-
101
- assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
102
- assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
103
- if has_bias:
104
- assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
105
- if has_initial_states:
106
- assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
107
-
108
-
109
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
110
- # @pytest.mark.parametrize('itype', [torch.float16])
111
- @pytest.mark.parametrize("silu_activation", [False, True])
112
- # @pytest.mark.parametrize('silu_activation', [True])
113
- @pytest.mark.parametrize("has_bias", [False, True])
114
- # @pytest.mark.parametrize('has_bias', [True])
115
- @pytest.mark.parametrize("has_cache_seqlens", [False, True])
116
- # @pytest.mark.parametrize('has_cache_seqlens', [True])
117
- @pytest.mark.parametrize("seqlen", [1, 4, 5])
118
- # @pytest.mark.parametrize('seqlen', [4])
119
- @pytest.mark.parametrize("width", [2, 3, 4])
120
- # @pytest.mark.parametrize('width', [4])
121
- @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
122
- # @pytest.mark.parametrize("dim", [2048])
123
- def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
124
- device = "cuda"
125
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
126
- if itype == torch.bfloat16:
127
- rtol, atol = 1e-2, 5e-2
128
- rtolw, atolw = (1e-3, 1e-3)
129
- # set seed
130
- torch.random.manual_seed(0)
131
- batch = 64
132
- # batch = 1
133
- # dim = 64
134
- x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
135
- state_len = torch.randint(width - 1, width + 10, (1,)).item()
136
- conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
137
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
138
- if has_bias:
139
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
140
- else:
141
- bias = None
142
- conv_state_ref = conv_state.detach().clone()
143
- activation = None if not silu_activation else "silu"
144
- cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
145
- if has_cache_seqlens else None)
146
- out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
147
- out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
148
-
149
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
150
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
151
- assert torch.equal(conv_state, conv_state_ref)
152
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
153
-
154
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
155
- # @pytest.mark.parametrize('itype', [torch.float16])
156
- @pytest.mark.parametrize("silu_activation", [False, True])
157
- # @pytest.mark.parametrize('silu_activation', [True])
158
- @pytest.mark.parametrize("has_bias", [False, True])
159
- # @pytest.mark.parametrize('has_bias', [True])
160
- @pytest.mark.parametrize("has_cache_seqlens", [False, True])
161
- # @pytest.mark.parametrize('has_cache_seqlens', [True])
162
- @pytest.mark.parametrize("seqlen", [1, 4, 5])
163
- # @pytest.mark.parametrize('seqlen', [4])
164
- @pytest.mark.parametrize("width", [2, 3, 4])
165
- # @pytest.mark.parametrize('width', [4])
166
- @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
167
- # @pytest.mark.parametrize("dim", [2048])
168
- def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
169
- device = "cuda"
170
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
171
- if itype == torch.bfloat16:
172
- rtol, atol = 1e-2, 5e-2
173
- rtolw, atolw = (1e-3, 1e-3)
174
- # set seed
175
- torch.random.manual_seed(0)
176
- batch = 64
177
- # batch = 1
178
- # dim = 64
179
- x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
180
- state_len = torch.randint(width - 1, width + 10, (1,)).item()
181
-
182
- total_entries = 10 * batch
183
- conv_state = torch.randn(total_entries, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
184
- conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32, device=device)
185
-
186
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
187
- if has_bias:
188
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
189
- else:
190
- bias = None
191
- conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
192
- activation = None if not silu_activation else "silu"
193
- cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
194
- if has_cache_seqlens else None)
195
- out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation,
196
- cache_seqlens=cache_seqlens, conv_state_indices=conv_state_indices)
197
- out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
198
-
199
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
200
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
201
- assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
202
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
203
-
204
-
205
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
206
- # @pytest.mark.parametrize('itype', [torch.float16])
207
- @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
208
- # @pytest.mark.parametrize("dim", [2048])
209
- def test_causal_conv1d_get_states(dim, itype):
210
- device = "cuda"
211
- # set seed
212
- torch.random.manual_seed(0)
213
- seqlens = torch.randint(1, 32, (100,), device=device)
214
- total_seqlen = seqlens.sum().item()
215
- x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
216
- cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
217
- state_len = 20
218
- out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
219
- out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
220
- assert torch.equal(out, out_ref)
221
-
222
-
223
- # @pytest.mark.parametrize("channel_last", [False, True])
224
- @pytest.mark.parametrize('channel_last', [True])
225
- # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
226
- @pytest.mark.parametrize('itype', [torch.bfloat16])
227
- # @pytest.mark.parametrize("silu_activation", [False, True])
228
- @pytest.mark.parametrize('silu_activation', [True])
229
- # @pytest.mark.parametrize("has_bias", [False, True])
230
- @pytest.mark.parametrize('has_bias', [True])
231
- # @pytest.mark.parametrize("width", [2, 3, 4])
232
- @pytest.mark.parametrize('width', [4])
233
- @pytest.mark.parametrize(
234
- # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
235
- "seqlen", [2048]
236
- )
237
- # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
238
- # @pytest.mark.parametrize('seqlen', [128])
239
- def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
240
- device = "cuda"
241
- # set seed
242
- torch.random.manual_seed(0)
243
- batch = 2
244
- # batch = 1
245
- dim = 4096 + 32 # Try dim not divisible by 64
246
- # dim = 64
247
- if not channel_last:
248
- x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
249
- else:
250
- x = rearrange(
251
- torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
252
- ).requires_grad_()
253
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
254
- if has_bias:
255
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
256
- else:
257
- bias = None
258
- activation = None if not silu_activation else "silu"
259
- out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
260
- g = torch.randn_like(out0)
261
- dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
262
- dw_atol = 1e-4
263
- db_atol = 1e-4
264
-
265
- for i in range(10000):
266
- out = causal_conv1d_fn(x, weight, bias, activation=activation)
267
- dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
268
- dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
269
- # if not dw_equal:
270
- # breakpoint()
271
- if has_bias:
272
- db_equal = torch.allclose(db, db0, atol=db_atol)
273
- # if not db_equal:
274
- # breakpoint()
275
- assert torch.equal(out, out0)
276
- assert torch.equal(dx, dx0)
277
- assert dw_equal
278
- if has_bias:
279
- assert dw_equal
280
-
281
-
282
- @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
283
- # @pytest.mark.parametrize('itype', [torch.float16])
284
- @pytest.mark.parametrize("silu_activation", [False, True])
285
- # @pytest.mark.parametrize('silu_activation', [False])
286
- @pytest.mark.parametrize("has_bias", [False, True])
287
- # @pytest.mark.parametrize('has_bias', [False])
288
- @pytest.mark.parametrize("width", [2, 3, 4])
289
- # @pytest.mark.parametrize('width', [2])
290
- @pytest.mark.parametrize(
291
- "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
292
- )
293
- # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
294
- # @pytest.mark.parametrize('seqlen', [2048])
295
- @pytest.mark.parametrize('dim', [64, 4096 + 32])
296
- # @pytest.mark.parametrize('dim', [64])
297
- def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
298
- device = "cuda"
299
- rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
300
- if itype == torch.bfloat16:
301
- rtol, atol = 1e-2, 5e-2
302
- rtolw, atolw = (1e-3, 1e-3)
303
- # set seed
304
- torch.random.manual_seed(seqlen + dim + width)
305
- batch = 3
306
- seqlens = []
307
- for b in range(batch):
308
- nsplits = torch.randint(1, 5, (1,)).item()
309
- eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
310
- seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
311
- assert sum(seqlens[-1]) == seqlen
312
- assert all(s > 0 for s in seqlens[-1])
313
- # Only support channel_last
314
- x = rearrange(
315
- torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
316
- ).requires_grad_()
317
- weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
318
- if has_bias:
319
- bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
320
- else:
321
- bias = None
322
- seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
323
- for sl in seqlens], dim=0)
324
- x_ref = x.detach().clone().requires_grad_()
325
- weight_ref = weight.detach().clone().requires_grad_()
326
- bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
327
- activation = None if not silu_activation else "silu"
328
- out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
329
- out_ref = []
330
- for b in range(batch):
331
- out_ref_b = []
332
- for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
333
- out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
334
- out_ref.append(torch.cat(out_ref_b, dim=2))
335
- out_ref = torch.cat(out_ref, dim=0)
336
-
337
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
338
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
339
- assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
340
-
341
- g = torch.randn_like(out)
342
- out_ref.backward(g)
343
- out.backward(g)
344
-
345
- print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
346
- print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
347
- if has_bias:
348
- print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
349
-
350
- assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
351
- assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
352
- if has_bias:
353
- assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/causal_conv1d/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
- from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
-
4
- __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
 
 
 
 
 
torch-ext/causal_conv1d/causal_conv1d_interface.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
-
8
-
9
- class CausalConv1dFn(torch.autograd.Function):
10
- @staticmethod
11
- def forward(
12
- ctx,
13
- x,
14
- weight,
15
- bias=None,
16
- seq_idx=None,
17
- initial_states=None,
18
- return_final_states=False,
19
- final_states_out=None,
20
- activation=None,
21
- ):
22
- if activation not in [None, "silu", "swish"]:
23
- raise NotImplementedError("activation must be None, silu, or swish")
24
- if x.stride(2) != 1 and x.stride(1) != 1:
25
- x = x.contiguous()
26
- bias = bias.contiguous() if bias is not None else None
27
- if seq_idx is not None:
28
- assert (
29
- initial_states is None
30
- ), "initial_states must be None if seq_idx is not None"
31
- assert (
32
- not return_final_states
33
- ), "If seq_idx is not None, we don't return final_states_out"
34
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
- if initial_states is not None and (
36
- initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
- ):
38
- initial_states = initial_states.contiguous()
39
- if return_final_states:
40
- assert (
41
- x.stride(1) == 1
42
- ), "Only channel-last layout support returning final_states_out"
43
- if final_states_out is not None:
44
- assert (
45
- final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
- )
47
- else:
48
- batch, dim, seqlen = x.shape
49
- width = weight.shape[1]
50
- final_states_out = torch.empty(
51
- batch, width - 1, dim, device=x.device, dtype=x.dtype
52
- ).transpose(1, 2)
53
- else:
54
- final_states_out = None
55
- ctx.activation = activation in ["silu", "swish"]
56
- out = causal_conv1d_fwd_function(
57
- x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
- )
59
- ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
- ctx.return_final_states = return_final_states
61
- ctx.return_dinitial_states = (
62
- initial_states is not None and initial_states.requires_grad
63
- )
64
- return out if not return_final_states else (out, final_states_out)
65
-
66
- @staticmethod
67
- def backward(ctx, dout, *args):
68
- x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
- dfinal_states = args[0] if ctx.return_final_states else None
70
- if dout.stride(2) != 1 and dout.stride(1) != 1:
71
- dout = dout.contiguous()
72
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
- # backward of conv1d with the backward of chunk).
74
- # Here we just pass in None and dx will be allocated in the C++ code.
75
- dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
- x,
77
- weight,
78
- bias,
79
- dout,
80
- seq_idx,
81
- initial_states,
82
- dfinal_states,
83
- None,
84
- ctx.return_dinitial_states,
85
- ctx.activation,
86
- )
87
- return (
88
- dx,
89
- dweight,
90
- dbias if bias is not None else None,
91
- None,
92
- dinitial_states if initial_states is not None else None,
93
- None,
94
- None,
95
- None,
96
- )
97
-
98
-
99
- def causal_conv1d_fn(
100
- x,
101
- weight,
102
- bias=None,
103
- seq_idx=None,
104
- initial_states=None,
105
- return_final_states=False,
106
- final_states_out=None,
107
- activation=None,
108
- ):
109
- """
110
- x: (batch, dim, seqlen)
111
- weight: (dim, width)
112
- bias: (dim,)
113
- seq_idx: (batch, seqlen)
114
- initial_states: (batch, dim, width - 1)
115
- final_states_out: (batch, dim, width - 1), to be written to
116
- activation: either None or "silu" or "swish"
117
-
118
- out: (batch, dim, seqlen)
119
- """
120
- return CausalConv1dFn.apply(
121
- x,
122
- weight,
123
- bias,
124
- seq_idx,
125
- initial_states,
126
- return_final_states,
127
- final_states_out,
128
- activation,
129
- )
130
-
131
-
132
- def causal_conv1d_ref(
133
- x,
134
- weight,
135
- bias=None,
136
- initial_states=None,
137
- return_final_states=False,
138
- final_states_out=None,
139
- activation=None,
140
- ):
141
- """
142
- x: (batch, dim, seqlen)
143
- weight: (dim, width)
144
- bias: (dim,)
145
- initial_states: (batch, dim, width - 1)
146
- final_states_out: (batch, dim, width - 1)
147
-
148
- out: (batch, dim, seqlen)
149
- """
150
- if activation not in [None, "silu", "swish"]:
151
- raise NotImplementedError("activation must be None, silu, or swish")
152
- dtype_in = x.dtype
153
- x = x.to(weight.dtype)
154
- seqlen = x.shape[-1]
155
- dim, width = weight.shape
156
- if initial_states is None:
157
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
- else:
159
- x = torch.cat([initial_states, x], dim=-1)
160
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
- out = out[..., :seqlen]
162
- if return_final_states:
163
- final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
- dtype_in
165
- ) # (batch, dim, width - 1)
166
- if final_states_out is not None:
167
- final_states_out.copy_(final_states)
168
- else:
169
- final_states_out = final_states
170
- out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
- return out if not return_final_states else (out, final_states_out)
172
-
173
-
174
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
- """
176
- x: (batch, dim) or (batch, dim, seqlen)
177
- conv_state: (batch, dim, state_len), where state_len >= width - 1
178
- weight: (dim, width)
179
- bias: (dim,)
180
- cache_seqlens: (batch,), dtype int32.
181
- If not None, the conv_state is treated as a circular buffer.
182
- The conv_state will be updated by copying x to the conv_state starting at the index
183
- @cache_seqlens % state_len.
184
- conv_state_indices: (batch,), dtype int32
185
- If None, the conv_state is a larger tensor along the batch dim,
186
- and we are selecting the batch coords specified by conv_state_indices.
187
- Useful for a continuous batching scenario.
188
-
189
- out: (batch, dim) or (batch, dim, seqlen)
190
- """
191
- if activation not in [None, "silu", "swish"]:
192
- raise NotImplementedError("activation must be None, silu, or swish")
193
- activation = activation in ["silu", "swish"]
194
- unsqueeze = x.dim() == 2
195
- if unsqueeze:
196
- x = x.unsqueeze(-1)
197
- out = causal_conv1d_update_function(
198
- x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
- )
200
- if unsqueeze:
201
- out = out.squeeze(-1)
202
- return out
203
-
204
-
205
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
- """
207
- x: (batch, dim) or (batch, dim, seqlen)
208
- conv_state: (batch, dim, state_len), where state_len >= width - 1
209
- weight: (dim, width)
210
- bias: (dim,)
211
- cache_seqlens: (batch,), dtype int32.
212
- If not None, the conv_state is treated as a circular buffer.
213
- The conv_state will be updated by copying x to the conv_state starting at the index
214
- @cache_seqlens % state_len before performing the convolution.
215
-
216
- out: (batch, dim) or (batch, dim, seqlen)
217
- """
218
- if activation not in [None, "silu", "swish"]:
219
- raise NotImplementedError("activation must be None, silu, or swish")
220
- dtype_in = x.dtype
221
- unsqueeze = x.dim() == 2
222
- if unsqueeze:
223
- x = x.unsqueeze(-1)
224
- batch, dim, seqlen = x.shape
225
- width = weight.shape[1]
226
- state_len = conv_state.shape[-1]
227
- assert conv_state.shape == (batch, dim, state_len)
228
- assert weight.shape == (dim, width)
229
- if cache_seqlens is None:
230
- x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
- conv_state.copy_(x_new[:, :, -state_len:])
232
- else:
233
- width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
- width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
- x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
- copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
- copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
- conv_state.scatter_(2, copy_idx, x)
239
- out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
- if unsqueeze:
241
- out = out.squeeze(-1)
242
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/causal_conv1d/causal_conv1d_varlen.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- @triton.jit
9
- def _causal_conv1d_varlen_states(
10
- X,
11
- CU_SEQLENS,
12
- STATES,
13
- state_len,
14
- dim,
15
- stride_x_seqlen, stride_x_dim,
16
- stride_states_batch, stride_states_seqlen, stride_states_dim,
17
- BLOCK_M: tl.constexpr,
18
- BLOCK_N: tl.constexpr
19
- ):
20
- batch_idx = tl.program_id(2)
21
- STATES += batch_idx * stride_states_batch
22
- end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
- start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
- rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
- cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
- x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
- mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
- other=0)
29
- rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
- tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
- x,
32
- mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
-
34
-
35
- def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
- """
37
- Forward pass only, does not support backward pass.
38
- Parameters:
39
- x: (total_tokens, dim)
40
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
- If some of those elements belong to a different sequence, the value of the states will be zero.
43
- Return:
44
- states: (batch, dim, state_len)
45
- """
46
- _, dim = x.shape
47
- batch = cu_seqlens.shape[0] - 1
48
- cu_seqlens = cu_seqlens.contiguous()
49
- states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
- BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
- BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
- grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
- with torch.cuda.device(x.device.index):
54
- _causal_conv1d_varlen_states[grid](
55
- x,
56
- cu_seqlens,
57
- states,
58
- state_len,
59
- dim,
60
- x.stride(0), x.stride(1),
61
- states.stride(0), states.stride(2), states.stride(1),
62
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
- )
64
- return states
65
-
66
-
67
- def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
- """
69
- Forward pass only, does not support backward pass.
70
- Parameters:
71
- x: (total_tokens, dim)
72
- cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
- state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
- If some of those elements belong to a different sequence, the value of the states will be zero.
75
- Return:
76
- states: (batch, dim, state_len)
77
- """
78
- _, dim = x.shape
79
- batch = cu_seqlens.shape[0] - 1
80
- cu_seqlens = cu_seqlens.contiguous()
81
- states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
- for i in range(batch):
83
- end_idx = cu_seqlens[i + 1]
84
- start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
- states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
- return states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/causal_conv1d/cpp_functions.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import torch
4
-
5
- from ._ops import ops
6
-
7
- def causal_conv1d_fwd_function(
8
- x: torch.Tensor,
9
- weight: torch.Tensor,
10
- bias: torch.Tensor | None,
11
- seq_idx: torch.Tensor | None,
12
- initial_states: torch.Tensor | None,
13
- final_states_out: torch.Tensor | None,
14
- silu_activation: bool,
15
- ) -> torch.Tensor:
16
- out = torch.empty_like(x)
17
- ops.causal_conv1d_fwd(
18
- x=x,
19
- weight=weight,
20
- bias=bias,
21
- seq_idx=seq_idx,
22
- initial_states=initial_states,
23
- out=out,
24
- final_states_out=final_states_out,
25
- silu_activation=silu_activation,
26
- )
27
- return out
28
-
29
-
30
- def causal_conv1d_bwd_function(
31
- x: torch.Tensor,
32
- weight: torch.Tensor,
33
- bias: torch.Tensor | None,
34
- dout: torch.Tensor,
35
- seq_idx: torch.Tensor | None,
36
- initial_states: torch.Tensor | None,
37
- dfinal_states: torch.Tensor | None,
38
- dx: torch.Tensor | None,
39
- return_dinitial_states: torch.Tensor,
40
- silu_activation: bool,
41
- ) -> tuple[torch.Tensor | None]:
42
- batch_size, dim = x.size()[:2]
43
- width = weight.size(-1)
44
-
45
- if dx is None:
46
- dx = torch.empty_like(x)
47
- dweight = torch.zeros_like(weight, dtype=torch.float32)
48
- dbias = None
49
- if bias is not None:
50
- dbias = torch.zeros_like(bias, dtype=torch.float32)
51
- dinitial_states = None
52
- if return_dinitial_states:
53
- dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
-
55
- ops.causal_conv1d_bwd(
56
- x=x,
57
- weight=weight,
58
- bias=bias,
59
- dout=dout,
60
- seq_idx=seq_idx,
61
- initial_states=initial_states,
62
- dfinal_states=dfinal_states,
63
- dx=dx,
64
- dweight=dweight,
65
- dbias=dbias,
66
- dinitial_states=dinitial_states,
67
- silu_activation=silu_activation,
68
- )
69
-
70
- dweight = dweight.type_as(weight)
71
- if dbias is not None:
72
- dbias = dbias.type_as(bias)
73
- return dx, dweight, dbias, dinitial_states
74
-
75
-
76
- def causal_conv1d_update_function(
77
- x: torch.Tensor,
78
- conv_state: torch.Tensor,
79
- weight: torch.Tensor,
80
- bias: torch.Tensor | None,
81
- silu_activation: bool,
82
- cache_seqlens: torch.Tensor | None,
83
- conv_state_indices: torch.Tensor | None,
84
- ) -> torch.Tensor:
85
- out = torch.empty_like(x)
86
- ops.causal_conv1d_update(
87
- x=x,
88
- conv_state=conv_state,
89
- weight=weight,
90
- bias=bias,
91
- out=out,
92
- silu_activation=silu_activation,
93
- cache_seqlens=cache_seqlens,
94
- conv_state_indices=conv_state_indices,
95
- )
96
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/pytorch_shim.h DELETED
@@ -1,105 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/library.h>
4
-
5
- /**
6
- * Unforunately, the type signatures of the flash_attn ops are not compatible
7
- * with the PyTorch library bindings. To get around that we use
8
- * `make_pytorch_shim` which creates a lambda that exponses the API using
9
- * PyTorch compatible types to the types, then converts them to the types
10
- * expected by the flash_attn ops. This shims allows us to make minimal changes
11
- * to `flash_api.cpp` making it easier to synchronize with upstream changes.
12
- *
13
- * The `pytorch_library_compatible_type` struct is used to map from the
14
- * flash_attn ops types to a PyTorch library compatible one. The main issues is
15
- * that the following types are not support by PyTorch libary bindings:
16
- * - `int`
17
- * - `float`
18
- * - `std::optional<T> &`
19
- * - `std::optional<const at::Tensor> &`
20
- * So we convert them to (respectively):
21
- * - `int64_t`
22
- * - `double`
23
- * - `const std::optional<T>&`
24
- * - `const std::optional<at::Tensor>&`
25
- */
26
-
27
- template<typename T>
28
- struct pytorch_library_compatible_type {
29
- using type = T;
30
- static T convert_from_type(T arg) { return arg; }
31
- };
32
-
33
- template<typename T>
34
- using pytorch_library_compatible_type_t = \
35
- typename pytorch_library_compatible_type<T>::type;
36
-
37
- template<typename T>
38
- T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg)
39
- { return pytorch_library_compatible_type<T>::convert_from_type(arg); }
40
-
41
- // Map `std::optional<T> &` -> `const std::optional<T>&`
42
- // (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
43
- // the optional container)
44
- template<typename T>
45
- struct pytorch_library_compatible_type<std::optional<T> &> {
46
- using type = const std::optional<T>&;
47
- static std::optional<T>& convert_from_type(const std::optional<T> &arg) {
48
- return const_cast<std::optional<T>&>(arg);
49
- }
50
- };
51
-
52
- // Map `std::optional<T>` ->
53
- // `std::optional<pytorch_library_compatible_type_t<T>>`
54
- // (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
55
- template<typename T>
56
- struct pytorch_library_compatible_type<std::optional<T>> {
57
- using type = std::optional<pytorch_library_compatible_type_t<T>>;
58
- static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(std::optional<T> arg) {
59
- return arg;
60
- }
61
- };
62
-
63
- // Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
64
- template<>
65
- struct pytorch_library_compatible_type<std::optional<const at::Tensor> &> {
66
- using type = const std::optional<at::Tensor>&;
67
- static std::optional<const at::Tensor>& convert_from_type(
68
- const std::optional<at::Tensor> &arg) {
69
- return const_cast<std::optional<const at::Tensor>&>(
70
- reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
71
- }
72
- };
73
-
74
- // Map `int` -> `int64_t`
75
- template<> struct pytorch_library_compatible_type<int> {
76
- using type = int64_t;
77
- static int convert_from_type(int64_t arg) {
78
- TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
79
- "int64_t value is too large to be converted to int");
80
- TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
81
- "int64_t value is too small to be converted to int");
82
- return arg;
83
- }
84
- };
85
-
86
- // Map `float` -> `double`
87
- template<> struct pytorch_library_compatible_type<float> {
88
- using type = double;
89
- static float convert_from_type(double arg) {
90
- TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
91
- "double value is too large to be converted to float");
92
- return arg;
93
- }
94
- };
95
-
96
- //
97
- // Shim Utils
98
- //
99
-
100
- template <typename Ret, typename... Args>
101
- auto make_pytorch_shim(Ret(*fun)(Args... args)){
102
- return [fun](pytorch_library_compatible_type_t<Args>... args) {
103
- return fun(convert_from_pytorch_compatible_type<Args>(args)...);
104
- };
105
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp DELETED
@@ -1,32 +0,0 @@
1
- #include <torch/library.h>
2
-
3
- #include "registration.h"
4
-
5
- #include "pytorch_shim.h"
6
- #include "torch_binding.h"
7
-
8
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9
- ops.def(
10
- "causal_conv1d_fwd("
11
- " Tensor x, Tensor weight, Tensor? bias, Tensor? seq_idx,"
12
- " Tensor? initial_states, Tensor! out, Tensor!? final_states_out,"
13
- " bool silu_activation) -> ()");
14
- ops.impl("causal_conv1d_fwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_fwd));
15
-
16
- ops.def(
17
- "causal_conv1d_bwd("
18
- " Tensor x, Tensor weight, Tensor? bias, Tensor! dout,"
19
- " Tensor? seq_idx, Tensor? initial_states, Tensor? dfinal_states,"
20
- " Tensor! dx, Tensor! dweight, Tensor!? dbias,"
21
- " Tensor!? dinitial_states, bool silu_activation) -> ()");
22
- ops.impl("causal_conv1d_bwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_bwd));
23
-
24
- ops.def(
25
- "causal_conv1d_update("
26
- " Tensor x, Tensor conv_state, Tensor weight, Tensor? bias,"
27
- " Tensor! out, bool silu_activation, Tensor? cache_seqlens,"
28
- " Tensor? conv_state_indices) -> ()");
29
- ops.impl("causal_conv1d_update", torch::kCUDA, make_pytorch_shim(&causal_conv1d_update));
30
- }
31
-
32
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.h DELETED
@@ -1,39 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/torch.h>
4
-
5
- void
6
- causal_conv1d_fwd(const at::Tensor &x,
7
- const at::Tensor &weight,
8
- const c10::optional<at::Tensor> &bias_,
9
- const c10::optional<at::Tensor> &seq_idx_,
10
- const c10::optional<at::Tensor> &initial_states_,
11
- at::Tensor &out,
12
- c10::optional<at::Tensor> &final_states_out_,
13
- bool silu_activation);
14
-
15
- void
16
- causal_conv1d_bwd(const at::Tensor &x,
17
- const at::Tensor &weight,
18
- const c10::optional<at::Tensor> &bias_,
19
- at::Tensor &dout,
20
- const c10::optional<at::Tensor> &seq_idx_,
21
- const c10::optional<at::Tensor> &initial_states_,
22
- const c10::optional<at::Tensor> &dfinal_states_,
23
- at::Tensor &dx,
24
- at::Tensor &dweight,
25
- c10::optional<at::Tensor> &dbias_,
26
- c10::optional<at::Tensor> &dinitial_states_,
27
- bool silu_activation);
28
-
29
- void
30
- causal_conv1d_update(const at::Tensor &x,
31
- const at::Tensor &conv_state,
32
- const at::Tensor &weight,
33
- const c10::optional<at::Tensor> &bias_,
34
- at::Tensor &out,
35
- bool silu_activation,
36
- const c10::optional<at::Tensor> &cache_seqlens_,
37
- const c10::optional<at::Tensor> &conv_state_indices_
38
- );
39
-