ggerganov commited on
Commit
d4d67dc
·
1 Parent(s): 8fdd994

ggml : sync resolve (skip) (#0)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ggml/CMakeLists.txt +2 -1
  2. ggml/src/ggml-amx.cpp +0 -436
  3. ggml/src/ggml-backend.cpp +6 -671
  4. ggml/src/ggml-blas.cpp +0 -514
  5. ggml/src/ggml-cann.cpp +0 -2128
  6. ggml/src/ggml-cpu-impl.h +0 -614
  7. ggml/src/ggml-cpu.c +0 -0
  8. ggml/src/ggml-cuda.cu +0 -0
  9. ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  10. ggml/src/ggml-kompute.cpp +0 -2184
  11. ggml/src/ggml-metal.m +0 -0
  12. ggml/src/ggml-metal.metal +0 -0
  13. ggml/src/ggml-metal/ggml-metal.m +1 -1
  14. ggml/src/ggml-musa/CMakeLists.txt +100 -0
  15. ggml/src/ggml-rpc.cpp +0 -1403
  16. ggml/src/ggml-sycl.cpp +0 -0
  17. ggml/src/ggml-vulkan.cpp +0 -0
  18. ggml/src/kompute-shaders/common.comp +0 -102
  19. ggml/src/kompute-shaders/op_add.comp +0 -58
  20. ggml/src/kompute-shaders/op_addrow.comp +0 -25
  21. ggml/src/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  22. ggml/src/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  23. ggml/src/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  24. ggml/src/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  25. ggml/src/kompute-shaders/op_diagmask.comp +0 -30
  26. ggml/src/kompute-shaders/op_gelu.comp +0 -22
  27. ggml/src/kompute-shaders/op_getrows.comp +0 -17
  28. ggml/src/kompute-shaders/op_getrows_f16.comp +0 -31
  29. ggml/src/kompute-shaders/op_getrows_f32.comp +0 -31
  30. ggml/src/kompute-shaders/op_getrows_q4_0.comp +0 -38
  31. ggml/src/kompute-shaders/op_getrows_q4_1.comp +0 -39
  32. ggml/src/kompute-shaders/op_getrows_q6_k.comp +0 -44
  33. ggml/src/kompute-shaders/op_mul.comp +0 -52
  34. ggml/src/kompute-shaders/op_mul_mat_f16.comp +0 -67
  35. ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  36. ggml/src/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  37. ggml/src/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  38. ggml/src/kompute-shaders/op_mul_mat_q6_k.comp +0 -94
  39. ggml/src/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  40. ggml/src/kompute-shaders/op_mul_mv_q_n.comp +0 -48
  41. ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -22
  42. ggml/src/kompute-shaders/op_norm.comp +0 -84
  43. ggml/src/kompute-shaders/op_relu.comp +0 -21
  44. ggml/src/kompute-shaders/op_rmsnorm.comp +0 -53
  45. ggml/src/kompute-shaders/op_rope_f16.comp +0 -73
  46. ggml/src/kompute-shaders/op_rope_f32.comp +0 -73
  47. ggml/src/kompute-shaders/op_scale.comp +0 -19
  48. ggml/src/kompute-shaders/op_scale_8.comp +0 -23
  49. ggml/src/kompute-shaders/op_silu.comp +0 -22
  50. ggml/src/kompute-shaders/op_softmax.comp +0 -56
ggml/CMakeLists.txt CHANGED
@@ -225,6 +225,7 @@ set(GGML_PUBLIC_HEADERS
225
  include/ggml-cann.h
226
  include/ggml-cuda.h
227
  include/ggml-kompute.h
 
228
  include/ggml-metal.h
229
  include/ggml-rpc.h
230
  include/ggml-sycl.h
@@ -241,7 +242,7 @@ install(TARGETS ggml-base LIBRARY)
241
  if (GGML_METAL)
242
  # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
243
  install(
244
- FILES ggml/src/ggml-metal/ggml-metal.metal
245
  PERMISSIONS
246
  OWNER_READ
247
  OWNER_WRITE
 
225
  include/ggml-cann.h
226
  include/ggml-cuda.h
227
  include/ggml-kompute.h
228
+ include/ggml-opt.h
229
  include/ggml-metal.h
230
  include/ggml-rpc.h
231
  include/ggml-sycl.h
 
242
  if (GGML_METAL)
243
  # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
244
  install(
245
+ FILES src/ggml-metal/ggml-metal.metal
246
  PERMISSIONS
247
  OWNER_READ
248
  OWNER_WRITE
ggml/src/ggml-amx.cpp DELETED
@@ -1,436 +0,0 @@
1
- #include "ggml-amx.h"
2
- #include "ggml-amx/common.h"
3
- #include "ggml-amx/mmq.h"
4
- #include "ggml-backend-impl.h"
5
- #include "ggml-impl.h"
6
-
7
- #if defined(__gnu_linux__)
8
- #include <sys/syscall.h>
9
- #include <unistd.h>
10
- #endif
11
-
12
- #include <cstdlib>
13
- #include <cstring>
14
- #include <memory>
15
-
16
- #if defined(__AMX_INT8__)
17
-
18
- // AMX buffer interface
19
- static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
20
- free(buffer->context);
21
- }
22
-
23
- static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
24
- return (void *)(buffer->context);
25
- }
26
-
27
- static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
28
- memset((char *)tensor->data + offset, value, size);
29
-
30
- GGML_UNUSED(buffer);
31
- }
32
-
33
- static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
34
- if (qtype_has_amx_kernels(tensor->type)) {
35
- ggml_backend_amx_convert_weight(tensor, data, offset, size);
36
- } else {
37
- memcpy((char *)tensor->data + offset, data, size);
38
- }
39
-
40
- GGML_UNUSED(buffer);
41
- }
42
-
43
- static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
44
- GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
45
- memcpy(data, (const char *)tensor->data + offset, size);
46
-
47
- GGML_UNUSED(buffer);
48
- }
49
-
50
- static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
51
- if (ggml_backend_buffer_is_host(src->buffer)) {
52
- if (qtype_has_amx_kernels(src->type)) {
53
- ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
54
- } else {
55
- memcpy(dst->data, src->data, ggml_nbytes(src));
56
- }
57
- return true;
58
- }
59
- return false;
60
-
61
- GGML_UNUSED(buffer);
62
- }
63
-
64
- static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
65
- memset(buffer->context, value, buffer->size);
66
- }
67
-
68
- static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
69
- /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
70
- /* .get_base = */ ggml_backend_amx_buffer_get_base,
71
- /* .init_tensor = */ NULL, // no initialization required
72
- /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
73
- /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
74
- /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
75
- /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
76
- /* .clear = */ ggml_backend_amx_buffer_clear,
77
- /* .reset = */ NULL,
78
- };
79
-
80
- static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
81
- return "AMX";
82
-
83
- GGML_UNUSED(buft);
84
- }
85
-
86
- static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
87
- void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
88
- if (data == NULL) {
89
- fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
90
- return NULL;
91
- }
92
-
93
- return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
94
- }
95
-
96
- static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
97
- return TENSOR_ALIGNMENT;
98
-
99
- GGML_UNUSED(buft);
100
- }
101
-
102
- static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
103
- return ggml_backend_amx_get_alloc_size(tensor);
104
-
105
- GGML_UNUSED(buft);
106
- }
107
-
108
- static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
109
- return false;
110
-
111
- GGML_UNUSED(buft);
112
- }
113
-
114
- ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
115
- static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
116
- /* .iface = */ {
117
- /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
118
- /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
119
- /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
120
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
121
- /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
122
- /* .is_host = */ ggml_backend_amx_buffer_type_is_host,
123
- },
124
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
125
- /* .context = */ NULL,
126
- };
127
-
128
- return &ggml_backend_buffer_type_amx;
129
- }
130
-
131
- // backend interface
132
-
133
- static const char * ggml_backend_amx_name(ggml_backend_t backend) {
134
- return "AMX";
135
-
136
- GGML_UNUSED(backend);
137
- }
138
-
139
- static void ggml_backend_amx_free(ggml_backend_t backend) {
140
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
141
- delete ctx;
142
- delete backend;
143
- }
144
-
145
- static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
146
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
147
-
148
- for (int i = 0; i < cgraph->n_nodes; i++) {
149
- struct ggml_tensor * node = cgraph->nodes[i];
150
-
151
- switch (node->op) {
152
- case GGML_OP_MUL_MAT:
153
- ggml_backend_amx_mul_mat(ctx, node);
154
- break;
155
-
156
- case GGML_OP_NONE:
157
- case GGML_OP_RESHAPE:
158
- case GGML_OP_VIEW:
159
- case GGML_OP_PERMUTE:
160
- case GGML_OP_TRANSPOSE:
161
- break;
162
-
163
- default:
164
- fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
165
- GGML_ASSERT(false);
166
- }
167
- }
168
-
169
- return GGML_STATUS_SUCCESS;
170
-
171
- GGML_UNUSED(backend);
172
- }
173
-
174
- static struct ggml_backend_i ggml_backend_amx_i = {
175
- /* .get_name = */ ggml_backend_amx_name,
176
- /* .free = */ ggml_backend_amx_free,
177
- /* .set_tensor_async = */ NULL,
178
- /* .get_tensor_async = */ NULL,
179
- /* .cpy_tensor_async = */ NULL,
180
- /* .synchronize = */ NULL,
181
- /* .graph_plan_create = */ NULL,
182
- /* .graph_plan_free = */ NULL,
183
- /* .graph_plan_update = */ NULL,
184
- /* .graph_plan_compute = */ NULL,
185
- /* .graph_compute = */ ggml_backend_amx_graph_compute,
186
- /* .event_record = */ NULL,
187
- /* .event_wait = */ NULL,
188
- };
189
-
190
- static ggml_guid_t ggml_backend_amx_guid() {
191
- static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
192
- return &guid;
193
- }
194
-
195
- #define ARCH_GET_XCOMP_PERM 0x1022
196
- #define ARCH_REQ_XCOMP_PERM 0x1023
197
- #define XFEATURE_XTILECFG 17
198
- #define XFEATURE_XTILEDATA 18
199
-
200
- static bool ggml_amx_init() {
201
- #if defined(__gnu_linux__)
202
- if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
203
- fprintf(stderr, "AMX is not ready to be used!\n");
204
- return false;
205
- }
206
- return true;
207
- #elif defined(_WIN32)
208
- return true;
209
- #endif
210
- }
211
-
212
- ggml_backend_t ggml_backend_amx_init() {
213
-
214
- // invoke a Linux system call to request access to AMX features
215
- ggml_amx_init();
216
-
217
- // backend context
218
- ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
219
-
220
- // ggml amx backend
221
- ggml_backend_t backend = new ggml_backend {
222
- /* .guid = */ ggml_backend_amx_guid(),
223
- /* .interface = */ ggml_backend_amx_i,
224
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
225
- /* .context = */ ctx,
226
- };
227
-
228
- return backend;
229
- }
230
-
231
- bool ggml_backend_is_amx(ggml_backend_t backend) {
232
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
233
- }
234
-
235
- void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
236
- GGML_ASSERT(ggml_backend_is_amx(backend_amx));
237
-
238
- ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
239
- ctx->n_threads = n_threads;
240
- }
241
-
242
- // device interface
243
-
244
- static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
245
- return "AMX";
246
-
247
- GGML_UNUSED(dev);
248
- }
249
-
250
- static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
251
- return "Intel Advanced Matrix Extensions";
252
-
253
- GGML_UNUSED(dev);
254
- }
255
-
256
- static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
257
- // TODO
258
- *free = 0;
259
- *total = 0;
260
-
261
- GGML_UNUSED(dev);
262
- }
263
-
264
- static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
265
- return GGML_BACKEND_DEVICE_TYPE_ACCEL;
266
-
267
- GGML_UNUSED(dev);
268
- }
269
-
270
- static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
271
- props->name = ggml_backend_amx_device_get_name(dev);
272
- props->description = ggml_backend_amx_device_get_description(dev);
273
- props->type = ggml_backend_amx_device_get_type(dev);
274
- ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
275
-
276
- // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
277
- props->caps = {
278
- /* .async = */ false,
279
- /* .host_buffer = */ false,
280
- /* .buffer_from_host_ptr = */ false,
281
- /* .events = */ false,
282
- };
283
- }
284
-
285
- static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
286
- return ggml_backend_amx_init();
287
-
288
- GGML_UNUSED(dev);
289
- GGML_UNUSED(params);
290
- }
291
-
292
- static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
293
- return ggml_backend_amx_buffer_type();
294
-
295
- GGML_UNUSED(dev);
296
- }
297
-
298
- static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
299
-
300
- // handle only 2d gemm for now
301
- auto is_contiguous_2d = [](const struct ggml_tensor * t) {
302
- return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
303
- };
304
-
305
- switch (op->op) {
306
- case GGML_OP_NONE:
307
- case GGML_OP_RESHAPE:
308
- case GGML_OP_VIEW:
309
- case GGML_OP_PERMUTE:
310
- case GGML_OP_TRANSPOSE:
311
- return true;
312
-
313
- case GGML_OP_MUL_MAT: {
314
- const struct ggml_tensor * src0 = op->src[0];
315
- const struct ggml_tensor * src1 = op->src[1];
316
-
317
- const enum ggml_type type = src0->type;
318
- const int64_t ne0 = op->ne[0];
319
-
320
- bool is_training = src0->grad || src1->grad;
321
-
322
- // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
323
- // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
324
- bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
325
-
326
- bool can_use_amx =
327
- is_contiguous_2d(src0) && // src0 must be contiguous
328
- is_contiguous_2d(src1) && // src1 must be contiguous
329
- !is_training && // inference only
330
- src1->type == GGML_TYPE_F32 && // src1 must be float32
331
- has_amx_kernels && // with amx kernel impls
332
- ne0 % (TILE_N * 2) == 0; // out_features is 32x
333
-
334
- return can_use_amx;
335
- }
336
- default:
337
- return false;
338
- }
339
-
340
- GGML_UNUSED(dev);
341
- }
342
-
343
- static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
344
- return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
345
-
346
- GGML_UNUSED(dev);
347
- }
348
-
349
- static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
350
- /* .get_name = */ ggml_backend_amx_device_get_name,
351
- /* .get_description = */ ggml_backend_amx_device_get_description,
352
- /* .get_memory = */ ggml_backend_amx_device_get_memory,
353
- /* .get_type = */ ggml_backend_amx_device_get_type,
354
- /* .get_props = */ ggml_backend_amx_device_get_props,
355
- /* .init_backend = */ ggml_backend_amx_device_init,
356
- /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
357
- /* .get_host_buffer_type = */ NULL,
358
- /* .buffer_from_host_ptr = */ NULL,
359
- /* .supports_op = */ ggml_backend_amx_device_supports_op,
360
- /* .supports_buft = */ ggml_backend_amx_device_supports_buft,
361
- /* .offload_op = */ NULL,
362
- /* .event_new = */ NULL,
363
- /* .event_free = */ NULL,
364
- /* .event_synchronize = */ NULL,
365
- };
366
-
367
- // backend reg interface
368
-
369
- static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
370
- return "AMX";
371
-
372
- GGML_UNUSED(reg);
373
- }
374
-
375
- static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
376
- return 1;
377
-
378
- GGML_UNUSED(reg);
379
- }
380
-
381
- static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
382
- GGML_ASSERT(index == 0);
383
-
384
- static ggml_backend_device ggml_backend_amx_device = {
385
- /* .iface = */ ggml_backend_amx_device_i,
386
- /* .reg = */ reg,
387
- /* .context = */ nullptr,
388
- };
389
-
390
- return &ggml_backend_amx_device;
391
-
392
- GGML_UNUSED(reg);
393
- GGML_UNUSED(index);
394
- }
395
-
396
- static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
397
- if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
398
- return (void *)ggml_backend_amx_set_n_threads;
399
- }
400
- return NULL;
401
-
402
- GGML_UNUSED(reg);
403
- GGML_UNUSED(name);
404
- }
405
-
406
- static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
407
- /* .get_name = */ ggml_backend_amx_reg_get_name,
408
- /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
409
- /* .get_device = */ ggml_backend_amx_reg_get_device,
410
- /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
411
- };
412
-
413
- ggml_backend_reg_t ggml_backend_amx_reg(void) {
414
- static struct ggml_backend_reg ggml_backend_amx_reg = {
415
- /* .iface = */ ggml_backend_amx_reg_i,
416
- /* .context = */ NULL,
417
- };
418
-
419
- return &ggml_backend_amx_reg;
420
- }
421
-
422
- #else // if defined(__AMX_INT8__)
423
-
424
- ggml_backend_t ggml_backend_amx_init(void) {
425
- fprintf(stderr, "GGML is not compiled with AMX support!\n");
426
- return ggml_backend_t{};
427
- }
428
-
429
- void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
430
- fprintf(stderr, "GGML is not compiled with AMX support!\n");
431
-
432
- GGML_UNUSED(backend_amx);
433
- GGML_UNUSED(n_threads);
434
- }
435
-
436
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-backend.cpp CHANGED
@@ -525,197 +525,6 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
525
  return reg->iface.get_proc_address(reg, name);
526
  }
527
 
528
- // Backend registry
529
-
530
- #ifdef GGML_USE_CUDA
531
- #include "ggml-cuda.h"
532
- #endif
533
-
534
- #ifdef GGML_USE_METAL
535
- #include "ggml-metal.h"
536
- #endif
537
-
538
- #ifdef GGML_USE_SYCL
539
- #include "ggml-sycl.h"
540
- #endif
541
-
542
- #ifdef GGML_USE_VULKAN
543
- #include "ggml-vulkan.h"
544
- #endif
545
-
546
- #ifdef GGML_USE_BLAS
547
- #include "ggml-blas.h"
548
- #endif
549
-
550
- #ifdef GGML_USE_RPC
551
- #include "ggml-rpc.h"
552
- #endif
553
-
554
- #ifndef __AMX_INT8__
555
- #undef GGML_USE_AMX
556
- #endif
557
-
558
- #ifdef GGML_USE_AMX
559
- # include "ggml-amx.h"
560
- #endif
561
-
562
- #ifdef GGML_USE_CANN
563
- #include "ggml-cann.h"
564
- #endif
565
-
566
- #ifdef GGML_USE_KOMPUTE
567
- #include "ggml-kompute.h"
568
- #endif
569
-
570
- #include "ggml-cpu.h"
571
-
572
- struct ggml_backend_registry {
573
- std::vector<ggml_backend_reg_t> backends;
574
- std::vector<ggml_backend_dev_t> devices;
575
-
576
- ggml_backend_registry() {
577
- #ifdef GGML_USE_CUDA
578
- register_backend(ggml_backend_cuda_reg());
579
- #endif
580
- #ifdef GGML_USE_METAL
581
- register_backend(ggml_backend_metal_reg());
582
- #endif
583
- #ifdef GGML_USE_SYCL
584
- register_backend(ggml_backend_sycl_reg());
585
- #endif
586
- #ifdef GGML_USE_VULKAN
587
- register_backend(ggml_backend_vk_reg());
588
- #endif
589
- #ifdef GGML_USE_CANN
590
- register_backend(ggml_backend_cann_reg());
591
- #endif
592
- #ifdef GGML_USE_BLAS
593
- register_backend(ggml_backend_blas_reg());
594
- #endif
595
- #ifdef GGML_USE_RPC
596
- register_backend(ggml_backend_rpc_reg());
597
- #endif
598
- #ifdef GGML_USE_AMX
599
- register_backend(ggml_backend_amx_reg());
600
- #endif
601
- #ifdef GGML_USE_KOMPUTE
602
- register_backend(ggml_backend_kompute_reg());
603
- #endif
604
-
605
- register_backend(ggml_backend_cpu_reg());
606
- }
607
-
608
- void register_backend(ggml_backend_reg_t reg) {
609
- #ifndef NDEBUG
610
- GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
611
- __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
612
- #endif
613
- backends.push_back(reg);
614
- for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
615
- register_device(ggml_backend_reg_dev_get(reg, i));
616
- }
617
- }
618
-
619
- void register_device(ggml_backend_dev_t device) {
620
- #ifndef NDEBUG
621
- GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
622
- #endif
623
- devices.push_back(device);
624
- }
625
- };
626
-
627
- static ggml_backend_registry & get_reg() {
628
- static ggml_backend_registry reg;
629
- return reg;
630
- }
631
-
632
- // Internal API
633
- void ggml_backend_register(ggml_backend_reg_t reg) {
634
- get_reg().register_backend(reg);
635
- }
636
-
637
- void ggml_backend_device_register(ggml_backend_dev_t device) {
638
- get_reg().register_device(device);
639
- }
640
-
641
- // Backend (reg) enumeration
642
- size_t ggml_backend_reg_count() {
643
- return get_reg().backends.size();
644
- }
645
-
646
- ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
647
- GGML_ASSERT(index < ggml_backend_reg_count());
648
- return get_reg().backends[index];
649
- }
650
-
651
- ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
652
- for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
653
- ggml_backend_reg_t reg = ggml_backend_reg_get(i);
654
- if (strcmp(ggml_backend_reg_name(reg), name) == 0) {
655
- return reg;
656
- }
657
- }
658
- return NULL;
659
- }
660
-
661
- // Device enumeration
662
- size_t ggml_backend_dev_count() {
663
- return get_reg().devices.size();
664
- }
665
-
666
- ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
667
- GGML_ASSERT(index < ggml_backend_dev_count());
668
- return get_reg().devices[index];
669
- }
670
-
671
- ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
672
- for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
673
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
674
- if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
675
- return dev;
676
- }
677
- }
678
- return NULL;
679
- }
680
-
681
- ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
682
- for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
683
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
684
- if (ggml_backend_dev_type(dev) == type) {
685
- return dev;
686
- }
687
- }
688
- return NULL;
689
- }
690
-
691
- // Convenience functions
692
- ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
693
- ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
694
- if (!dev) {
695
- return NULL;
696
- }
697
- return ggml_backend_dev_init(dev, params);
698
- }
699
-
700
- ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
701
- ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
702
- if (!dev) {
703
- return NULL;
704
- }
705
- return ggml_backend_dev_init(dev, params);
706
- }
707
-
708
- ggml_backend_t ggml_backend_init_best(void) {
709
- ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
710
- if (!dev) {
711
- dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
712
- }
713
- if (!dev) {
714
- return NULL;
715
- }
716
- return ggml_backend_dev_init(dev, NULL);
717
- }
718
-
719
  // multi-buffer buffer
720
 
721
  struct ggml_backend_multi_buffer_context {
@@ -1641,7 +1450,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
1641
  bool parallel) {
1642
  GGML_ASSERT(n_backends > 0);
1643
  GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
1644
- GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
1645
 
1646
  struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
1647
 
@@ -2038,17 +1847,6 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
2038
  return true;
2039
  }
2040
 
2041
-
2042
-
2043
- #include "ggml-backend.h"
2044
- #include "ggml-backend-impl.h"
2045
- #include "ggml-cpu.h"
2046
- #include "ggml-impl.h"
2047
- #include <cctype>
2048
- #include <string>
2049
-
2050
- // ggml-backend interface
2051
-
2052
  // CPU backend - buffer
2053
 
2054
  static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
@@ -2122,7 +1920,9 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
2122
  /* .reset = */ NULL,
2123
  };
2124
 
2125
- // CPU backend - buffer type
 
 
2126
 
2127
  static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2128
  return "CPU";
@@ -2163,7 +1963,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
2163
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2164
  /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
2165
  },
2166
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
2167
  /* .context = */ NULL,
2168
  };
2169
 
@@ -2186,479 +1986,14 @@ static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) {
2186
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2187
  /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
2188
  },
2189
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
2190
  /* .context = */ NULL,
2191
  };
2192
 
2193
  return &ggml_backend_cpu_buffer_type;
2194
  }
2195
 
2196
- #ifdef GGML_USE_CPU_HBM
2197
-
2198
- // buffer type HBM
2199
-
2200
- #include <hbwmalloc.h>
2201
-
2202
- static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2203
- return "CPU_HBM";
2204
-
2205
- GGML_UNUSED(buft);
2206
- }
2207
-
2208
- static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2209
- hbw_free(buffer->context);
2210
- }
2211
-
2212
- static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2213
- void * ptr;
2214
- int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
2215
- if (result != 0) {
2216
- GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
2217
- return NULL;
2218
- }
2219
-
2220
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
2221
- buffer->buft = buft;
2222
- buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
2223
-
2224
- return buffer;
2225
- }
2226
-
2227
- ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
2228
- static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
2229
- /* .iface = */ {
2230
- /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
2231
- /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
2232
- /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
2233
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
2234
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2235
- /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
2236
- },
2237
- /* .context = */ NULL,
2238
- };
2239
-
2240
- return &ggml_backend_cpu_buffer_type_hbm;
2241
- }
2242
- #endif
2243
-
2244
- static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
2245
- static ggml_backend_buffer_type_t bufts[] = {
2246
- #ifdef GGML_USE_CPU_HBM
2247
- ggml_backend_cpu_hbm_buffer_type(),
2248
- #endif
2249
- NULL
2250
- };
2251
-
2252
- return bufts;
2253
-
2254
- GGML_UNUSED(device);
2255
- }
2256
-
2257
- // CPU backend - backend (stream)
2258
-
2259
- struct ggml_backend_cpu_context {
2260
- int n_threads;
2261
- ggml_threadpool_t threadpool;
2262
-
2263
- uint8_t * work_data;
2264
- size_t work_size;
2265
-
2266
- ggml_abort_callback abort_callback;
2267
- void * abort_callback_data;
2268
- };
2269
-
2270
- static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
2271
- return "CPU";
2272
-
2273
- GGML_UNUSED(backend);
2274
- }
2275
-
2276
- static void ggml_backend_cpu_free(ggml_backend_t backend) {
2277
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
2278
- delete[] cpu_ctx->work_data;
2279
- delete cpu_ctx;
2280
- delete backend;
2281
- }
2282
-
2283
- struct ggml_backend_plan_cpu {
2284
- struct ggml_cplan cplan;
2285
- struct ggml_cgraph cgraph;
2286
- };
2287
-
2288
- static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
2289
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
2290
-
2291
- struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
2292
-
2293
- cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
2294
- cpu_plan->cgraph = *cgraph; // FIXME: deep copy
2295
-
2296
- if (cpu_plan->cplan.work_size > 0) {
2297
- cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
2298
- if (cpu_plan->cplan.work_data == NULL) {
2299
- delete cpu_plan;
2300
- return NULL;
2301
- }
2302
- }
2303
-
2304
- cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
2305
- cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
2306
-
2307
- return cpu_plan;
2308
- }
2309
-
2310
- static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
2311
- struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
2312
-
2313
- delete[] cpu_plan->cplan.work_data;
2314
- delete cpu_plan;
2315
-
2316
- GGML_UNUSED(backend);
2317
- }
2318
-
2319
- static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
2320
- struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
2321
-
2322
- return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
2323
-
2324
- GGML_UNUSED(backend);
2325
- }
2326
-
2327
- static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2328
- struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
2329
-
2330
- struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
2331
-
2332
- if (cpu_ctx->work_size < cplan.work_size) {
2333
- delete[] cpu_ctx->work_data;
2334
- cpu_ctx->work_data = new uint8_t[cplan.work_size];
2335
- if (cpu_ctx->work_data == NULL) {
2336
- cpu_ctx->work_size = 0;
2337
- return GGML_STATUS_ALLOC_FAILED;
2338
- }
2339
- cpu_ctx->work_size = cplan.work_size;
2340
- }
2341
- cplan.work_data = (uint8_t *)cpu_ctx->work_data;
2342
-
2343
- cplan.abort_callback = cpu_ctx->abort_callback;
2344
- cplan.abort_callback_data = cpu_ctx->abort_callback_data;
2345
-
2346
- return ggml_graph_compute(cgraph, &cplan);
2347
- }
2348
-
2349
- static const struct ggml_backend_i ggml_backend_cpu_i = {
2350
- /* .get_name = */ ggml_backend_cpu_get_name,
2351
- /* .free = */ ggml_backend_cpu_free,
2352
- /* .set_tensor_async = */ NULL,
2353
- /* .get_tensor_async = */ NULL,
2354
- /* .cpy_tensor_async = */ NULL,
2355
- /* .synchronize = */ NULL,
2356
- /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
2357
- /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
2358
- /* .graph_plan_update = */ NULL,
2359
- /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
2360
- /* .graph_compute = */ ggml_backend_cpu_graph_compute,
2361
- /* .event_record = */ NULL,
2362
- /* .event_wait = */ NULL,
2363
- };
2364
-
2365
- static ggml_guid_t ggml_backend_cpu_guid(void) {
2366
- static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
2367
- return &guid;
2368
- }
2369
-
2370
- ggml_backend_t ggml_backend_cpu_init(void) {
2371
- // initialize CPU backend now to avoid slowing the first graph computation
2372
- ggml_cpu_init();
2373
-
2374
- struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
2375
- if (ctx == NULL) {
2376
- return NULL;
2377
- }
2378
-
2379
- ctx->n_threads = GGML_DEFAULT_N_THREADS;
2380
- ctx->threadpool = NULL;
2381
- ctx->work_data = NULL;
2382
- ctx->work_size = 0;
2383
- ctx->abort_callback = NULL;
2384
- ctx->abort_callback_data = NULL;
2385
-
2386
- ggml_backend_t cpu_backend = new ggml_backend {
2387
- /* .guid = */ ggml_backend_cpu_guid(),
2388
- /* .interface = */ ggml_backend_cpu_i,
2389
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
2390
- /* .context = */ ctx,
2391
- };
2392
-
2393
- if (cpu_backend == NULL) {
2394
- delete ctx;
2395
- return NULL;
2396
- }
2397
-
2398
- return cpu_backend;
2399
- }
2400
-
2401
- bool ggml_backend_is_cpu(ggml_backend_t backend) {
2402
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
2403
- }
2404
-
2405
- void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
2406
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
2407
-
2408
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
2409
- ctx->n_threads = n_threads;
2410
- }
2411
-
2412
- void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
2413
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
2414
-
2415
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
2416
-
2417
- if (ctx->threadpool && ctx->threadpool != threadpool) {
2418
- // already had a different threadpool, pause/suspend it before switching
2419
- ggml_threadpool_pause(ctx->threadpool);
2420
- }
2421
- ctx->threadpool = threadpool;
2422
- }
2423
-
2424
- void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
2425
- GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
2426
-
2427
- struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
2428
- ctx->abort_callback = abort_callback;
2429
- ctx->abort_callback_data = abort_callback_data;
2430
- }
2431
-
2432
  ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
2433
  GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
2434
  return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
2435
  }
2436
-
2437
- // CPU backend - device
2438
-
2439
- struct ggml_backend_cpu_device_context {
2440
- std::string description = "CPU";
2441
-
2442
- ggml_backend_cpu_device_context() {
2443
- #ifdef __APPLE__
2444
- size_t len = 0;
2445
- if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
2446
- description.resize(len);
2447
- sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
2448
- }
2449
- #elif defined(__linux__)
2450
- FILE * f = fopen("/proc/cpuinfo", "r");
2451
- if (f) {
2452
- char buf[1024];
2453
- while (fgets(buf, sizeof(buf), f)) {
2454
- if (strncmp(buf, "model name", 10) == 0) {
2455
- char * p = strchr(buf, ':');
2456
- if (p) {
2457
- p++;
2458
- while (std::isspace(*p)) {
2459
- p++;
2460
- }
2461
- while (std::isspace(p[strlen(p) - 1])) {
2462
- p[strlen(p) - 1] = '\0';
2463
- }
2464
- description = p;
2465
- break;
2466
- }
2467
- }
2468
- }
2469
- fclose(f);
2470
- }
2471
- #elif defined(_WIN32)
2472
- HKEY hKey;
2473
- if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
2474
- TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
2475
- 0,
2476
- KEY_READ,
2477
- &hKey) == ERROR_SUCCESS) {
2478
- DWORD cpu_brand_size = 0;
2479
- if (RegQueryValueExA(hKey,
2480
- TEXT("ProcessorNameString"),
2481
- NULL,
2482
- NULL,
2483
- NULL,
2484
- &cpu_brand_size) == ERROR_SUCCESS) {
2485
- description.resize(cpu_brand_size);
2486
- if (RegQueryValueExA(hKey,
2487
- TEXT("ProcessorNameString"),
2488
- NULL,
2489
- NULL,
2490
- (LPBYTE)&description[0], // NOLINT
2491
- &cpu_brand_size) == ERROR_SUCCESS) {
2492
- if (description.find('\0') != std::string::npos) {
2493
- description.resize(description.find('\0'));
2494
- }
2495
- }
2496
- }
2497
- RegCloseKey(hKey);
2498
- }
2499
- #endif
2500
- }
2501
- };
2502
-
2503
- static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
2504
- return "CPU";
2505
-
2506
- GGML_UNUSED(dev);
2507
- }
2508
-
2509
- static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
2510
- struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
2511
-
2512
- return ctx->description.c_str();
2513
- }
2514
-
2515
- static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2516
- // TODO
2517
- *free = 0;
2518
- *total = 0;
2519
-
2520
- GGML_UNUSED(dev);
2521
- }
2522
-
2523
- static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
2524
- return GGML_BACKEND_DEVICE_TYPE_CPU;
2525
-
2526
- GGML_UNUSED(dev);
2527
- }
2528
-
2529
- static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2530
- props->name = ggml_backend_cpu_device_get_name(dev);
2531
- props->description = ggml_backend_cpu_device_get_description(dev);
2532
- props->type = ggml_backend_cpu_device_get_type(dev);
2533
- ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
2534
- props->caps = {
2535
- /* .async = */ false,
2536
- /* .host_buffer = */ false,
2537
- /* .buffer_from_host_ptr = */ true,
2538
- /* .events = */ false,
2539
- };
2540
- }
2541
-
2542
- static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
2543
- return ggml_backend_cpu_init();
2544
-
2545
- GGML_UNUSED(dev);
2546
- GGML_UNUSED(params);
2547
- }
2548
-
2549
- static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
2550
- return ggml_backend_cpu_buffer_type();
2551
-
2552
- GGML_UNUSED(dev);
2553
- }
2554
-
2555
- static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
2556
- return ggml_backend_cpu_buffer_from_ptr(ptr, size);
2557
-
2558
- GGML_UNUSED(dev);
2559
- GGML_UNUSED(max_tensor_size);
2560
- }
2561
-
2562
- static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2563
- switch (op->op) {
2564
- case GGML_OP_CPY:
2565
- return
2566
- op->type != GGML_TYPE_IQ2_XXS &&
2567
- op->type != GGML_TYPE_IQ2_XS &&
2568
- op->type != GGML_TYPE_IQ1_S &&
2569
- op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
2570
- case GGML_OP_MUL_MAT:
2571
- //return op->src[1]->type == GGML_TYPE_F32; // TMP: workaround until sync with latest ggml
2572
- return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits_cpu(op->src[0]->type)->vec_dot_type;
2573
- case GGML_OP_ROPE_BACK:
2574
- return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
2575
- case GGML_OP_IM2COL_BACK:
2576
- return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
2577
- case GGML_OP_OUT_PROD:
2578
- return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
2579
- default:
2580
- return true;
2581
- }
2582
-
2583
- GGML_UNUSED(dev);
2584
- }
2585
-
2586
- static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2587
- return ggml_backend_buft_is_host(buft);
2588
-
2589
- GGML_UNUSED(dev);
2590
- }
2591
-
2592
- static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
2593
- /* .get_name = */ ggml_backend_cpu_device_get_name,
2594
- /* .get_description = */ ggml_backend_cpu_device_get_description,
2595
- /* .get_memory = */ ggml_backend_cpu_device_get_memory,
2596
- /* .get_type = */ ggml_backend_cpu_device_get_type,
2597
- /* .get_props = */ ggml_backend_cpu_device_get_props,
2598
- /* .init_backend = */ ggml_backend_cpu_device_init_backend,
2599
- /* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type,
2600
- /* .get_host_buffer_type = */ NULL,
2601
- /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
2602
- /* .supports_op = */ ggml_backend_cpu_device_supports_op,
2603
- /* .supports_buft = */ ggml_backend_cpu_device_supports_buft,
2604
- /* .offload_op = */ NULL,
2605
- /* .event_new = */ NULL,
2606
- /* .event_free = */ NULL,
2607
- /* .event_synchronize = */ NULL,
2608
- };
2609
-
2610
- // CPU backend - backend (reg)
2611
-
2612
- static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
2613
- return "CPU";
2614
-
2615
- GGML_UNUSED(reg);
2616
- }
2617
-
2618
- static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
2619
- return 1;
2620
-
2621
- GGML_UNUSED(reg);
2622
- }
2623
-
2624
- static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2625
- GGML_ASSERT(index == 0);
2626
-
2627
- static ggml_backend_cpu_device_context ctx;
2628
- static ggml_backend_device ggml_backend_cpu_device = {
2629
- /* .iface = */ ggml_backend_cpu_device_i,
2630
- /* .reg = */ reg,
2631
- /* .context = */ &ctx,
2632
- };
2633
-
2634
- return &ggml_backend_cpu_device;
2635
- }
2636
-
2637
- static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2638
- if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
2639
- return (void *)ggml_backend_cpu_set_n_threads;
2640
- }
2641
- if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
2642
- return (void *)ggml_backend_cpu_get_extra_bufts;
2643
- }
2644
-
2645
- return NULL;
2646
-
2647
- GGML_UNUSED(reg);
2648
- }
2649
-
2650
- static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
2651
- /* .get_name = */ ggml_backend_cpu_reg_get_name,
2652
- /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
2653
- /* .get_device = */ ggml_backend_cpu_reg_get_device,
2654
- /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
2655
- };
2656
-
2657
- ggml_backend_reg_t ggml_backend_cpu_reg(void) {
2658
- static struct ggml_backend_reg ggml_backend_cpu_reg = {
2659
- /* .iface = */ ggml_backend_cpu_reg_i,
2660
- /* .context = */ NULL,
2661
- };
2662
-
2663
- return &ggml_backend_cpu_reg;
2664
- }
 
525
  return reg->iface.get_proc_address(reg, name);
526
  }
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  // multi-buffer buffer
529
 
530
  struct ggml_backend_multi_buffer_context {
 
1450
  bool parallel) {
1451
  GGML_ASSERT(n_backends > 0);
1452
  GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
1453
+ GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
1454
 
1455
  struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
1456
 
 
1847
  return true;
1848
  }
1849
 
 
 
 
 
 
 
 
 
 
 
 
1850
  // CPU backend - buffer
1851
 
1852
  static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
 
1920
  /* .reset = */ NULL,
1921
  };
1922
 
1923
+ // CPU backend buffer type
1924
+
1925
+ // this buffer type is defined here to make it available to all backends
1926
 
1927
  static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1928
  return "CPU";
 
1963
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1964
  /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
1965
  },
1966
+ /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
1967
  /* .context = */ NULL,
1968
  };
1969
 
 
1986
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1987
  /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
1988
  },
1989
+ /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
1990
  /* .context = */ NULL,
1991
  };
1992
 
1993
  return &ggml_backend_cpu_buffer_type;
1994
  }
1995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1996
  ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
1997
  GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
1998
  return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
1999
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-blas.cpp DELETED
@@ -1,514 +0,0 @@
1
- #include "ggml-impl.h"
2
- #include "ggml-blas.h"
3
- #include "ggml-backend-impl.h"
4
-
5
- #include <future>
6
- #include <vector>
7
- #include <cstring>
8
-
9
- #if defined(GGML_USE_ACCELERATE)
10
- # include <Accelerate/Accelerate.h>
11
- #elif defined(GGML_BLAS_USE_MKL)
12
- # include <mkl.h>
13
- #elif defined(GGML_BLAS_USE_BLIS)
14
- # include <blis.h>
15
- #elif defined(GGML_BLAS_USE_NVPL)
16
- # include <nvpl_blas.h>
17
- #else
18
- # include <cblas.h>
19
- #endif
20
-
21
- struct ggml_backend_blas_context {
22
- int n_threads = GGML_DEFAULT_N_THREADS;
23
- std::unique_ptr<char[]> work_data;
24
- size_t work_size = 0;
25
- #ifndef GGML_USE_OPENMP
26
- std::vector<std::future<void>> tasks;
27
- #endif
28
- };
29
-
30
- static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
31
- const struct ggml_tensor * src0 = dst->src[0];
32
- const struct ggml_tensor * src1 = dst->src[1];
33
-
34
- GGML_TENSOR_BINARY_OP_LOCALS
35
-
36
- const enum ggml_type type = src0->type;
37
-
38
- GGML_ASSERT(ne0 == ne01);
39
- GGML_ASSERT(ne1 == ne11);
40
- GGML_ASSERT(ne2 == ne12);
41
- GGML_ASSERT(ne3 == ne13);
42
-
43
- // we don't support permuted src0 or src1
44
- GGML_ASSERT(nb00 == ggml_type_size(type));
45
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
46
-
47
- // dst cannot be transposed or permuted
48
- GGML_ASSERT(nb0 == sizeof(float));
49
- GGML_ASSERT(nb0 <= nb1);
50
- GGML_ASSERT(nb1 <= nb2);
51
- GGML_ASSERT(nb2 <= nb3);
52
-
53
- // broadcast factors
54
- const int64_t r2 = ne12/ne02;
55
- const int64_t r3 = ne13/ne03;
56
-
57
- const int64_t ne_plane = ne01*ne00;
58
- const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
59
-
60
- if (ctx->work_size < desired_wsize) {
61
- ctx->work_data.reset(new char[desired_wsize]);
62
- ctx->work_size = desired_wsize;
63
- }
64
- void * wdata = ctx->work_data.get();
65
-
66
- // convert src0 to float
67
- if (type != GGML_TYPE_F32) {
68
- const auto * type_traits = ggml_get_type_traits(type);
69
- ggml_to_float_t const to_float = type_traits->to_float;
70
-
71
- for (int64_t i03 = 0; i03 < ne03; i03++) {
72
- for (int64_t i02 = 0; i02 < ne02; i02++) {
73
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
74
- float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
75
-
76
- const int min_cols_per_thread = 4096;
77
- const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
78
- const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
79
-
80
- #ifdef GGML_USE_OPENMP
81
- #pragma omp parallel for num_threads(n_threads)
82
- for (int64_t i01 = 0; i01 < ne01; i01++) {
83
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
84
- }
85
- #else
86
- for (int i = 1; i < n_threads; i++) {
87
- const int64_t start = i*ne01/n_threads;
88
- const int64_t end = (i + 1)*ne01/n_threads;
89
- if (start < end) {
90
- ctx->tasks.push_back(std::async(std::launch::async, [=]() {
91
- for (int64_t i01 = start; i01 < end; i01++) {
92
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
93
- }
94
- }));
95
- }
96
- }
97
- {
98
- // reuse the current thread for the first task
99
- const int64_t start = 0;
100
- const int64_t end = ne01/n_threads;
101
- for (int64_t i01 = start; i01 < end; i01++) {
102
- to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
103
- }
104
- }
105
- #endif
106
- }
107
- }
108
-
109
- #ifndef GGML_USE_OPENMP
110
- // wait for all tasks to finish
111
- for (auto & task : ctx->tasks) {
112
- task.get();
113
- }
114
- ctx->tasks.clear();
115
- #endif
116
- }
117
-
118
- #if defined(OPENBLAS_VERSION)
119
- openblas_set_num_threads(ctx->n_threads);
120
- #endif
121
-
122
- #if defined(GGML_BLAS_USE_BLIS)
123
- bli_thread_set_num_threads(ctx->n_threads);
124
- #endif
125
-
126
- #if defined(GGML_BLAS_USE_NVPL)
127
- nvpl_blas_set_num_threads(ctx->n_threads);
128
- #endif
129
-
130
- for (int64_t i13 = 0; i13 < ne13; i13++) {
131
- for (int64_t i12 = 0; i12 < ne12; i12++) {
132
- const int64_t i03 = i13/r3;
133
- const int64_t i02 = i12/r2;
134
-
135
- const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
136
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
137
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
138
-
139
- if (type != GGML_TYPE_F32) {
140
- x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
141
- }
142
-
143
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
144
- ne1, ne01, ne10,
145
- 1.0f, y, ne10,
146
- x, ne00,
147
- 0.0f, d, ne01);
148
- }
149
- }
150
- }
151
-
152
- static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
153
- const struct ggml_tensor * src0 = dst->src[0];
154
- const struct ggml_tensor * src1 = dst->src[1];
155
-
156
- GGML_TENSOR_BINARY_OP_LOCALS
157
-
158
- GGML_ASSERT(ne0 == ne00);
159
- GGML_ASSERT(ne1 == ne10);
160
- GGML_ASSERT(ne2 == ne02);
161
- GGML_ASSERT(ne02 == ne12);
162
- GGML_ASSERT(ne3 == ne13);
163
- GGML_ASSERT(ne03 == ne13);
164
-
165
- // we don't support permuted src0 or src1
166
- GGML_ASSERT(nb00 == sizeof(float));
167
-
168
- // dst cannot be transposed or permuted
169
- GGML_ASSERT(nb0 == sizeof(float));
170
- // GGML_ASSERT(nb0 <= nb1);
171
- // GGML_ASSERT(nb1 <= nb2);
172
- // GGML_ASSERT(nb2 <= nb3);
173
-
174
- // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
175
- // src0: (k,n)
176
- // src1: (k,m)
177
- // dst: (m,n)
178
- //
179
- // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
180
- // Also expressed as (major,minor)
181
- // a: (m,k): so src1 transposed
182
- // b: (k,n): so src0
183
- // c: (m,n)
184
- //
185
- // However, if ggml_is_transposed(src1) is true, then
186
- // src1->data already contains a transposed version, so sgemm mustn't
187
- // transpose it further.
188
-
189
- int n = src0->ne[0];
190
- int k = src0->ne[1];
191
- int m = src1->ne[0];
192
-
193
- CBLAS_TRANSPOSE transposeA;
194
- int lda;
195
-
196
- if (!ggml_is_transposed(src1)) {
197
- transposeA = CblasTrans;
198
- lda = m;
199
- } else {
200
- transposeA = CblasNoTrans;
201
- lda = k;
202
- }
203
-
204
- float * a = (float *) ((char *) src1->data);
205
- float * b = (float *) ((char *) src0->data);
206
- float * c = (float *) ((char *) dst->data);
207
-
208
- cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
209
-
210
- GGML_UNUSED(ctx);
211
- }
212
-
213
- // backend interface
214
-
215
- static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
216
- return "BLAS";
217
-
218
- GGML_UNUSED(backend);
219
- }
220
-
221
- static void ggml_backend_blas_free(ggml_backend_t backend) {
222
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
223
- delete ctx;
224
- delete backend;
225
- }
226
-
227
- static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
228
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
229
-
230
- for (int i = 0; i < cgraph->n_nodes; i++) {
231
- struct ggml_tensor * node = cgraph->nodes[i];
232
-
233
- switch (node->op) {
234
- case GGML_OP_MUL_MAT:
235
- ggml_backend_blas_mul_mat(ctx, node);
236
- break;
237
-
238
- case GGML_OP_OUT_PROD:
239
- ggml_backend_blas_out_prod(ctx, node);
240
- break;
241
-
242
- case GGML_OP_NONE:
243
- case GGML_OP_RESHAPE:
244
- case GGML_OP_VIEW:
245
- case GGML_OP_PERMUTE:
246
- case GGML_OP_TRANSPOSE:
247
- break;
248
-
249
- default:
250
- GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
251
- }
252
- }
253
-
254
- return GGML_STATUS_SUCCESS;
255
-
256
- GGML_UNUSED(backend);
257
- }
258
-
259
- static struct ggml_backend_i blas_backend_i = {
260
- /* .get_name = */ ggml_backend_blas_get_name,
261
- /* .free = */ ggml_backend_blas_free,
262
- /* .set_tensor_async = */ NULL,
263
- /* .get_tensor_async = */ NULL,
264
- /* .cpy_tensor_async = */ NULL,
265
- /* .synchronize = */ NULL,
266
- /* .graph_plan_create = */ NULL,
267
- /* .graph_plan_free = */ NULL,
268
- /* .graph_plan_update = */ NULL,
269
- /* .graph_plan_compute = */ NULL,
270
- /* .graph_compute = */ ggml_backend_blas_graph_compute,
271
- /* .event_record = */ NULL,
272
- /* .event_wait = */ NULL,
273
- };
274
-
275
- static ggml_guid_t ggml_backend_blas_guid(void) {
276
- static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
277
- return &guid;
278
- }
279
-
280
- ggml_backend_t ggml_backend_blas_init(void) {
281
- ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
282
-
283
- ggml_backend_t backend = new ggml_backend {
284
- /* .guid = */ ggml_backend_blas_guid(),
285
- /* .interface = */ blas_backend_i,
286
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
287
- /* .context = */ ctx,
288
- };
289
-
290
- #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
291
- if (openblas_get_parallel() != OPENBLAS_OPENMP) {
292
- GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
293
- }
294
- #endif
295
-
296
- #if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
297
- GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
298
- #endif
299
-
300
- return backend;
301
- }
302
-
303
- bool ggml_backend_is_blas(ggml_backend_t backend) {
304
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
305
- }
306
-
307
- void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
308
- GGML_ASSERT(ggml_backend_is_blas(backend_blas));
309
-
310
- ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
311
- ctx->n_threads = n_threads;
312
- }
313
-
314
- // device interface
315
-
316
- static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
317
- return "BLAS";
318
-
319
- GGML_UNUSED(dev);
320
- }
321
-
322
- static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
323
- #if defined(GGML_USE_ACCELERATE)
324
- return "Accelerate";
325
- #elif defined(GGML_BLAS_USE_MKL)
326
- return "MKL";
327
- #elif defined(GGML_BLAS_USE_BLIS)
328
- return "BLIS";
329
- #elif defined(GGML_BLAS_USE_NVPL)
330
- return "NVPL";
331
- #elif defined(OPENBLAS_VERSION)
332
- return "OpenBLAS";
333
- #else
334
- return "BLAS";
335
- #endif
336
-
337
- GGML_UNUSED(dev);
338
- }
339
-
340
- static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
341
- // TODO
342
- *free = 0;
343
- *total = 0;
344
-
345
- GGML_UNUSED(dev);
346
- }
347
-
348
- static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
349
- return GGML_BACKEND_DEVICE_TYPE_ACCEL;
350
-
351
- GGML_UNUSED(dev);
352
- }
353
-
354
- static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
355
- props->name = ggml_backend_blas_device_get_name(dev);
356
- props->description = ggml_backend_blas_device_get_description(dev);
357
- props->type = ggml_backend_blas_device_get_type(dev);
358
- ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
359
- props->caps = {
360
- /* .async = */ false,
361
- /* .host_buffer = */ false,
362
- /* .buffer_from_host_ptr = */ true,
363
- /* .events = */ false,
364
- };
365
- }
366
-
367
- static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
368
- return ggml_backend_blas_init();
369
-
370
- GGML_UNUSED(dev);
371
- GGML_UNUSED(params);
372
- }
373
-
374
- static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
375
- return ggml_backend_cpu_buffer_type();
376
-
377
- GGML_UNUSED(dev);
378
- }
379
-
380
- static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
381
- return ggml_backend_cpu_buffer_from_ptr(ptr, size);
382
-
383
- GGML_UNUSED(dev);
384
- GGML_UNUSED(max_tensor_size);
385
- }
386
-
387
- static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
388
- const struct ggml_tensor * src0 = op->src[0];
389
- const struct ggml_tensor * src1 = op->src[1];
390
-
391
- switch (op->op) {
392
- case GGML_OP_NONE:
393
- case GGML_OP_RESHAPE:
394
- case GGML_OP_VIEW:
395
- case GGML_OP_PERMUTE:
396
- case GGML_OP_TRANSPOSE:
397
- return true;
398
-
399
- case GGML_OP_MUL_MAT:
400
- {
401
- // BLAS usually is only faster for large matrices
402
- const struct ggml_tensor * src0 = op->src[0];
403
- const struct ggml_tensor * src1 = op->src[1];
404
-
405
- const int64_t ne10 = src1->ne[0];
406
-
407
- const int64_t ne0 = op->ne[0];
408
- const int64_t ne1 = op->ne[1];
409
-
410
- // TODO: find the optimal value
411
- const int64_t min_batch = 32;
412
-
413
- return ggml_is_contiguous(src0) &&
414
- ggml_is_contiguous(src1) &&
415
- src1->type == GGML_TYPE_F32 &&
416
- (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
417
- (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
418
- }
419
-
420
- case GGML_OP_OUT_PROD:
421
- return op->src[0]->type == GGML_TYPE_F32 &&
422
- op->src[1]->type == GGML_TYPE_F32 &&
423
- ggml_is_matrix(src0) &&
424
- ggml_is_matrix(src1) &&
425
- ggml_is_contiguous(src0) &&
426
- (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
427
- (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
428
-
429
- default:
430
- return false;
431
-
432
- }
433
-
434
- GGML_UNUSED(dev);
435
- }
436
-
437
- static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
438
- return ggml_backend_buft_is_host(buft);
439
-
440
- GGML_UNUSED(dev);
441
- }
442
-
443
- static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
444
- /* .get_name = */ ggml_backend_blas_device_get_name,
445
- /* .get_description = */ ggml_backend_blas_device_get_description,
446
- /* .get_memory = */ ggml_backend_blas_device_get_memory,
447
- /* .get_type = */ ggml_backend_blas_device_get_type,
448
- /* .get_props = */ ggml_backend_blas_device_get_props,
449
- /* .init_backend = */ ggml_backend_blas_device_init_backend,
450
- /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
451
- /* .get_host_buffer_type = */ NULL,
452
- /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
453
- /* .supports_op = */ ggml_backend_blas_device_supports_op,
454
- /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
455
- /* .offload_op = */ NULL,
456
- /* .event_new = */ NULL,
457
- /* .event_free = */ NULL,
458
- /* .event_synchronize = */ NULL,
459
- };
460
-
461
- // backend reg interface
462
-
463
- static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
464
- return "BLAS";
465
-
466
- GGML_UNUSED(reg);
467
- }
468
-
469
- static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
470
- return 1;
471
-
472
- GGML_UNUSED(reg);
473
- }
474
-
475
- static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
476
- GGML_ASSERT(index == 0);
477
-
478
- static ggml_backend_device ggml_backend_blas_device = {
479
- /* .iface = */ ggml_backend_blas_device_i,
480
- /* .reg = */ reg,
481
- /* .context = */ nullptr,
482
- };
483
-
484
- return &ggml_backend_blas_device;
485
-
486
- GGML_UNUSED(reg);
487
- GGML_UNUSED(index);
488
- }
489
-
490
- static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
491
- if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
492
- return (void *)ggml_backend_blas_set_n_threads;
493
- }
494
- return NULL;
495
-
496
- GGML_UNUSED(reg);
497
- GGML_UNUSED(name);
498
- }
499
-
500
- static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
501
- /* .get_name = */ ggml_backend_blas_reg_get_name,
502
- /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
503
- /* .get_device = */ ggml_backend_blas_reg_get_device,
504
- /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
505
- };
506
-
507
- ggml_backend_reg_t ggml_backend_blas_reg(void) {
508
- static struct ggml_backend_reg ggml_backend_blas_reg = {
509
- /* .iface = */ ggml_backend_blas_reg_i,
510
- /* .context = */ NULL,
511
- };
512
-
513
- return &ggml_backend_blas_reg;
514
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cann.cpp DELETED
@@ -1,2128 +0,0 @@
1
- /*
2
- * Copyright (c) 2023-2024 The ggml authors
3
- *
4
- * Permission is hereby granted, free of charge, to any person obtaining a copy
5
- * of this software and associated documentation files (the "Software"), to
6
- * deal in the Software without restriction, including without limitation the
7
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
- * sell copies of the Software, and to permit persons to whom the Software is
9
- * furnished to do so, subject to the following conditions:
10
- *
11
- * The above copyright notice and this permission notice shall be included in
12
- * all copies or substantial portions of the Software.
13
- *
14
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
- * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
- * IN THE SOFTWARE.
21
- */
22
-
23
- #include "ggml-cann.h"
24
-
25
- #include <acl/acl.h>
26
- #include <stdarg.h>
27
-
28
- #include <cmath>
29
- #include <cstdio>
30
- #include <cstring>
31
- #include <mutex>
32
-
33
- #include "ggml-impl.h"
34
- #include "ggml-backend-impl.h"
35
- #include "ggml-cann/aclnn_ops.h"
36
- #include "ggml-cann/common.h"
37
-
38
- #define GGML_COMMON_DECL_C
39
-
40
- #include "ggml-common.h"
41
-
42
- #define GGML_CANN_NAME "CANN"
43
-
44
- /**
45
- * @brief Handles CANN errors by printing an error message and aborting.
46
- *
47
- * @param stmt The statement that caused the error.
48
- * @param func The function in which the error occurred.
49
- * @param file The file in which the error occurred.
50
- * @param line The line number where the error occurred.
51
- * @param msg The error message.
52
- */
53
- [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
54
- const char* file, int line, const char* msg) {
55
- int32_t id = -1;
56
- aclrtGetDevice(&id);
57
-
58
- GGML_LOG_ERROR("CANN error: %s\n", msg);
59
- GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
60
- file, line);
61
- GGML_LOG_ERROR(" %s\n", stmt);
62
- // abort with GGML_ASSERT to get a stack trace
63
- GGML_ABORT("CANN error");
64
- }
65
-
66
- /**
67
- * @brief Sets the device to be used by CANN.
68
- *
69
- * @param device The device ID to set.
70
- */
71
- void ggml_cann_set_device(const int32_t device) {
72
- // TODO: uncomment these lines after empty context has fixed.
73
- // int current_device;
74
- // ACL_CHECK(aclrtGetDevice(&current_device));
75
-
76
- // if (device == current_device) {
77
- // return;
78
- // }
79
- ACL_CHECK(aclrtSetDevice(device));
80
- }
81
-
82
- /**
83
- * @brief Retrieves the current device ID.
84
- *
85
- * @return The current device ID.
86
- */
87
- int32_t ggml_cann_get_device() {
88
- int32_t id;
89
- ACL_CHECK(aclrtGetDevice(&id));
90
- return id;
91
- }
92
-
93
- /**
94
- * @brief Initialize the CANN device information.
95
- *
96
- * This function initializes the CANN device information by obtaining the
97
- * device count and setting the memory allocation granularity for each device.
98
- *
99
- * @return A structure containing the device information.
100
- */
101
- static ggml_cann_device_info ggml_cann_init() {
102
- ggml_cann_device_info info = {};
103
-
104
- aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
105
-
106
- if (err != ACL_SUCCESS) {
107
- GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
108
- __func__, aclGetRecentErrMsg());
109
- return info;
110
- }
111
-
112
- GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
113
-
114
- for (int id = 0; id < info.device_count; ++id) {
115
- aclrtPhysicalMemProp prop = {};
116
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
117
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
118
- prop.memAttr = ACL_HBM_MEM_HUGE;
119
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
120
- prop.location.id = id;
121
- prop.reserve = 0;
122
- ACL_CHECK(aclrtMemGetAllocationGranularity(
123
- &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
- &info.devices[id].vmm_granularity));
125
- }
126
-
127
- // TODO: add more device info later.
128
- return info;
129
- }
130
-
131
- /**
132
- * @brief Retrieve the CANN device information.
133
- *
134
- * This function returns a reference to a structure containing the CANN device
135
- * information. The device information is initialized once and reused on
136
- * subsequent calls.
137
- *
138
- * @return A reference to the structure containing the device information.
139
- */
140
- const ggml_cann_device_info& ggml_cann_info() {
141
- static ggml_cann_device_info info = ggml_cann_init();
142
- return info;
143
- }
144
-
145
- //#define DEBUG_CANN_MALLOC
146
- /**
147
- * @brief A pool of CANN buffers(legacy).
148
- *
149
- * This class manages a pool of CANN buffers for a specific device.
150
- */
151
- struct ggml_cann_pool_leg : public ggml_cann_pool {
152
- /**
153
- * @brief The maximum number of buffers in the pool.
154
- */
155
- static const int MAX_BUFFERS = 256;
156
-
157
- /**
158
- * @brief The device ID associated with this buffer pool.
159
- */
160
- int device;
161
-
162
- /**
163
- * @brief Structure representing a CANN buffer.
164
- */
165
- struct ggml_cann_buffer {
166
- void* ptr = nullptr; ///< Pointer to the buffer memory.
167
- size_t size = 0; ///< Size of the buffer.
168
- };
169
-
170
- /**
171
- * @brief Array of CANN buffers in the pool.
172
- */
173
- ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
174
-
175
- /**
176
- * @brief Total size of all buffers in the pool.
177
- */
178
- size_t pool_size = 0;
179
-
180
- /**
181
- * @brief Constructor to initialize the buffer pool for a specific device.
182
- *
183
- * @param device The device ID to associate with this buffer pool.
184
- */
185
- explicit ggml_cann_pool_leg(int device) : device(device) {}
186
-
187
- /**
188
- * @brief Destructor to free all buffers in the pool.
189
- */
190
- ~ggml_cann_pool_leg() {
191
- ggml_cann_set_device(device);
192
- for (int i = 0; i < MAX_BUFFERS; ++i) {
193
- ggml_cann_buffer& b = buffer_pool[i];
194
- if (b.ptr != nullptr) {
195
- ACL_CHECK(aclrtFree(b.ptr));
196
- pool_size -= b.size;
197
- }
198
- }
199
- GGML_ASSERT(pool_size == 0);
200
- }
201
-
202
- /**
203
- * @brief Allocate a buffer of the given size.
204
- *
205
- * @param size The size of the buffer to allocate.
206
- * @param actual_size A pointer to a variable to receive the actual size of
207
- * the allocated buffer.
208
- * @return A pointer to the allocated buffer.
209
- */
210
- void* alloc(size_t size, size_t* actual_size) override {
211
- #ifdef DEBUG_CANN_MALLOC
212
- int nnz = 0;
213
- size_t max_size = 0;
214
- #endif
215
- size_t best_diff = 1ull << 36;
216
- int ibest = -1;
217
- for (int i = 0; i < MAX_BUFFERS; ++i) {
218
- ggml_cann_buffer& b = buffer_pool[i];
219
- if (b.ptr != nullptr) {
220
- #ifdef DEBUG_CANN_MALLOC
221
- ++nnz;
222
- if (b.size > max_size) max_size = b.size;
223
- #endif
224
- if (b.size >= size) {
225
- size_t diff = b.size - size;
226
- if (diff < best_diff) {
227
- best_diff = diff;
228
- ibest = i;
229
- if (!best_diff) {
230
- void* ptr = b.ptr;
231
- *actual_size = b.size;
232
- b.ptr = nullptr;
233
- b.size = 0;
234
- return ptr;
235
- }
236
- }
237
- }
238
- }
239
- }
240
- if (ibest >= 0) {
241
- ggml_cann_buffer& b = buffer_pool[ibest];
242
- void* ptr = b.ptr;
243
- *actual_size = b.size;
244
- b.ptr = nullptr;
245
- b.size = 0;
246
- return ptr;
247
- }
248
- void* ptr;
249
- size_t look_ahead_size = (size_t)(1.05 * size);
250
- look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
251
- ggml_cann_set_device(device);
252
- ACL_CHECK(
253
- aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
254
- *actual_size = look_ahead_size;
255
- pool_size += look_ahead_size;
256
- #ifdef DEBUG_CANN_MALLOC
257
- GGML_LOG_INFO(
258
- "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
259
- "requested %u MB\n",
260
- __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
261
- (uint32_t)(pool_size / 1024 / 1024),
262
- (uint32_t)(size / 1024 / 1024));
263
- #endif
264
- return ptr;
265
- }
266
-
267
- /**
268
- * @brief Free a buffer and return it to the pool.
269
- *
270
- * @param ptr Pointer to the buffer to free.
271
- * @param size Size of the buffer to free.
272
- */
273
- void free(void* ptr, size_t size) override {
274
- for (int i = 0; i < MAX_BUFFERS; ++i) {
275
- ggml_cann_buffer& b = buffer_pool[i];
276
- if (b.ptr == nullptr) {
277
- b.ptr = ptr;
278
- b.size = size;
279
- return;
280
- }
281
- }
282
- // memory should always buffered. these memory may still needed by
283
- // tasks in stream.
284
- // TODO, fix me.
285
- GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
286
- }
287
- };
288
-
289
- /**
290
- * @brief A pool of CANN buffers with virtual memory.
291
- *
292
- * This class manages a pool of CANN buffers with virtual memory for a specific
293
- * device.
294
- */
295
- struct ggml_cann_pool_vmm : public ggml_cann_pool {
296
- /**
297
- * @brief The maximum size of the virtual memory pool (32 GB).
298
- */
299
- static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
300
-
301
- /**
302
- * @brief The device ID associated with this buffer pool.
303
- */
304
- int device;
305
-
306
- /**
307
- * @brief Pointer to the start of the virtual memory pool.
308
- */
309
- void* pool_addr = 0;
310
-
311
- /**
312
- * @brief Amount of virtual memory used in the pool.
313
- */
314
- size_t pool_used = 0;
315
-
316
- /**
317
- * @brief Total size of the virtual memory pool.
318
- */
319
- size_t pool_size = 0;
320
-
321
- /**
322
- * @brief Allocation granularity for the virtual memory pool.
323
- */
324
- size_t granularity;
325
-
326
- /**
327
- * @brief Handles for the physical memory allocated.
328
- */
329
- std::vector<aclrtDrvMemHandle> handles;
330
-
331
- /**
332
- * @brief Offsets for the mapped memory regions.
333
- */
334
- std::vector<void*> map_offsets;
335
-
336
- /**
337
- * @brief Constructor to initialize the buffer pool with virtual memory for
338
- * a specific device.
339
- *
340
- * @param device The device ID to associate with this buffer pool.
341
- */
342
- explicit ggml_cann_pool_vmm(int device)
343
- : device(device),
344
- granularity(ggml_cann_info().devices[device].vmm_granularity) {}
345
-
346
- /**
347
- * @brief Destructor to free all buffers in the virtual memory pool.
348
- */
349
- ~ggml_cann_pool_vmm() {
350
- if (pool_addr != 0) {
351
- for (auto& offset : map_offsets) {
352
- ACL_CHECK(aclrtUnmapMem(offset));
353
- }
354
- for (auto& handle : handles) {
355
- ACL_CHECK(aclrtFreePhysical(handle));
356
- }
357
- ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
358
- }
359
- }
360
-
361
- /**
362
- * @brief Allocate a buffer of the given size in the virtual memory pool.
363
- *
364
- * @param size The size of the buffer to allocate.
365
- * @param actual_size A pointer to a variable to receive the actual size of
366
- * the allocated buffer.
367
- * @return A pointer to the allocated buffer.
368
- */
369
- void* alloc(size_t size, size_t* actual_size) override {
370
- // round up the allocation size to the alignment to ensure that all
371
- // allocations are aligned for all data types
372
- const size_t alignment = 128;
373
- size = alignment * ((size + alignment - 1) / alignment);
374
-
375
- size_t avail = pool_size - pool_used;
376
-
377
- if (size > avail) {
378
- // round up to the next multiple of the granularity
379
- size_t reserve_size = size - avail;
380
- reserve_size =
381
- granularity * ((reserve_size + granularity - 1) / granularity);
382
-
383
- GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
384
-
385
- // allocate more physical memory
386
- aclrtPhysicalMemProp prop = {};
387
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
388
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
389
- prop.memAttr = ACL_HBM_MEM_HUGE;
390
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
391
- prop.location.id = device;
392
- prop.reserve = 0;
393
- aclrtDrvMemHandle handle;
394
- ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
395
-
396
- // reserve virtual address space (if not already reserved)
397
- if (pool_addr == 0) {
398
- ACL_CHECK(aclrtReserveMemAddress(
399
- &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
400
- }
401
-
402
- // map at the end of the pool
403
- ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
404
- handle, 0));
405
-
406
- handles.push_back(handle);
407
- map_offsets.push_back((char*)pool_addr + pool_size);
408
-
409
- // add to the pool
410
- pool_size += reserve_size;
411
-
412
- // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
413
- // reserved %llu MB)\n",
414
- // device, (unsigned long long) (pool_size/1024/1024),
415
- // (unsigned long long) (reserve_size/1024/1024));
416
- }
417
-
418
- GGML_ASSERT(pool_addr != 0);
419
-
420
- void* ptr = (void*)((char*)pool_addr + pool_used);
421
- *actual_size = size;
422
- pool_used += size;
423
-
424
- #ifdef DEBUG_CANN_MALLOC
425
- GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
426
- (unsigned long long)size, (unsigned long long)ptr);
427
- #endif
428
- return ptr;
429
- }
430
-
431
- /**
432
- * @brief Free a buffer and return it to the virtual memory pool.
433
- *
434
- * @param ptr Pointer to the buffer to free.
435
- * @param size Size of the buffer to free.
436
- */
437
- void free(void* ptr, size_t size) override {
438
- #ifdef DEBUG_CANN_MALLOC
439
- GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
440
- (unsigned long long)size, (unsigned long long)ptr);
441
- #endif
442
-
443
- pool_used -= size;
444
-
445
- // all deallocations must be in reverse order of the allocations
446
- GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
447
- }
448
- };
449
-
450
- /**
451
- * @brief Create a new CANN pool for a specific device.
452
- *
453
- * Factory method to create a new CANN pool object based on the device type.
454
- *
455
- * @param device The device ID for which to create the pool.
456
- * @return A unique pointer to the created CANN pool.
457
- */
458
- std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
459
- int device) {
460
- // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
461
- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
462
- }
463
-
464
- // cann buffer
465
- /**
466
- * @brief Context for managing a CANN buffer associated with a specific device.
467
- *
468
- * This structure holds information about a CANN buffer, including the device
469
- * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
470
- */
471
- struct ggml_backend_cann_buffer_context {
472
- int32_t device; ///< The device ID associated with this buffer context.
473
- void* dev_ptr =
474
- nullptr; ///< Pointer to the device memory allocated for the buffer.
475
-
476
- /**
477
- * @brief Constructor to initialize the CANN buffer context.
478
- *
479
- * @param device The device ID associated with this buffer context.
480
- * @param dev_ptr Pointer to the device memory allocated for the buffer.
481
- */
482
- ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
483
- : device(device),
484
- dev_ptr(dev_ptr) {}
485
-
486
- /**
487
- * @brief Destructor to free the device memory allocated for the buffer.
488
- */
489
- ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
490
- };
491
-
492
- /**
493
- * @brief Check if a buffer is a CANN buffer.
494
- *
495
- * This function checks if a given buffer is a CANN buffer by comparing its
496
- * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
497
- *
498
- * @param buffer The buffer to check.
499
- * @return true if the buffer is a CANN buffer, false otherwise.
500
- */
501
- static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
502
- static bool ggml_backend_buffer_is_cann(
503
- ggml_backend_buffer_t buffer) {
504
- return ggml_backend_buft_is_cann(buffer->buft);
505
- }
506
-
507
- /**
508
- * @brief Free resources associated with a CANN buffer.
509
- *
510
- * This function frees the resources associated with a CANN buffer, including
511
- * its context.
512
- *
513
- * @param buffer The CANN buffer to free.
514
- */
515
- static void ggml_backend_cann_buffer_free_buffer(
516
- ggml_backend_buffer_t buffer) {
517
- ggml_backend_cann_buffer_context* ctx =
518
- (ggml_backend_cann_buffer_context*)buffer->context;
519
- delete ctx;
520
- }
521
-
522
- /**
523
- * @brief Retrieve the base pointer of a CANN buffer.
524
- *
525
- * This function returns the base pointer of a CANN buffer, which points to the
526
- * device memory allocated for the buffer.
527
- *
528
- * @param buffer The CANN buffer whose base pointer is to be retrieved.
529
- * @return A pointer to the base of the device memory allocated for the buffer.
530
- */
531
- static void* ggml_backend_cann_buffer_get_base(
532
- ggml_backend_buffer_t buffer) {
533
- ggml_backend_cann_buffer_context* ctx =
534
- (ggml_backend_cann_buffer_context*)buffer->context;
535
- return ctx->dev_ptr;
536
- }
537
-
538
- /**
539
- * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
540
- * processing.
541
- *
542
- * This function transforms quantized Q4.0 tensor data into a format suitable
543
- * for CANN processing. It extracts quantization values and scales from the
544
- * source data and prepares them in a format expected by CANN operations.
545
- *
546
- * @param tensor Pointer to the tensor information.
547
- * @param src Pointer to the source data in Q4.0 format.
548
- * @param dst Pointer to the destination buffer where transformed data will be
549
- * stored.
550
- */
551
- static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
552
- const void* src,
553
- void* dst) {
554
-
555
- int64_t n_elems = ggml_nelements(tensor);
556
- int64_t groups = n_elems / QK4_0;
557
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
558
-
559
- uint8_t* quant_offset = (uint8_t*)dst;
560
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
561
-
562
- for (int i = 0; i < groups; i++) {
563
- const block_q4_0* group =
564
- (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
565
- *scale_offset = group->d;
566
- scale_offset++;
567
-
568
- // 0-15
569
- for (int j = 0; j < QK4_0 / 2; j += 2) {
570
- (*quant_offset) = (group->qs[j] & 0x0F);
571
- (*quant_offset) |= ((group->qs[j + 1] << 4));
572
- quant_offset++;
573
- }
574
-
575
- // 16-31
576
- for (int j = 0; j < QK4_0 / 2; j += 2) {
577
- (*quant_offset) = (group->qs[j] >> 4);
578
- (*quant_offset) |= (group->qs[j + 1] & 0xF0);
579
- quant_offset++;
580
- }
581
- }
582
-
583
- // put (uint4b_t -8) into int4b_t
584
- for (quant_offset = (uint8_t*)dst;
585
- quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
586
- (*quant_offset) ^= 0x88;
587
- }
588
- }
589
-
590
- /**
591
- * @brief Transform CANN processed data back into quantized Q4.0 format.
592
- *
593
- * This function transforms CANN processed data back into quantized Q4.0 format.
594
- * It reverses the transformation performed by
595
- * ggml_backend_cann_transform_q4_0(), converting the data back into its
596
- * original quantized form.
597
- *
598
- * @param tensor Pointer to the tensor information.
599
- * @param src Pointer to the source buffer containing transformed data.
600
- * @param dst Pointer to the destination buffer where the Q4.0 formatted data
601
- * will be stored.
602
- */
603
- static void ggml_backend_cann_transform_back_q4_0(
604
- const ggml_tensor* tensor, void* src, void* dst) {
605
-
606
- int64_t n_elems = ggml_nelements(tensor);
607
- int64_t groups = n_elems / QK4_0;
608
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
609
-
610
- uint8_t* quant_offset = (uint8_t*)src;
611
- uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
612
-
613
- for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
614
- (*quant_offset) ^= 0x88;
615
- }
616
- quant_offset = (uint8_t*)src;
617
-
618
- for (int i = 0; i < groups; i++) {
619
- block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
620
- group->d = *scale_offset;
621
- scale_offset++;
622
-
623
- // 0-15
624
- for (int j = 0; j < QK4_0 / 2; j += 2) {
625
- group->qs[j] = ((*quant_offset) & 0x0F);
626
- group->qs[j + 1] = ((*quant_offset) >> 4);
627
- quant_offset++;
628
- }
629
-
630
- // 16-31
631
- for (int j = 0; j < QK4_0 / 2; j += 2) {
632
- group->qs[j] |= ((*quant_offset) << 4);
633
- group->qs[j + 1] |= ((*quant_offset) & 0xF0);
634
- quant_offset++;
635
- }
636
- }
637
- }
638
-
639
- /**
640
- * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
641
- * processing.
642
- *
643
- * This function transforms quantized Q8.0 tensor data into a format suitable
644
- * for CANN processing. It extracts quantization values and scales from the
645
- * source data and prepares them in a format expected by CANN operations.
646
- *
647
- * @param tensor Pointer to the tensor information.
648
- * @param src Pointer to the source data in Q8.0 format.
649
- * @param dst Pointer to the destination buffer where transformed data will be
650
- * stored.
651
- */
652
- static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
653
- const void* src,
654
- void* dst) {
655
- int64_t n_elems = ggml_nelements(tensor);
656
- int64_t groups = n_elems / QK8_0;
657
- size_t quant_bytes = n_elems * sizeof(uint8_t);
658
-
659
- uint8_t* quant_offset = (uint8_t*)dst;
660
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
661
-
662
- for (int i = 0; i < groups; i++) {
663
- const block_q8_0* group =
664
- (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
665
- *scale_offset = group->d;
666
- scale_offset++;
667
- size_t group_quant_size = QK8_0 * sizeof(uint8_t);
668
- memcpy(quant_offset, group->qs, group_quant_size);
669
- quant_offset += group_quant_size;
670
- }
671
- }
672
-
673
- /**
674
- * @brief Transform CANN processed data back into quantized Q8.0 format.
675
- *
676
- * This function transforms CANN processed data back into quantized Q8.0 format.
677
- * It reverses the transformation performed by
678
- * ggml_backend_cann_transform_q8_0(), converting the data back into its
679
- * original quantized form.
680
- *
681
- * @param tensor Pointer to the tensor information.
682
- * @param src Pointer to the source buffer containing transformed data.
683
- * @param dst Pointer to the destination buffer where the Q8.0 formatted data
684
- * will be stored.
685
- */
686
- static void ggml_backend_cann_transform_back_q8_0(
687
- const ggml_tensor* tensor, const void* src, void* dst) {
688
- int64_t n_elems = ggml_nelements(tensor);
689
- int64_t groups = n_elems / QK8_0;
690
- size_t quant_bytes = n_elems * sizeof(uint8_t);
691
-
692
- const uint8_t* quant_offset = (const uint8_t*)src;
693
- const uint16_t* scale_offset =
694
- (const uint16_t*)((const char*)src + quant_bytes);
695
-
696
- for (int i = 0; i < groups; i++) {
697
- block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
698
- group->d = *scale_offset;
699
- scale_offset++;
700
- size_t group_quant_size = QK8_0 * sizeof(uint8_t);
701
- memcpy(group->qs, quant_offset, group_quant_size);
702
- quant_offset += group_quant_size;
703
- }
704
- }
705
-
706
- /**
707
- * @brief Transform tensor data based on its type for CANN processing.
708
- *
709
- * This function transforms tensor data based on its quantization type for CANN
710
- * processing. It dispatches the transformation based on the tensor's type to
711
- * specialized functions handling Q4.0 and Q8.0 formats.
712
- *
713
- * @param tensor Pointer to the tensor information.
714
- * @param src Pointer to the source data to be transformed.
715
- * @param dst Pointer to the destination buffer where transformed data will be
716
- * stored.
717
- */
718
- static void ggml_backend_cann_transform(ggml_tensor* tensor,
719
- const void* src, void* dst) {
720
- switch (tensor->type) {
721
- case GGML_TYPE_Q4_0:
722
- ggml_backend_cann_transform_q4_0(tensor, src, dst);
723
- break;
724
- case GGML_TYPE_Q8_0:
725
- ggml_backend_cann_transform_q8_0(tensor, src, dst);
726
- break;
727
- default:
728
- break;
729
- }
730
- }
731
-
732
- /**
733
- * @brief Transform CANN processed data back into tensor data based on its type.
734
- *
735
- * This function transforms CANN processed data back into tensor data based on
736
- * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
737
- * transformation based on the tensor's type to specialized functions.
738
- *
739
- * @param tensor Pointer to the tensor information.
740
- * @param src Pointer to the source data containing CANN processed data.
741
- * @param dst Pointer to the destination buffer where transformed tensor data
742
- * will be stored.
743
- */
744
- static void ggml_backend_cann_transform_back(
745
- const ggml_tensor* tensor, void* src, void* dst) {
746
- switch (tensor->type) {
747
- case GGML_TYPE_Q4_0:
748
- ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
749
- break;
750
- case GGML_TYPE_Q8_0:
751
- ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
752
- break;
753
- default:
754
- break;
755
- }
756
- }
757
-
758
- /**
759
- * @brief Check if transformation is needed for a given tensor type.
760
- *
761
- * This function checks if transformation is needed for a given tensor type
762
- * to prepare data for CANN processing.
763
- *
764
- * @param type The tensor type to check.
765
- * @return true if transformation is needed, false otherwise.
766
- */
767
- static bool need_transform(ggml_type type) {
768
- switch (type) {
769
- case GGML_TYPE_Q4_0:
770
- case GGML_TYPE_Q8_0:
771
- return true;
772
- default:
773
- return false;
774
- }
775
- }
776
-
777
- /**
778
- * @brief Initialize a tensor using data from a CANN buffer.
779
- *
780
- * This function initializes a tensor using data from a CANN buffer.
781
- * It handles special cases such as views and quantization.
782
- *
783
- * @param buffer The CANN buffer from which to initialize the tensor.
784
- * @param tensor Pointer to the tensor to be initialized.
785
- */
786
- static void ggml_backend_cann_buffer_init_tensor(
787
- ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
788
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
789
- GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
790
- return;
791
- }
792
-
793
- // TODO: can backend doesn't support quantized yet. Just leave the code
794
- // here.
795
- if (ggml_is_quantized(tensor->type)) {
796
- // Initialize padding to 0 to avoid possible NaN values
797
- size_t original_size = ggml_nbytes(tensor);
798
- size_t padded_size =
799
- ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
800
-
801
- if (padded_size > original_size && tensor->view_src == nullptr) {
802
- size_t memset_size = padded_size - original_size;
803
- ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
804
- memset_size, 0, memset_size));
805
- }
806
- }
807
- }
808
-
809
- // TODO: need handle tensor which has paddings.
810
- /**
811
- * @brief Set tensor data in a CANN buffer.
812
- *
813
- * This function sets tensor data in a CANN buffer, handling transformations
814
- * if needed based on the tensor's type.
815
- *
816
- * @param buffer The CANN buffer where the tensor data will be set.
817
- * @param tensor Pointer to the tensor whose data will be set.
818
- * @param data Pointer to the source data to be copied into the tensor.
819
- * @param offset Offset in the source data from where to start copying.
820
- * @param size Size of the data to be copied, in bytes.
821
- */
822
- static void ggml_backend_cann_buffer_set_tensor(
823
- ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
824
- size_t offset, size_t size) {
825
- ggml_backend_cann_buffer_context *ctx =
826
- (ggml_backend_cann_buffer_context *)buffer->context;
827
-
828
- ggml_cann_set_device(ctx->device);
829
- // TODO: refer to cann(#6017), it use thread's default stream.
830
- // For acl, synchronous functions use this default stream.
831
- // Why aclrtSynchronizeDevice?
832
-
833
- if (!need_transform(tensor->type)) {
834
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
835
- ACL_MEMCPY_HOST_TO_DEVICE));
836
- } else {
837
- void *transform_buffer = malloc(size);
838
- ggml_backend_cann_transform(tensor, data, transform_buffer);
839
-
840
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
841
- transform_buffer, size,
842
- ACL_MEMCPY_HOST_TO_DEVICE));
843
- free(transform_buffer);
844
- }
845
- }
846
-
847
- /**
848
- * @brief Get tensor data from a CANN buffer.
849
- *
850
- * This function retrieves tensor data from a CANN buffer, handling
851
- * transformations if needed based on the tensor's type.
852
- *
853
- * @param buffer The CANN buffer from which to retrieve tensor data.
854
- * @param tensor Pointer to the tensor whose data will be retrieved.
855
- * @param data Pointer to the destination buffer where the tensor data will be
856
- * copied.
857
- * @param offset Offset in the destination buffer where to start copying.
858
- * @param size Size of the data to be copied, in bytes.
859
- */
860
- static void ggml_backend_cann_buffer_get_tensor(
861
- ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
862
- size_t offset, size_t size) {
863
- ggml_backend_cann_buffer_context* ctx =
864
- (ggml_backend_cann_buffer_context*)buffer->context;
865
-
866
- ggml_cann_set_device(ctx->device);
867
-
868
- if (!need_transform(tensor->type)) {
869
- ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
870
- ACL_MEMCPY_DEVICE_TO_HOST));
871
- } else {
872
- void* transform_buffer = malloc(size);
873
- ACL_CHECK(aclrtMemcpy(transform_buffer, size,
874
- (char*)tensor->data + offset, size,
875
- ACL_MEMCPY_DEVICE_TO_HOST));
876
- ggml_backend_cann_transform_back(tensor, transform_buffer, data);
877
- free(transform_buffer);
878
- }
879
- }
880
-
881
- /**
882
- * @brief Copy tensor data between CANN buffers if possible.
883
- *
884
- * This function copies tensor data between CANN buffers if the source and
885
- * destination buffers are CANN buffers and they meet the necessary conditions
886
- * (same device or devices can access each other).
887
- *
888
- * @param buffer The destination CANN buffer where the tensor data will be
889
- * copied.
890
- * @param src Pointer to the source tensor whose data will be copied.
891
- * @param dst Pointer to the destination tensor where the data will be copied.
892
- * @return true if the copy operation succeeded, false otherwise.
893
- */
894
- static bool ggml_backend_cann_buffer_cpy_tensor(
895
- ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
896
- if (ggml_backend_buffer_is_cann(src->buffer)) {
897
- ggml_backend_cann_buffer_context* src_ctx =
898
- (ggml_backend_cann_buffer_context*)src->buffer->context;
899
- ggml_backend_cann_buffer_context* dst_ctx =
900
- (ggml_backend_cann_buffer_context*)buffer->context;
901
-
902
- size_t memcpy_size = ggml_nbytes(src);
903
- // Same device.
904
- if (src_ctx->device == dst_ctx->device) {
905
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
906
- (const char*)src->data, memcpy_size,
907
- ACL_MEMCPY_DEVICE_TO_DEVICE));
908
- return true;
909
- } else {
910
- // Different device but can access by peer.
911
- int32_t canAccessPeer = 0;
912
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
913
- dst_ctx->device));
914
- if (canAccessPeer) {
915
- ggml_cann_set_device(src_ctx->device);
916
- ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
917
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
918
- (const char*)src->data, memcpy_size,
919
- ACL_MEMCPY_DEVICE_TO_DEVICE));
920
- return true;
921
- }
922
- }
923
- }
924
- return false;
925
- }
926
-
927
- /**
928
- * @brief Clear a CANN buffer by setting all its memory to a specified value.
929
- *
930
- * This function clears a CANN buffer by setting all its memory to a specified
931
- * value.
932
- *
933
- * @param buffer The CANN buffer to be cleared.
934
- * @param value The value to which each byte in the buffer will be set.
935
- */
936
- static void ggml_backend_cann_buffer_clear(
937
- ggml_backend_buffer_t buffer, uint8_t value) {
938
- ggml_backend_cann_buffer_context* ctx =
939
- (ggml_backend_cann_buffer_context*)buffer->context;
940
-
941
- ggml_cann_set_device(ctx->device);
942
- ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
943
- }
944
-
945
- /**
946
- * @brief Interface for a CANN buffer in the backend.
947
- *
948
- * This structure defines function pointers to operations that can be performed
949
- * on a CANN buffer within the backend.
950
- */
951
- static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
952
- /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
953
- /* .get_base = */ ggml_backend_cann_buffer_get_base,
954
- /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
955
- /* .memset_tensor = */ NULL,
956
- /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
957
- /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
958
- /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
959
- /* .clear = */ ggml_backend_cann_buffer_clear,
960
- /* .reset = */ NULL,
961
- };
962
-
963
- // cann buffer type
964
- /**
965
- * @brief Structure representing context information for a specific backend
966
- * buffer type.
967
- */
968
- struct ggml_backend_cann_buffer_type_context {
969
- int32_t
970
- device; /**< Device identifier associated with the buffer context. */
971
- std::string name; /**< Name associated with the buffer context. */
972
- };
973
-
974
- /**
975
- * @brief Retrieves the name associated with a CANN buffer type.
976
- *
977
- * This function returns the descriptive name associated with the specified
978
- * CANN buffer type context.
979
- *
980
- * @param buft Pointer to the buffer type context.
981
- * @return Const pointer to the C-style string containing the name.
982
- */
983
- static const char* ggml_backend_cann_buffer_type_name(
984
- ggml_backend_buffer_type_t buft) {
985
- ggml_backend_cann_buffer_type_context* buft_ctx =
986
- (ggml_backend_cann_buffer_type_context*)buft->context;
987
-
988
- return buft_ctx->name.c_str();
989
- }
990
-
991
- /**
992
- * @brief Allocates a new CANN buffer of the specified type and size.
993
- *
994
- * This function allocates a new CANN buffer on the specified device with the
995
- * given size.
996
- *
997
- * @param buft Pointer to the buffer type context.
998
- * @param size Size in bytes of the buffer to allocate.
999
- * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1000
- */
1001
- static ggml_backend_buffer_t
1002
- ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1003
- size_t size) {
1004
- ggml_backend_cann_buffer_type_context* buft_ctx =
1005
- (ggml_backend_cann_buffer_type_context*)buft->context;
1006
-
1007
- ggml_cann_set_device(buft_ctx->device);
1008
-
1009
- size = std::max(size, (size_t)1);
1010
-
1011
- void* dev_ptr;
1012
- aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1013
- if (err != ACL_SUCCESS) {
1014
- GGML_LOG_ERROR(
1015
- "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1016
- __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1017
- aclGetRecentErrMsg());
1018
- return nullptr;
1019
- }
1020
-
1021
- ggml_backend_cann_buffer_context* ctx =
1022
- new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1023
-
1024
- return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1025
- ctx, size);
1026
- }
1027
-
1028
- /**
1029
- * @brief Retrieves the memory alignment requirement for CANN buffers of this
1030
- * type.
1031
- *
1032
- * This function returns the alignment requirement in bytes for memory allocated
1033
- * by the CANN buffer type.
1034
- *
1035
- * @param buft Pointer to the buffer type context (unused in this
1036
- * implementation).
1037
- * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1038
- * buffers).
1039
- */
1040
- static size_t ggml_backend_cann_buffer_type_get_alignment(
1041
- ggml_backend_buffer_type_t buft) {
1042
- return 128;
1043
-
1044
- GGML_UNUSED(buft);
1045
- }
1046
-
1047
- /**
1048
- * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1049
- *
1050
- * Computes the total allocation size needed for storing the tensor's data in a
1051
- * CANN buffer, considering any necessary padding or adjustments for quantized
1052
- * types.
1053
- *
1054
- * @param buft Pointer to the buffer type context (unused in this
1055
- * implementation).
1056
- * @param tensor Pointer to the tensor for which the allocation size is
1057
- * calculated.
1058
- * @return The total allocation size in bytes required for the tensor in the
1059
- * CANN buffer.
1060
- */
1061
- static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1062
- ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1063
- size_t size = ggml_nbytes(tensor);
1064
- int64_t ne0 = tensor->ne[0];
1065
-
1066
- // last line must bigger than 32, because every single op deal at
1067
- // least 32 bytes.
1068
- // TODO: quantized type?
1069
- // int64_t line_size = ne0 * ggml_element_size(tensor);
1070
- // int64_t line_size_align_32 = (line_size + 31) & ~31;
1071
- // size += (line_size_align_32 - line_size);
1072
-
1073
- // TODO: not support quantized yet.
1074
- // TODO: consider un-continue tensor.
1075
- if (ggml_is_quantized(tensor->type)) {
1076
- if (ne0 % MATRIX_ROW_PADDING != 0) {
1077
- size += ggml_row_size(
1078
- tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1079
- }
1080
- }
1081
-
1082
- return size;
1083
-
1084
- GGML_UNUSED(buft);
1085
- }
1086
-
1087
- static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1088
- return false;
1089
-
1090
- GGML_UNUSED(buft);
1091
- }
1092
-
1093
- /**
1094
- * @brief Interface for managing CANN buffer types in the GGML backend.
1095
- *
1096
- * Provides function pointers for allocating, querying properties, and managing
1097
- * memory for CANN buffer types in the GGML backend.
1098
- */
1099
- static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1100
- /* .get_name = */ ggml_backend_cann_buffer_type_name,
1101
- /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1102
- /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1103
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1104
- /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1105
- /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
1106
- };
1107
-
1108
- /**
1109
- * @brief Retrieves the CANN buffer type for a specified device.
1110
- *
1111
- * This function initializes and returns the buffer type interface associated
1112
- * with the given device. It ensures thread-safe access using a mutex.
1113
- *
1114
- * @param device The device index for which to retrieve the buffer type.
1115
- * @return A pointer to the buffer type interface for the specified device, or
1116
- * nullptr if the device index is out of range.
1117
- */
1118
- ggml_backend_buffer_type_t
1119
- ggml_backend_cann_buffer_type(int32_t device) {
1120
- static std::mutex mutex;
1121
- std::lock_guard<std::mutex> lock(mutex);
1122
-
1123
- if (device >= ggml_backend_cann_get_device_count()) {
1124
- return nullptr;
1125
- }
1126
-
1127
- static ggml_backend_buffer_type
1128
- ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1129
-
1130
- static bool ggml_backend_cann_buffer_type_initialized = false;
1131
-
1132
- if (!ggml_backend_cann_buffer_type_initialized) {
1133
- for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1134
- ggml_backend_cann_buffer_types[i] = {
1135
- /* .iface = */ ggml_backend_cann_buffer_type_interface,
1136
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
1137
- /* .context = */
1138
- new ggml_backend_cann_buffer_type_context{
1139
- i, "CANN" + std::to_string(i)},
1140
- };
1141
- }
1142
- ggml_backend_cann_buffer_type_initialized = true;
1143
- }
1144
-
1145
- return &ggml_backend_cann_buffer_types[device];
1146
- }
1147
-
1148
- /**
1149
- * @brief Retrieves the name associated with a CANN host buffer type.
1150
- *
1151
- * This function returns the descriptive name associated with the specified
1152
- * CANN host buffer type context.
1153
- *
1154
- * @param buft Pointer to the host buffer type context.
1155
- * @return Const pointer to the C-style string containing the name.
1156
- */
1157
- static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1158
- return "CANN_Host";
1159
-
1160
- GGML_UNUSED(buft);
1161
- }
1162
-
1163
- /**
1164
- * @brief Retrieves the name associated with a CANN host buffer.
1165
- *
1166
- * This function returns the descriptive name associated with the specified
1167
- * CANN host buffer context.
1168
- *
1169
- * @param buft Pointer to the host buffer context.
1170
- * @return Const pointer to the C-style string containing the name.
1171
- */
1172
- static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1173
- return "CANN_Host";
1174
-
1175
- GGML_UNUSED(buffer);
1176
- }
1177
-
1178
- /**
1179
- * @brief Free resources associated with a CANN host buffer.
1180
- *
1181
- * This function frees the resources associated with a CANN host buffer, including
1182
- * its context.
1183
- *
1184
- * @param buffer The CANN host buffer to free.
1185
- */
1186
- static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1187
- ACL_CHECK(aclrtFreeHost(buffer->context));
1188
- }
1189
-
1190
- /**
1191
- * @brief Allocates a new CANN host buffer of the specified size.
1192
- *
1193
- * This function allocates a new CANN host buffer with the given size.
1194
- * @param size Size in bytes of the host buffer to allocate.
1195
- * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1196
- */
1197
- static void * ggml_cann_host_malloc(size_t size) {
1198
- if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1199
- return nullptr;
1200
- }
1201
-
1202
- void * hostPtr = nullptr;
1203
- aclError err = aclrtMallocHost((void **) &hostPtr, size);
1204
- if (err != ACL_SUCCESS) {
1205
-
1206
- GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1207
- size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1208
- return nullptr;
1209
- }
1210
- return hostPtr;
1211
- }
1212
-
1213
- /**
1214
- * @brief Allocates a new CANN host buffer of the specified type and size.
1215
- *
1216
- * @param buft Pointer to the host buffer type context.
1217
- * @param size Size in bytes of the host buffer to allocate.
1218
- * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1219
- */
1220
- static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1221
- void * hostPtr = ggml_cann_host_malloc(size);
1222
-
1223
- if (hostPtr == nullptr) {
1224
- // fallback to cpu buffer
1225
- return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1226
- }
1227
-
1228
- ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1229
- buffer->buft = buft;
1230
- buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1231
-
1232
- return buffer;
1233
- }
1234
-
1235
- /**
1236
- * @brief Interface for managing CANN host buffer types in the GGML backend.
1237
- *
1238
- * Provides function pointers for allocating, querying properties, and managing
1239
- * memory for CANN buffer types in the GGML backend.
1240
- */
1241
- ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1242
- static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1243
- /* .iface = */ {
1244
- /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1245
- /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1246
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1247
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1248
- /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1249
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1250
- },
1251
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1252
- /* .context = */ nullptr,
1253
- };
1254
-
1255
- return &ggml_backend_cann_buffer_type_host;
1256
- }
1257
-
1258
- /**
1259
- * @brief Computes the forward operation for a given tensor using CANN
1260
- * operations.
1261
- *
1262
- * This function selects the appropriate CANN operation based on the type of
1263
- * operation specified in the tensor and performs the computation.
1264
- *
1265
- * @param ctx The CANN context containing necessary resources and
1266
- * configurations.
1267
- * @param dst The destination tensor where the result of the computation will be
1268
- * stored.
1269
- * @return true if the computation was successful; false otherwise.
1270
- */
1271
- static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1272
- struct ggml_tensor* dst) {
1273
- switch (dst->op) {
1274
- case GGML_OP_REPEAT:
1275
- ggml_cann_repeat(ctx, dst);
1276
- break;
1277
- case GGML_OP_GET_ROWS:
1278
- ggml_cann_get_rows(ctx, dst);
1279
- break;
1280
- case GGML_OP_DUP:
1281
- ggml_cann_dup(ctx, dst);
1282
- break;
1283
- case GGML_OP_ADD:
1284
- ggml_cann_add(ctx, dst);
1285
- break;
1286
- case GGML_OP_ACC:
1287
- ggml_cann_acc(ctx, dst);
1288
- break;
1289
- case GGML_OP_MUL:
1290
- ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1291
- break;
1292
- case GGML_OP_DIV:
1293
- ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1294
- break;
1295
- case GGML_OP_UNARY:
1296
- switch (ggml_get_unary_op(dst)) {
1297
- case GGML_UNARY_OP_GELU:
1298
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1299
- ctx, dst);
1300
- break;
1301
- case GGML_UNARY_OP_SILU:
1302
- ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1303
- ctx, dst);
1304
- break;
1305
- // TODO: Use faster gelu??
1306
- case GGML_UNARY_OP_GELU_QUICK:
1307
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1308
- ctx, dst);
1309
- break;
1310
- case GGML_UNARY_OP_TANH:
1311
- ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1312
- ctx, dst);
1313
- break;
1314
- case GGML_UNARY_OP_RELU:
1315
- ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1316
- ctx, dst);
1317
- break;
1318
- case GGML_UNARY_OP_HARDSIGMOID:
1319
- ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1320
- aclnnHardsigmoid>(ctx, dst);
1321
- break;
1322
- case GGML_UNARY_OP_HARDSWISH:
1323
- ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1324
- aclnnHardswish>(ctx, dst);
1325
- break;
1326
- default:
1327
- return false;
1328
- }
1329
- break;
1330
- case GGML_OP_NORM:
1331
- ggml_cann_norm(ctx, dst);
1332
- break;
1333
- case GGML_OP_GROUP_NORM:
1334
- ggml_cann_group_norm(ctx, dst);
1335
- break;
1336
- case GGML_OP_CONCAT:
1337
- ggml_cann_concat(ctx, dst);
1338
- break;
1339
- case GGML_OP_UPSCALE:
1340
- ggml_cann_upsample_nearest2d(ctx, dst);
1341
- break;
1342
- case GGML_OP_PAD:
1343
- ggml_cann_pad(ctx, dst);
1344
- break;
1345
- case GGML_OP_ARANGE:
1346
- ggml_cann_arange(ctx, dst);
1347
- break;
1348
- case GGML_OP_TIMESTEP_EMBEDDING:
1349
- ggml_cann_timestep_embedding(ctx, dst);
1350
- break;
1351
- case GGML_OP_LEAKY_RELU:
1352
- ggml_cann_leaky_relu(ctx, dst);
1353
- break;
1354
- case GGML_OP_RMS_NORM:
1355
- ggml_cann_rms_norm(ctx, dst);
1356
- break;
1357
- case GGML_OP_MUL_MAT:
1358
- ggml_cann_mul_mat(ctx, dst);
1359
- break;
1360
- case GGML_OP_MUL_MAT_ID:
1361
- return false;
1362
- case GGML_OP_SCALE:
1363
- ggml_cann_scale(ctx, dst);
1364
- break;
1365
- case GGML_OP_SQR:
1366
- ggml_cann_sqr(ctx, dst);
1367
- break;
1368
- case GGML_OP_CLAMP:
1369
- ggml_cann_clamp(ctx, dst);
1370
- break;
1371
- case GGML_OP_CPY:
1372
- ggml_cann_cpy(ctx, dst);
1373
- break;
1374
- case GGML_OP_CONT:
1375
- ggml_cann_dup(ctx, dst);
1376
- break;
1377
- case GGML_OP_NONE:
1378
- case GGML_OP_RESHAPE:
1379
- case GGML_OP_VIEW:
1380
- case GGML_OP_PERMUTE:
1381
- case GGML_OP_TRANSPOSE:
1382
- break;
1383
- case GGML_OP_DIAG_MASK_INF:
1384
- ggml_cann_diag_mask(ctx, dst, -INFINITY);
1385
- break;
1386
- case GGML_OP_SOFT_MAX:
1387
- ggml_cann_softmax(ctx, dst);
1388
- break;
1389
- case GGML_OP_ROPE:
1390
- ggml_cann_rope(ctx, dst);
1391
- break;
1392
- case GGML_OP_IM2COL:
1393
- ggml_cann_im2col(ctx, dst);
1394
- break;
1395
- case GGML_OP_POOL_2D:
1396
- ggml_cann_pool2d(ctx, dst);
1397
- break;
1398
- case GGML_OP_SUM_ROWS:
1399
- ggml_cann_sum_rows(ctx, dst);
1400
- break;
1401
- case GGML_OP_ARGSORT:
1402
- ggml_cann_argsort(ctx, dst);
1403
- break;
1404
- default:
1405
- return false;
1406
- }
1407
-
1408
- return true;
1409
- }
1410
-
1411
- // backend
1412
- /**
1413
- * @brief Retrieves the name associated with the CANN backend.
1414
- *
1415
- * This function returns the name assigned to the CANN backend, which is stored
1416
- * in the context of the provided backend structure.
1417
- *
1418
- * @param backend Pointer to the CANN backend structure.
1419
- * @return A pointer to a constant string representing the backend name.
1420
- */
1421
- static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1422
- ggml_backend_cann_context* cann_ctx =
1423
- (ggml_backend_cann_context*)backend->context;
1424
-
1425
- return cann_ctx->name.c_str();
1426
- }
1427
-
1428
- /**
1429
- * @brief Frees resources associated with the CANN backend.
1430
- *
1431
- * This function releases resources associated with the CANN backend context
1432
- * and resets the device associated with the backend to its initial state.
1433
- *
1434
- * @param backend Pointer to the CANN backend structure to be freed.
1435
- */
1436
- static void ggml_backend_cann_free(ggml_backend_t backend) {
1437
- ggml_backend_cann_context* cann_ctx =
1438
- (ggml_backend_cann_context*)backend->context;
1439
- ACL_CHECK(aclrtSynchronizeDevice());
1440
- ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1441
-
1442
- // finalize when last backend freed.
1443
- if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1444
- ACL_CHECK(aclFinalize());
1445
- }
1446
-
1447
- delete cann_ctx;
1448
- delete backend;
1449
- }
1450
-
1451
- /**
1452
- * @brief Sets tensor data asynchronously in the CANN backend.
1453
- *
1454
- * This function asynchronously sets tensor data in the CANN backend. Depending
1455
- * on the tensor type, it may perform data transformations before copying data
1456
- * to the device.
1457
- *
1458
- * @param backend Pointer to the CANN backend structure.
1459
- * @param tensor Pointer to the tensor structure to set data for.
1460
- * @param data Pointer to the host data to copy to the tensor.
1461
- * @param offset Offset in bytes within the host data.
1462
- * @param size Size of the data to copy in bytes.
1463
- */
1464
- static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1465
- ggml_tensor *tensor,
1466
- const void *data,
1467
- size_t offset,
1468
- size_t size) {
1469
- ggml_backend_cann_context *cann_ctx =
1470
- (ggml_backend_cann_context *)backend->context;
1471
-
1472
- if (!need_transform(tensor->type)) {
1473
- ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1474
- size, ACL_MEMCPY_HOST_TO_DEVICE,
1475
- cann_ctx->stream()));
1476
- } else {
1477
- void *transform_buffer = malloc(size);
1478
- ggml_backend_cann_transform(tensor, data, transform_buffer);
1479
-
1480
- ACL_CHECK(aclrtMemcpyAsync(
1481
- (char *)tensor->data + offset, size, transform_buffer, size,
1482
- ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1483
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1484
- free(transform_buffer);
1485
- }
1486
- }
1487
-
1488
- static void ggml_backend_cann_get_tensor_async(
1489
- ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1490
- size_t offset, size_t size) {
1491
- ggml_backend_cann_context *cann_ctx =
1492
- (ggml_backend_cann_context *)backend->context;
1493
- ggml_backend_buffer_t buf =
1494
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1495
-
1496
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1497
- "unsupported buffer type");
1498
-
1499
- if (!need_transform(tensor->type)) {
1500
- ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1501
- size, ACL_MEMCPY_DEVICE_TO_HOST,
1502
- cann_ctx->stream()));
1503
- } else {
1504
- void *transform_buffer = malloc(size);
1505
- ACL_CHECK(aclrtMemcpyAsync(
1506
- transform_buffer, size, (char *)tensor->data + offset, size,
1507
- ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1508
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1509
- ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1510
- free(transform_buffer);
1511
- }
1512
- }
1513
-
1514
- /**
1515
- * @brief Asynchronously copies tensor data between CANN backends.
1516
- *
1517
- * This function copies tensor data asynchronously between two CANN backends. It
1518
- * checks if both tensors reside in CANN buffers and whether the devices support
1519
- * peer-to-peer access for direct copying. If not, it returns false.
1520
- *
1521
- * @param backend_src Pointer to the source CANN backend structure.
1522
- * @param backend_dst Pointer to the destination CANN backend structure.
1523
- * @param src Pointer to the source tensor to copy data from.
1524
- * @param dst Pointer to the destination tensor to copy data to.
1525
- * @return true if the copy operation succeeds, false otherwise.
1526
- */
1527
- static bool ggml_backend_cann_cpy_tensor_async(
1528
- ggml_backend_t backend_src, ggml_backend_t backend_dst,
1529
- const ggml_tensor* src, ggml_tensor* dst) {
1530
- GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1531
- ggml_backend_is_cann(backend_dst));
1532
-
1533
- if (!ggml_backend_buffer_is_cann(src->buffer) ||
1534
- !ggml_backend_buffer_is_cann(dst->buffer)) {
1535
- return false;
1536
- }
1537
-
1538
- ggml_backend_buffer_t buf_src =
1539
- src->view_src ? src->view_src->buffer : src->buffer;
1540
- ggml_backend_buffer_t buf_dst =
1541
- dst->view_src ? dst->view_src->buffer : dst->buffer;
1542
-
1543
- ggml_backend_cann_context* cann_ctx_src =
1544
- (ggml_backend_cann_context*)backend_src->context;
1545
- ggml_backend_cann_context* cann_ctx_dst =
1546
- (ggml_backend_cann_context*)backend_dst->context;
1547
-
1548
- size_t copy_size = ggml_nbytes(dst);
1549
- if (backend_src != backend_dst) {
1550
- ggml_backend_cann_buffer_context* buf_ctx_src =
1551
- (ggml_backend_cann_buffer_context*)buf_src->context;
1552
- ggml_backend_cann_buffer_context* buf_ctx_dst =
1553
- (ggml_backend_cann_buffer_context*)buf_dst->context;
1554
-
1555
- GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1556
- GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1557
-
1558
- int32_t canAccessPeer = 0;
1559
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1560
- cann_ctx_dst->device));
1561
- if (!canAccessPeer) {
1562
- return false;
1563
- }
1564
-
1565
- // need open both directions for memcpyasync between devices.
1566
- ggml_cann_set_device(cann_ctx_dst->device);
1567
- ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1568
- ggml_cann_set_device(cann_ctx_src->device);
1569
- ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1570
-
1571
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1572
- ACL_MEMCPY_DEVICE_TO_DEVICE,
1573
- cann_ctx_src->stream()));
1574
-
1575
- //TODO: workaround for Event didn`t work here.
1576
- aclrtSynchronizeStream(cann_ctx_src->stream());
1577
- } else {
1578
- // src and dst are on the same backend
1579
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1580
- ACL_MEMCPY_DEVICE_TO_DEVICE,
1581
- cann_ctx_dst->stream()));
1582
- }
1583
-
1584
- return true;
1585
- }
1586
-
1587
- /**
1588
- * @brief Synchronizes a CANN backend.
1589
- *
1590
- * This function synchronizes the specified CANN backend by waiting for all
1591
- * operations in its associated stream to complete.
1592
- *
1593
- * @param backend Pointer to the CANN backend structure to synchronize.
1594
- */
1595
- static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1596
- ggml_backend_cann_context* cann_ctx =
1597
- (ggml_backend_cann_context*)backend->context;
1598
-
1599
- ggml_cann_set_device(cann_ctx->device);
1600
-
1601
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1602
- }
1603
-
1604
- /**
1605
- * @brief Computes a computational graph using a CANN backend.
1606
- *
1607
- * This function computes the operations defined in the computational graph
1608
- * using the specified CANN backend.
1609
- *
1610
- * @param backend Pointer to the CANN backend structure to use for computation.
1611
- * @param cgraph Pointer to the computational graph structure containing nodes
1612
- * representing operations to be computed.
1613
- * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1614
- * completes successfully, otherwise an appropriate error status.
1615
- */
1616
- static enum ggml_status ggml_backend_cann_graph_compute(
1617
- ggml_backend_t backend, ggml_cgraph* cgraph) {
1618
- ggml_backend_cann_context* cann_ctx =
1619
- (ggml_backend_cann_context*)backend->context;
1620
-
1621
- ggml_cann_set_device(cann_ctx->device);
1622
-
1623
- for (int i = 0; i < cgraph->n_nodes; i++) {
1624
- ggml_tensor* node = cgraph->nodes[i];
1625
-
1626
- if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1627
- continue;
1628
- }
1629
-
1630
- bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1631
-
1632
- if (!ok) {
1633
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1634
- node->name, ggml_op_name(node->op));
1635
- }
1636
- GGML_ASSERT(ok);
1637
- }
1638
-
1639
- return GGML_STATUS_SUCCESS;
1640
- }
1641
-
1642
- /**
1643
- * @brief Checks if the CANN backend supports a specific operation.
1644
- *
1645
- * This function checks whether the specified operation is supported by the
1646
- * CANN backend.
1647
- *
1648
- * @param backend Pointer to the CANN backend structure to check support for
1649
- * the operation.
1650
- * @param op Pointer to the tensor representing the operation to check.
1651
- * @return bool Returns true if the operation is supported by the backend,
1652
- * otherwise false.
1653
- */
1654
- static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1655
- const ggml_tensor* op) {
1656
- switch (op->op) {
1657
- case GGML_OP_UNARY:
1658
- switch (ggml_get_unary_op(op)) {
1659
- case GGML_UNARY_OP_GELU:
1660
- case GGML_UNARY_OP_SILU:
1661
- case GGML_UNARY_OP_RELU:
1662
- case GGML_UNARY_OP_HARDSIGMOID:
1663
- case GGML_UNARY_OP_HARDSWISH:
1664
- case GGML_UNARY_OP_GELU_QUICK:
1665
- case GGML_UNARY_OP_TANH:
1666
- return true;
1667
- default:
1668
- return false;
1669
- }
1670
- case GGML_OP_MUL_MAT: {
1671
- switch (op->src[0]->type) {
1672
- case GGML_TYPE_F16:
1673
- case GGML_TYPE_F32:
1674
- case GGML_TYPE_Q8_0:
1675
- // TODO: fix me
1676
- // Current groupsize should not be greater than k-1 in
1677
- // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
1678
- case GGML_TYPE_Q4_0:
1679
- return true;
1680
- default:
1681
- return false;
1682
- }
1683
- }
1684
- case GGML_OP_MUL_MAT_ID:
1685
- return false;
1686
- // embedding
1687
- case GGML_OP_GET_ROWS: {
1688
- switch (op->src[0]->type) {
1689
- case GGML_TYPE_F32:
1690
- case GGML_TYPE_F16:
1691
- case GGML_TYPE_Q4_0:
1692
- case GGML_TYPE_Q8_0:
1693
- return true;
1694
- default:
1695
- return false;
1696
- }
1697
- } break;
1698
- case GGML_OP_CPY: {
1699
- switch (op->type) {
1700
- case GGML_TYPE_F32:
1701
- case GGML_TYPE_F16:
1702
- case GGML_TYPE_Q8_0:
1703
- case GGML_TYPE_Q4_0:
1704
- return true;
1705
- default:
1706
- return false;
1707
- }
1708
- }
1709
- case GGML_OP_DUP:
1710
- case GGML_OP_REPEAT:
1711
- case GGML_OP_CONCAT:
1712
- case GGML_OP_NONE:
1713
- case GGML_OP_RESHAPE:
1714
- case GGML_OP_VIEW:
1715
- case GGML_OP_PERMUTE:
1716
- case GGML_OP_TRANSPOSE:
1717
- case GGML_OP_NORM:
1718
- case GGML_OP_ADD:
1719
- case GGML_OP_MUL:
1720
- case GGML_OP_DIV:
1721
- case GGML_OP_RMS_NORM:
1722
- case GGML_OP_SCALE:
1723
- case GGML_OP_SQR:
1724
- case GGML_OP_CLAMP:
1725
- case GGML_OP_CONT:
1726
- case GGML_OP_DIAG_MASK_INF:
1727
- case GGML_OP_SOFT_MAX:
1728
- case GGML_OP_ROPE:
1729
- case GGML_OP_IM2COL:
1730
- case GGML_OP_POOL_2D:
1731
- case GGML_OP_SUM_ROWS:
1732
- case GGML_OP_ARGSORT:
1733
- case GGML_OP_ACC:
1734
- case GGML_OP_GROUP_NORM:
1735
- case GGML_OP_UPSCALE:
1736
- case GGML_OP_PAD:
1737
- case GGML_OP_ARANGE:
1738
- case GGML_OP_TIMESTEP_EMBEDDING:
1739
- case GGML_OP_LEAKY_RELU:
1740
- return true;
1741
- default:
1742
- return false;
1743
- }
1744
-
1745
- GGML_UNUSED(dev);
1746
- }
1747
-
1748
- /**
1749
- * @brief Checks if the backend buffer type is associated with the CANN backend.
1750
- *
1751
- * This function checks whether the provided backend buffer type is associated
1752
- * with the CANN backend based on the comparison of its name retrieval function
1753
- * pointer.
1754
- *
1755
- * @param buft Pointer to the backend buffer type to check.
1756
- * @return bool Returns true if the buffer type is associated with the CANN
1757
- * backend, otherwise false.
1758
- */
1759
- static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1760
- return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1761
- }
1762
-
1763
- /**
1764
- * @brief Determines if a tensor operation should be offloaded to the CANN
1765
- * backend.
1766
- *
1767
- * This function checks if a given tensor operation should be offloaded to the
1768
- * CANN backend based on the operation type and the size of the tensor. It
1769
- * returns true if the second dimension (ne[1]) of the tensor is greater than or
1770
- * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1771
- *
1772
- * @param backend Pointer to the CANN backend.
1773
- * @param op Pointer to the tensor operation to check.
1774
- * @return bool Returns true if the operation should be offloaded, otherwise
1775
- * false.
1776
- */
1777
- static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
1778
- const ggml_tensor* op) {
1779
- const int min_batch_size = 32;
1780
- GGML_UNUSED(dev);
1781
-
1782
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1783
- }
1784
-
1785
- /**
1786
- * @brief Records an event on the CANN backend stream.
1787
- *
1788
- * This function records the given event on the ACL runtime stream associated
1789
- * with the backend context.
1790
- *
1791
- * @param event Pointer to the event structure to be recorded.
1792
- */
1793
- static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
1794
- ggml_backend_cann_context* cann_ctx =
1795
- (ggml_backend_cann_context*)backend->context;
1796
- ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1797
- }
1798
-
1799
- /**
1800
- * @brief Waits for a recorded event to complete on the CANN backend stream.
1801
- *
1802
- * This function makes the given backend wait for the event to complete on its
1803
- * ACL runtime stream.
1804
- *
1805
- * @param backend Pointer to the backend structure.
1806
- * @param event Pointer to the event structure that the backend needs to wait
1807
- * for.
1808
- */
1809
- static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1810
- ggml_backend_event_t event) {
1811
- ggml_backend_cann_context* cann_ctx =
1812
- (ggml_backend_cann_context*)backend->context;
1813
- if (ggml_backend_is_cann(backend)) {
1814
- ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1815
- (aclrtEvent)event->context));
1816
- } else {
1817
- GGML_ABORT("fatal error");
1818
- }
1819
- }
1820
-
1821
- /**
1822
- * @brief Structure defining the interface for the CANN backend.
1823
- *
1824
- * This structure contains function pointers for various operations
1825
- * supported by the CANN backend, including name retrieval, memory
1826
- * management, tensor operations, synchronization, and event handling.
1827
- */
1828
- static const ggml_backend_i ggml_backend_cann_interface = {
1829
- /* .get_name = */ ggml_backend_cann_name,
1830
- /* .free = */ ggml_backend_cann_free,
1831
- /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1832
- /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1833
- /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1834
- /* .synchronize = */ ggml_backend_cann_synchronize,
1835
- /* .graph_plan_create = */ NULL,
1836
- /* .graph_plan_free = */ NULL,
1837
- /* .graph_plan_update = */ NULL,
1838
- /* .graph_plan_compute = */ NULL,
1839
- /* .graph_compute = */ ggml_backend_cann_graph_compute,
1840
- /* .event_record = */ ggml_backend_cann_event_record,
1841
- /* .event_wait = */ ggml_backend_cann_event_wait,
1842
- };
1843
-
1844
- /**
1845
- * @brief Return the hardcoded GUID for the CANN backend.
1846
- *
1847
- * This function returns a static GUID which uniquely identifies the CANN
1848
- * backend.
1849
- *
1850
- * @return A pointer to the static GUID.
1851
- */
1852
- static ggml_guid_t ggml_backend_cann_guid() {
1853
- static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1854
- 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1855
- return &guid;
1856
- }
1857
-
1858
- // backend device
1859
- struct ggml_backend_cann_device_context {
1860
- int device;
1861
- std::string name;
1862
- std::string description;
1863
- };
1864
-
1865
- static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
1866
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1867
- return ctx->name.c_str();
1868
- }
1869
-
1870
- static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
1871
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1872
- return ctx->description.c_str();
1873
- }
1874
-
1875
- static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1876
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1877
- ggml_backend_cann_get_device_memory(ctx->device, free, total);
1878
- }
1879
-
1880
- static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
1881
- GGML_UNUSED(dev);
1882
- return GGML_BACKEND_DEVICE_TYPE_GPU;
1883
- }
1884
-
1885
- static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
1886
- props->name = ggml_backend_cann_device_get_name(dev);
1887
- props->description = ggml_backend_cann_device_get_description(dev);
1888
- props->type = ggml_backend_cann_device_get_type(dev);
1889
- ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
1890
-
1891
- bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
1892
-
1893
- props->caps = {
1894
- /* .async = */ false,
1895
- /* .host_buffer = */ host_buffer,
1896
- /* .buffer_from_host_ptr = */ false,
1897
- /* .events = */ true,
1898
- };
1899
- }
1900
-
1901
- static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
1902
- GGML_UNUSED(params);
1903
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1904
- return ggml_backend_cann_init(ctx->device);
1905
- }
1906
-
1907
- /**
1908
- * @brief Checks if the CANN backend supports a specific backend buffer type.
1909
- *
1910
- * This function determines whether the CANN backend supports the given backend
1911
- * buffer type by comparing the device context of the backend and buffer type.
1912
- * It returns true if the devices are same between the backend context and
1913
- * buffer type context.
1914
- *
1915
- * @param backend Pointer to the CANN backend.
1916
- * @param buft Pointer to the backend buffer type to check.
1917
- * @return bool Returns true if the CANN backend supports the buffer type,
1918
- * otherwise false.
1919
- */
1920
- static bool ggml_backend_cann_supports_buft(
1921
- ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1922
- if (ggml_backend_buft_is_cann(buft)) {
1923
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1924
- ggml_backend_cann_buffer_type_context * buft_ctx =
1925
- (ggml_backend_cann_buffer_type_context *)buft->context;
1926
- return buft_ctx->device == dev_ctx->device;
1927
- }
1928
- return false;
1929
- }
1930
-
1931
- static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
1932
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1933
- return ggml_backend_cann_buffer_type(ctx->device);
1934
- }
1935
-
1936
- static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
1937
- GGML_UNUSED(dev);
1938
- return ggml_backend_cann_host_buffer_type();
1939
- }
1940
-
1941
- /**
1942
- * @brief Creates a new event for the CANN backend device.
1943
- *
1944
- * This function initializes a new event for the CANN backend by setting the
1945
- * device and creating an ACL runtime event. The created event is then wrapped
1946
- * in a ggml_backend_event structure and returned.
1947
- *
1948
- * @param backend Pointer to the CANN backend.
1949
- * @return ggml_backend_event_t Returns a pointer to the new event structure.
1950
- */
1951
- static ggml_backend_event_t ggml_backend_cann_device_event_new(
1952
- ggml_backend_dev_t dev) {
1953
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1954
-
1955
- ggml_cann_set_device(dev_ctx->device);
1956
-
1957
- aclrtEvent event;
1958
- ACL_CHECK(aclrtCreateEvent(&event));
1959
-
1960
- return new ggml_backend_event{
1961
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
1962
- /* .context = */ event,
1963
- };
1964
- }
1965
-
1966
- /**
1967
- * @brief Frees a CANN backend event.
1968
- *
1969
- * This function destroys the ACL runtime event associated with the given CANN
1970
- * backend event and then deletes the event structure itself.
1971
- *
1972
- * @param event Pointer to the event structure to be freed.
1973
- */
1974
- static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1975
- ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
1976
-
1977
- delete event;
1978
- GGML_UNUSED(dev);
1979
- }
1980
-
1981
- /**
1982
- * @brief Synchronizes the given event on the CANN backend.
1983
- *
1984
- * This function waits for the specified event to complete on the ACL runtime.
1985
- *
1986
- * @param event Pointer to the event structure to be synchronized.
1987
- */
1988
- static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1989
- ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
1990
-
1991
- GGML_UNUSED(dev);
1992
- }
1993
-
1994
- static const ggml_backend_device_i ggml_backend_cann_device_interface = {
1995
- /* .get_name = */ ggml_backend_cann_device_get_name,
1996
- /* .get_description = */ ggml_backend_cann_device_get_description,
1997
- /* .get_memory = */ ggml_backend_cann_device_get_memory,
1998
- /* .get_type = */ ggml_backend_cann_device_get_type,
1999
- /* .get_props = */ ggml_backend_cann_device_get_props,
2000
- /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2001
- /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2002
- /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2003
- /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2004
- /* .supports_op = */ ggml_backend_cann_supports_op,
2005
- /* .supports_buft = */ ggml_backend_cann_supports_buft,
2006
- /* .offload_op = */ ggml_backend_cann_offload_op,
2007
- /* .event_new = */ ggml_backend_cann_device_event_new,
2008
- /* .event_free = */ ggml_backend_cann_device_event_free,
2009
- /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2010
- };
2011
-
2012
-
2013
- // backend reg
2014
- struct ggml_backend_cann_reg_context {
2015
- std::vector<ggml_backend_dev_t> devices;
2016
- };
2017
-
2018
- static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2019
- GGML_UNUSED(reg);
2020
- return GGML_CANN_NAME;
2021
- }
2022
-
2023
- static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2024
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2025
- return ctx->devices.size();
2026
- }
2027
-
2028
- static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2029
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2030
- GGML_ASSERT(index < ctx->devices.size());
2031
- return ctx->devices[index];
2032
- }
2033
-
2034
- static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2035
- GGML_UNUSED(reg);
2036
- GGML_UNUSED(name);
2037
- // reserved for future use
2038
- return nullptr;
2039
- }
2040
-
2041
- static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2042
- /* .get_name = */ ggml_backend_cann_reg_get_name,
2043
- /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2044
- /* .get_device_get = */ ggml_backend_cann_reg_get_device,
2045
- /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2046
- };
2047
-
2048
- // backend registry, called only once for cann backend
2049
- ggml_backend_reg_t ggml_backend_cann_reg() {
2050
- static ggml_backend_reg reg;
2051
- static bool initialized = false;
2052
-
2053
- {
2054
- static std::mutex mutex;
2055
- std::lock_guard<std::mutex> lock(mutex);
2056
- if (!initialized) {
2057
- aclInit(nullptr);
2058
- ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2059
-
2060
- for (int i = 0; i < ggml_cann_info().device_count; i++) {
2061
- ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2062
- dev_ctx->description = aclrtGetSocName();
2063
- dev_ctx->device = i;
2064
- dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2065
- ggml_cann_set_device(i);
2066
- ggml_backend_dev_t dev = new ggml_backend_device {
2067
- /* .interface = */ ggml_backend_cann_device_interface,
2068
- /* .reg = */ &reg,
2069
- /* .context = */ dev_ctx
2070
- };
2071
- ctx->devices.push_back(dev);
2072
- }
2073
-
2074
- reg = ggml_backend_reg {
2075
- /* .interface = */ ggml_backend_cann_reg_interface,
2076
- /* .context = */ ctx
2077
- };
2078
- }
2079
-
2080
- initialized = true;
2081
- }
2082
-
2083
- return &reg;
2084
- }
2085
-
2086
- ggml_backend_t ggml_backend_cann_init(int32_t device) {
2087
- aclInit(nullptr);
2088
- if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2089
- GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2090
- return nullptr;
2091
- }
2092
-
2093
- ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2094
- if (ctx == nullptr) {
2095
- GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2096
- return nullptr;
2097
- }
2098
- ggml_cann_set_device(ctx->device);
2099
- ggml_backend_t cann_backend =
2100
- new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
2101
- /* .interface = */ ggml_backend_cann_interface,
2102
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2103
- /* .context = */ ctx};
2104
-
2105
- return cann_backend;
2106
- }
2107
-
2108
- bool ggml_backend_is_cann(ggml_backend_t backend) {
2109
- return backend != NULL &&
2110
- ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2111
- }
2112
-
2113
- int32_t ggml_backend_cann_get_device_count() {
2114
- return ggml_cann_info().device_count;
2115
- }
2116
-
2117
- void ggml_backend_cann_get_device_description(
2118
- int32_t device, char* description, size_t description_size) {
2119
- ggml_cann_set_device(device);
2120
- const char* soc_name = aclrtGetSocName();
2121
- snprintf(description, description_size, "%s", soc_name);
2122
- }
2123
-
2124
- void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
2125
- size_t* total) {
2126
- ggml_cann_set_device(device);
2127
- ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2128
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cpu-impl.h DELETED
@@ -1,614 +0,0 @@
1
- #pragma once
2
-
3
- // GGML CPU internal header
4
-
5
- #include "ggml.h"
6
- #include "ggml-impl.h"
7
- #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
8
- //#include <stddef.h>
9
- #include <stdbool.h>
10
- #include <string.h> // memcpy
11
- #include <math.h> // fabsf
12
-
13
-
14
- #ifdef __cplusplus
15
- extern "C" {
16
- #endif
17
-
18
- #if defined(_MSC_VER)
19
-
20
- #define m512bh(p) p
21
- #define m512i(p) p
22
-
23
- #else
24
-
25
- #define m512bh(p) (__m512bh)(p)
26
- #define m512i(p) (__m512i)(p)
27
-
28
- #endif
29
-
30
- /**
31
- * Converts brain16 to float32.
32
- *
33
- * The bfloat16 floating point format has the following structure:
34
- *
35
- * ┌sign
36
- * │
37
- * │ ┌exponent
38
- * │ │
39
- * │ │ ┌mantissa
40
- * │ │ │
41
- * │┌──┴───┐┌─┴───┐
42
- * 0b0000000000000000 brain16
43
- *
44
- * Since bf16 has the same number of exponent bits as a 32bit float,
45
- * encoding and decoding numbers becomes relatively straightforward.
46
- *
47
- * ┌sign
48
- * │
49
- * │ ┌exponent
50
- * │ │
51
- * │ │ ┌mantissa
52
- * │ │ │
53
- * │┌──┴───┐┌─┴───────────────────┐
54
- * 0b00000000000000000000000000000000 IEEE binary32
55
- *
56
- * For comparison, the standard fp16 format has fewer exponent bits.
57
- *
58
- * ┌sign
59
- * │
60
- * │ ┌exponent
61
- * │ │
62
- * │ │ ┌mantissa
63
- * │ │ │
64
- * │┌─┴─┐┌─┴──────┐
65
- * 0b0000000000000000 IEEE binary16
66
- *
67
- * @see IEEE 754-2008
68
- */
69
- static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
70
- union {
71
- float f;
72
- uint32_t i;
73
- } u;
74
- u.i = (uint32_t)h.bits << 16;
75
- return u.f;
76
- }
77
-
78
- /**
79
- * Converts float32 to brain16.
80
- *
81
- * This is binary identical with Google Brain float conversion.
82
- * Floats shall round to nearest even, and NANs shall be quiet.
83
- * Subnormals aren't flushed to zero, except perhaps when used.
84
- * This code should vectorize nicely if using modern compilers.
85
- */
86
- static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
87
- ggml_bf16_t h;
88
- union {
89
- float f;
90
- uint32_t i;
91
- } u;
92
- u.f = s;
93
- if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
94
- h.bits = (u.i >> 16) | 64; /* force to quiet */
95
- return h;
96
- }
97
- h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
98
- return h;
99
- }
100
-
101
- #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
102
- #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
103
-
104
- // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
105
- #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
106
- #ifndef __FMA__
107
- #define __FMA__
108
- #endif
109
- #ifndef __F16C__
110
- #define __F16C__
111
- #endif
112
- #endif
113
-
114
- // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
115
- #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
116
- #ifndef __SSE3__
117
- #define __SSE3__
118
- #endif
119
- #ifndef __SSSE3__
120
- #define __SSSE3__
121
- #endif
122
- #endif
123
-
124
- #if defined(__ARM_FEATURE_SVE)
125
- #include <arm_sve.h>
126
- #include <sys/prctl.h>
127
- #endif
128
-
129
- // 16-bit float
130
- // on Arm, we use __fp16
131
- // on x86, we use uint16_t
132
- #if defined(__ARM_NEON)
133
-
134
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
135
- //
136
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
137
- //
138
- #include <arm_neon.h>
139
-
140
- #ifdef _MSC_VER
141
-
142
- typedef uint16_t ggml_fp16_internal_t;
143
-
144
- #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
145
-
146
- #else
147
-
148
- typedef __fp16 ggml_fp16_internal_t;
149
-
150
- #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
151
-
152
- #endif // _MSC_VER
153
-
154
- #if !defined(__aarch64__)
155
-
156
- // 32-bit ARM compatibility
157
-
158
- // vaddlvq_s16
159
- // vpaddq_s16
160
- // vpaddq_s32
161
- // vaddvq_s32
162
- // vaddvq_f32
163
- // vmaxvq_f32
164
- // vcvtnq_s32_f32
165
- // vzip1_u8
166
- // vzip2_u8
167
-
168
- inline static int32_t vaddlvq_s16(int16x8_t v) {
169
- int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
170
- return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
171
- }
172
-
173
- inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
174
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
175
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
176
- return vcombine_s16(a0, b0);
177
- }
178
-
179
- inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
180
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
181
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
182
- return vcombine_s32(a0, b0);
183
- }
184
-
185
- inline static int32_t vaddvq_s32(int32x4_t v) {
186
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
187
- }
188
-
189
- inline static float vaddvq_f32(float32x4_t v) {
190
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
191
- }
192
-
193
- inline static float vmaxvq_f32(float32x4_t v) {
194
- return
195
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
196
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
197
- }
198
-
199
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
200
- int32x4_t res;
201
-
202
- res[0] = roundf(vgetq_lane_f32(v, 0));
203
- res[1] = roundf(vgetq_lane_f32(v, 1));
204
- res[2] = roundf(vgetq_lane_f32(v, 2));
205
- res[3] = roundf(vgetq_lane_f32(v, 3));
206
-
207
- return res;
208
- }
209
-
210
- inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
211
- uint8x8_t res;
212
-
213
- res[0] = a[0]; res[1] = b[0];
214
- res[2] = a[1]; res[3] = b[1];
215
- res[4] = a[2]; res[5] = b[2];
216
- res[6] = a[3]; res[7] = b[3];
217
-
218
- return res;
219
- }
220
-
221
- inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
222
- uint8x8_t res;
223
-
224
- res[0] = a[4]; res[1] = b[4];
225
- res[2] = a[5]; res[3] = b[5];
226
- res[4] = a[6]; res[5] = b[6];
227
- res[6] = a[7]; res[7] = b[7];
228
-
229
- return res;
230
- }
231
-
232
- // vld1q_s16_x2
233
- // vld1q_u8_x2
234
- // vld1q_u8_x4
235
- // vld1q_s8_x2
236
- // vld1q_s8_x4
237
- // TODO: double-check these work correctly
238
-
239
- typedef struct ggml_int16x8x2_t {
240
- int16x8_t val[2];
241
- } ggml_int16x8x2_t;
242
-
243
- inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
244
- ggml_int16x8x2_t res;
245
-
246
- res.val[0] = vld1q_s16(ptr + 0);
247
- res.val[1] = vld1q_s16(ptr + 8);
248
-
249
- return res;
250
- }
251
-
252
- typedef struct ggml_uint8x16x2_t {
253
- uint8x16_t val[2];
254
- } ggml_uint8x16x2_t;
255
-
256
- inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
257
- ggml_uint8x16x2_t res;
258
-
259
- res.val[0] = vld1q_u8(ptr + 0);
260
- res.val[1] = vld1q_u8(ptr + 16);
261
-
262
- return res;
263
- }
264
-
265
- typedef struct ggml_uint8x16x4_t {
266
- uint8x16_t val[4];
267
- } ggml_uint8x16x4_t;
268
-
269
- inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
270
- ggml_uint8x16x4_t res;
271
-
272
- res.val[0] = vld1q_u8(ptr + 0);
273
- res.val[1] = vld1q_u8(ptr + 16);
274
- res.val[2] = vld1q_u8(ptr + 32);
275
- res.val[3] = vld1q_u8(ptr + 48);
276
-
277
- return res;
278
- }
279
-
280
- typedef struct ggml_int8x16x2_t {
281
- int8x16_t val[2];
282
- } ggml_int8x16x2_t;
283
-
284
- inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
285
- ggml_int8x16x2_t res;
286
-
287
- res.val[0] = vld1q_s8(ptr + 0);
288
- res.val[1] = vld1q_s8(ptr + 16);
289
-
290
- return res;
291
- }
292
-
293
- typedef struct ggml_int8x16x4_t {
294
- int8x16_t val[4];
295
- } ggml_int8x16x4_t;
296
-
297
- inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
298
- ggml_int8x16x4_t res;
299
-
300
- res.val[0] = vld1q_s8(ptr + 0);
301
- res.val[1] = vld1q_s8(ptr + 16);
302
- res.val[2] = vld1q_s8(ptr + 32);
303
- res.val[3] = vld1q_s8(ptr + 48);
304
-
305
- return res;
306
- }
307
-
308
- // NOTE: not tested
309
- inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
310
- int8x16_t res;
311
-
312
- res[ 0] = a[b[ 0]];
313
- res[ 1] = a[b[ 1]];
314
- res[ 2] = a[b[ 2]];
315
- res[ 3] = a[b[ 3]];
316
- res[ 4] = a[b[ 4]];
317
- res[ 5] = a[b[ 5]];
318
- res[ 6] = a[b[ 6]];
319
- res[ 7] = a[b[ 7]];
320
- res[ 8] = a[b[ 8]];
321
- res[ 9] = a[b[ 9]];
322
- res[10] = a[b[10]];
323
- res[11] = a[b[11]];
324
- res[12] = a[b[12]];
325
- res[13] = a[b[13]];
326
- res[14] = a[b[14]];
327
- res[15] = a[b[15]];
328
-
329
- return res;
330
- }
331
-
332
- // NOTE: not tested
333
- inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
334
- uint8x16_t res;
335
-
336
- res[ 0] = a[b[ 0]];
337
- res[ 1] = a[b[ 1]];
338
- res[ 2] = a[b[ 2]];
339
- res[ 3] = a[b[ 3]];
340
- res[ 4] = a[b[ 4]];
341
- res[ 5] = a[b[ 5]];
342
- res[ 6] = a[b[ 6]];
343
- res[ 7] = a[b[ 7]];
344
- res[ 8] = a[b[ 8]];
345
- res[ 9] = a[b[ 9]];
346
- res[10] = a[b[10]];
347
- res[11] = a[b[11]];
348
- res[12] = a[b[12]];
349
- res[13] = a[b[13]];
350
- res[14] = a[b[14]];
351
- res[15] = a[b[15]];
352
-
353
- return res;
354
- }
355
-
356
- #else
357
-
358
- #define ggml_int16x8x2_t int16x8x2_t
359
- #define ggml_uint8x16x2_t uint8x16x2_t
360
- #define ggml_uint8x16x4_t uint8x16x4_t
361
- #define ggml_int8x16x2_t int8x16x2_t
362
- #define ggml_int8x16x4_t int8x16x4_t
363
-
364
- #define ggml_vld1q_s16_x2 vld1q_s16_x2
365
- #define ggml_vld1q_u8_x2 vld1q_u8_x2
366
- #define ggml_vld1q_u8_x4 vld1q_u8_x4
367
- #define ggml_vld1q_s8_x2 vld1q_s8_x2
368
- #define ggml_vld1q_s8_x4 vld1q_s8_x4
369
- #define ggml_vqtbl1q_s8 vqtbl1q_s8
370
- #define ggml_vqtbl1q_u8 vqtbl1q_u8
371
-
372
- #endif // !defined(__aarch64__)
373
-
374
- #if !defined(__ARM_FEATURE_DOTPROD)
375
-
376
- inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
377
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
378
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
379
-
380
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
381
- }
382
-
383
- #else
384
-
385
- #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
386
-
387
- #endif // !defined(__ARM_FEATURE_DOTPROD)
388
-
389
- #endif // defined(__ARM_NEON)
390
-
391
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
392
-
393
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
394
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
395
-
396
- #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
397
-
398
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
399
- ggml_fp16_internal_t tmp;
400
- memcpy(&tmp, &h, sizeof(ggml_fp16_t));
401
- return (float)tmp;
402
- }
403
-
404
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
405
- ggml_fp16_t res;
406
- ggml_fp16_internal_t tmp = f;
407
- memcpy(&res, &tmp, sizeof(ggml_fp16_t));
408
- return res;
409
- }
410
-
411
- #else
412
-
413
- #ifdef __wasm_simd128__
414
- #include <wasm_simd128.h>
415
- #else
416
- #ifdef __POWER9_VECTOR__
417
- #include <altivec.h>
418
- #undef bool
419
- #define bool _Bool
420
- #else
421
- #if defined(_MSC_VER) || defined(__MINGW32__)
422
- #include <intrin.h>
423
- #else
424
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
425
- #if !defined(__riscv)
426
- #include <immintrin.h>
427
- #endif
428
- #endif
429
- #endif
430
- #endif
431
- #endif
432
-
433
- #ifdef __riscv_v_intrinsic
434
- #include <riscv_vector.h>
435
- #endif
436
-
437
- #if defined(__loongarch64)
438
- #if defined(__loongarch_asx)
439
- #include <lasxintrin.h>
440
- #endif
441
- #if defined(__loongarch_sx)
442
- #include <lsxintrin.h>
443
- #endif
444
- #endif
445
-
446
- #if defined(__loongarch_asx)
447
-
448
- typedef union {
449
- int32_t i;
450
- float f;
451
- } ft_union;
452
-
453
- /* float type data load instructions */
454
- static __m128 __lsx_vreplfr2vr_s(float val) {
455
- ft_union fi_tmpval = {.f = val};
456
- return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
457
- }
458
-
459
- static __m256 __lasx_xvreplfr2vr_s(float val) {
460
- ft_union fi_tmpval = {.f = val};
461
- return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
462
- }
463
- #endif
464
-
465
- #ifdef __F16C__
466
-
467
- #ifdef _MSC_VER
468
- #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
469
- #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
470
- #else
471
- #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
472
- #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
473
- #endif
474
-
475
- #elif defined(__POWER9_VECTOR__)
476
-
477
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
478
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
479
- /* the inline asm below is about 12% faster than the lookup method */
480
- #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
481
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
482
-
483
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
484
- register float f;
485
- register double d;
486
- __asm__(
487
- "mtfprd %0,%2\n"
488
- "xscvhpdp %0,%0\n"
489
- "frsp %1,%0\n" :
490
- /* temp */ "=d"(d),
491
- /* out */ "=f"(f):
492
- /* in */ "r"(h));
493
- return f;
494
- }
495
-
496
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
497
- register double d;
498
- register ggml_fp16_t r;
499
- __asm__( /* xscvdphp can work on double or single precision */
500
- "xscvdphp %0,%2\n"
501
- "mffprd %1,%0\n" :
502
- /* temp */ "=d"(d),
503
- /* out */ "=r"(r):
504
- /* in */ "f"(f));
505
- return r;
506
- }
507
-
508
- #else
509
-
510
- // FP16 <-> FP32
511
- // ref: https://github.com/Maratyszcza/FP16
512
-
513
- static inline float fp32_from_bits(uint32_t w) {
514
- union {
515
- uint32_t as_bits;
516
- float as_value;
517
- } fp32;
518
- fp32.as_bits = w;
519
- return fp32.as_value;
520
- }
521
-
522
- static inline uint32_t fp32_to_bits(float f) {
523
- union {
524
- float as_value;
525
- uint32_t as_bits;
526
- } fp32;
527
- fp32.as_value = f;
528
- return fp32.as_bits;
529
- }
530
-
531
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
532
- const uint32_t w = (uint32_t) h << 16;
533
- const uint32_t sign = w & UINT32_C(0x80000000);
534
- const uint32_t two_w = w + w;
535
-
536
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
537
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
538
- const float exp_scale = 0x1.0p-112f;
539
- #else
540
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
541
- #endif
542
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
543
-
544
- const uint32_t magic_mask = UINT32_C(126) << 23;
545
- const float magic_bias = 0.5f;
546
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
547
-
548
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
549
- const uint32_t result = sign |
550
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
551
- return fp32_from_bits(result);
552
- }
553
-
554
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
555
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
556
- const float scale_to_inf = 0x1.0p+112f;
557
- const float scale_to_zero = 0x1.0p-110f;
558
- #else
559
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
560
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
561
- #endif
562
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
563
-
564
- const uint32_t w = fp32_to_bits(f);
565
- const uint32_t shl1_w = w + w;
566
- const uint32_t sign = w & UINT32_C(0x80000000);
567
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
568
- if (bias < UINT32_C(0x71000000)) {
569
- bias = UINT32_C(0x71000000);
570
- }
571
-
572
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
573
- const uint32_t bits = fp32_to_bits(base);
574
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
575
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
576
- const uint32_t nonsign = exp_bits + mantissa_bits;
577
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
578
- }
579
-
580
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
581
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
582
-
583
- #endif // __F16C__
584
-
585
- #endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
586
-
587
- #ifdef __ARM_FEATURE_SVE
588
- #include <arm_sve.h>
589
- #endif // __ARM_FEATURE_SVE
590
-
591
- // precomputed f32 table for f16 (256 KB)
592
- // defined in ggml.c, initialized in ggml_init()
593
- extern float ggml_table_f32_f16[1 << 16];
594
-
595
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
596
- // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
597
- // This is also true for POWER9.
598
- #if !defined(GGML_FP16_TO_FP32)
599
- inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
600
- uint16_t s;
601
- memcpy(&s, &f, sizeof(uint16_t));
602
- return ggml_table_f32_f16[s];
603
- }
604
-
605
- #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
606
- #endif
607
-
608
- #if !defined(GGML_FP32_TO_FP16)
609
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
610
- #endif
611
-
612
- #ifdef __cplusplus
613
- }
614
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cpu.c DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cuda.cu DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cuda/CMakeLists.txt ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
2
+
3
+ find_package(CUDAToolkit)
4
+
5
+ if (CUDAToolkit_FOUND)
6
+ message(STATUS "CUDA Toolkit found")
7
+
8
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
9
+ # native == GPUs available at build time
10
+ # 52 == Maxwell, lowest CUDA 12 standard
11
+ # 60 == P100, FP16 CUDA intrinsics
12
+ # 61 == Pascal, __dp4a instruction (per-byte integer dot product)
13
+ # 70 == V100, FP16 tensor cores
14
+ # 75 == Turing, int8 tensor cores
15
+ if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
16
+ set(CMAKE_CUDA_ARCHITECTURES "native")
17
+ elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
18
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
19
+ else()
20
+ set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
21
+ endif()
22
+ endif()
23
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
24
+
25
+ enable_language(CUDA)
26
+
27
+ file(GLOB GGML_HEADERS_CUDA "*.cuh")
28
+ list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
29
+
30
+ file(GLOB GGML_SOURCES_CUDA "*.cu")
31
+ file(GLOB SRCS "template-instances/fattn-wmma*.cu")
32
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
33
+ file(GLOB SRCS "template-instances/mmq*.cu")
34
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
35
+
36
+ if (GGML_CUDA_FA_ALL_QUANTS)
37
+ file(GLOB SRCS "template-instances/fattn-vec*.cu")
38
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
39
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
40
+ else()
41
+ file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
42
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
43
+ file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
44
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
45
+ file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
46
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
47
+ endif()
48
+
49
+ add_library(ggml-cuda
50
+ ${GGML_HEADERS_CUDA}
51
+ ${GGML_SOURCES_CUDA}
52
+ )
53
+
54
+ target_link_libraries(ggml-cuda PRIVATE ggml-base)
55
+ target_include_directories(ggml-cuda PRIVATE . ..)
56
+
57
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
58
+
59
+ if (GGML_CUDA_GRAPHS)
60
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS)
61
+ endif()
62
+
63
+ if (GGML_CUDA_FORCE_MMQ)
64
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
65
+ endif()
66
+
67
+ if (GGML_CUDA_FORCE_CUBLAS)
68
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
69
+ endif()
70
+
71
+ if (GGML_CUDA_NO_VMM)
72
+ add_compile_definitions(GGML_CUDA_NO_VMM)
73
+ endif()
74
+
75
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
76
+ add_compile_definitions(GGML_CUDA_F16)
77
+ endif()
78
+
79
+ if (GGML_CUDA_NO_PEER_COPY)
80
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
81
+ endif()
82
+
83
+ if (GGML_STATIC)
84
+ if (WIN32)
85
+ # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
86
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
87
+ else ()
88
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
89
+ endif()
90
+ else()
91
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
92
+ endif()
93
+
94
+ if (GGML_CUDA_NO_VMM)
95
+ # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
96
+ else()
97
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
98
+ endif()
99
+
100
+ set(CUDA_CXX_FLAGS "")
101
+
102
+ set(CUDA_FLAGS -use_fast_math)
103
+
104
+ if (GGML_FATAL_WARNINGS)
105
+ list(APPEND CUDA_FLAGS -Werror all-warnings)
106
+ endif()
107
+
108
+ if (GGML_ALL_WARNINGS AND NOT MSVC)
109
+ set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
110
+ if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
111
+ list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
112
+ endif()
113
+
114
+ execute_process(
115
+ COMMAND ${NVCC_CMD} -Xcompiler --version
116
+ OUTPUT_VARIABLE CUDA_CCFULLVER
117
+ ERROR_QUIET
118
+ )
119
+
120
+ if (NOT CUDA_CCFULLVER MATCHES clang)
121
+ set(CUDA_CCID "GNU")
122
+ execute_process(
123
+ COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
124
+ OUTPUT_VARIABLE CUDA_CCVER
125
+ ERROR_QUIET
126
+ )
127
+ else()
128
+ if (CUDA_CCFULLVER MATCHES Apple)
129
+ set(CUDA_CCID "AppleClang")
130
+ else()
131
+ set(CUDA_CCID "Clang")
132
+ endif()
133
+ string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
134
+ endif()
135
+
136
+ message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
137
+
138
+ get_flags(${CUDA_CCID} ${CUDA_CCVER})
139
+ list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
140
+ endif()
141
+
142
+ if (NOT MSVC)
143
+ list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
144
+ endif()
145
+
146
+ list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
147
+
148
+ if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
149
+ list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
150
+ endif()
151
+
152
+ target_compile_options(ggml-cuda PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
153
+ else()
154
+ message(FATAL_ERROR "CUDA Toolkit not found")
155
+ endif()
ggml/src/ggml-kompute.cpp DELETED
@@ -1,2184 +0,0 @@
1
- #include "ggml-impl.h"
2
- #include "ggml-backend.h"
3
- #include "ggml-backend-impl.h"
4
- #include "ggml-kompute.h"
5
-
6
- // These are generated at build time by cmake custom command
7
- #include "shaderop_scale.h"
8
- #include "shaderop_scale_8.h"
9
- #include "shaderop_add.h"
10
- #include "shaderop_addrow.h"
11
- #include "shaderop_mul.h"
12
- #include "shaderop_silu.h"
13
- #include "shaderop_relu.h"
14
- #include "shaderop_gelu.h"
15
- #include "shaderop_softmax.h"
16
- #include "shaderop_norm.h"
17
- #include "shaderop_rmsnorm.h"
18
- #include "shaderop_diagmask.h"
19
- #include "shaderop_mul_mat_f16.h"
20
- #include "shaderop_mul_mat_q8_0.h"
21
- #include "shaderop_mul_mat_q4_0.h"
22
- #include "shaderop_mul_mat_q4_1.h"
23
- #include "shaderop_mul_mat_q4_k.h"
24
- #include "shaderop_mul_mat_q6_k.h"
25
- #include "shaderop_mul_mat_mat_f32.h"
26
- #include "shaderop_getrows_f32.h"
27
- #include "shaderop_getrows_f16.h"
28
- #include "shaderop_getrows_q4_0.h"
29
- #include "shaderop_getrows_q4_1.h"
30
- #include "shaderop_getrows_q6_k.h"
31
- #include "shaderop_rope_f16.h"
32
- #include "shaderop_rope_f32.h"
33
- #include "shaderop_cpy_f16_f16.h"
34
- #include "shaderop_cpy_f16_f32.h"
35
- #include "shaderop_cpy_f32_f16.h"
36
- #include "shaderop_cpy_f32_f32.h"
37
-
38
- #include <algorithm>
39
- #include <array>
40
- #include <cassert>
41
- #include <cstdint>
42
- #include <cstdio>
43
- #include <cstring>
44
- #include <iostream>
45
- #include <memory>
46
- #include <mutex>
47
- #include <stdexcept>
48
- #include <string>
49
- #include <unordered_map>
50
- #include <utility>
51
- #include <vector>
52
-
53
- #include <kompute/Kompute.hpp>
54
- #include <vulkan/vulkan.hpp>
55
-
56
- #ifdef __linux__
57
- #include <cstdlib> // for setenv
58
- #endif
59
-
60
- #define QK4_0 32
61
- #define QR4_0 2
62
- #define QK4_1 32
63
- #define QK_NL 16
64
-
65
- typedef ggml_fp16_t half;
66
-
67
- static std::string ggml_kompute_format_name(int device) {
68
- return "Kompute" + std::to_string(device);
69
- }
70
-
71
- struct ggml_kompute_context {
72
- int device;
73
- std::string name;
74
- std::shared_ptr<vk::DescriptorPool> pool;
75
-
76
- ggml_kompute_context(int device)
77
- : device(device), name(ggml_kompute_format_name(device)) {}
78
- };
79
-
80
- // FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
81
- // and consolidate the init functions and simplify object lifetime management. As it currently stands,
82
- // we *have* to have the kompute manager no matter what for device discovery, but the kompute context
83
- // is only created when a device is set and vulkan is explicitly turned on.
84
- static ggml_kompute_context *s_kompute_context = nullptr;
85
-
86
- class kompute_manager {
87
- kp::Manager *s_mgr = nullptr;
88
-
89
- public:
90
- kp::Manager *operator()() {
91
- if (s_mgr && !s_mgr->hasInstance()) {
92
- destroy();
93
- }
94
- if (!s_mgr) {
95
- s_mgr = new kp::Manager;
96
- }
97
- return s_mgr;
98
- }
99
-
100
- void destroy() {
101
- delete s_mgr;
102
- s_mgr = nullptr;
103
- }
104
- };
105
-
106
- static kompute_manager komputeManager;
107
-
108
- struct ggml_vk_memory {
109
- void *data = nullptr;
110
- size_t size = 0;
111
- vk::DeviceMemory *primaryMemory = nullptr;
112
- vk::Buffer *primaryBuffer = nullptr;
113
- vk::DeviceMemory *stagingMemory = nullptr;
114
- vk::Buffer *stagingBuffer = nullptr;
115
- };
116
-
117
- #ifdef __linux__
118
- __attribute__((constructor))
119
- static void enable_sam() {
120
- setenv("RADV_PERFTEST", "sam", false);
121
- }
122
- #endif
123
-
124
- static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
125
- vk::PhysicalDeviceFeatures availableFeatures;
126
- physical_device.getFeatures(&availableFeatures);
127
-
128
- if (!availableFeatures.shaderInt16)
129
- return false;
130
-
131
- vk::PhysicalDeviceVulkan11Features availableFeatures11;
132
- vk::PhysicalDeviceVulkan12Features availableFeatures12;
133
-
134
- availableFeatures11.pNext = &availableFeatures12;
135
- availableFeatures12.pNext = nullptr;
136
-
137
- vk::PhysicalDeviceFeatures2 features2;
138
- features2.pNext = &availableFeatures11;
139
-
140
- physical_device.getFeatures2(&features2);
141
-
142
- if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
143
- !availableFeatures11.storageBuffer16BitAccess) {
144
- return false;
145
- }
146
-
147
- if (!availableFeatures12.storageBuffer8BitAccess ||
148
- !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
149
- !availableFeatures12.shaderFloat16 ||
150
- !availableFeatures12.shaderInt8) {
151
- return false;
152
- }
153
-
154
- return true;
155
- }
156
-
157
- static const char * ggml_vk_getVendorName(uint32_t vendorID) {
158
- switch (vendorID) {
159
- case 0x10DE:
160
- return "nvidia";
161
- case 0x1002:
162
- return "amd";
163
- case 0x8086:
164
- return "intel";
165
- default:
166
- return "unknown";
167
- }
168
- }
169
-
170
- static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
171
- std::vector<ggml_vk_device> results;
172
- if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
173
- return results;
174
-
175
- std::vector<vk::PhysicalDevice> physical_devices;
176
- try {
177
- physical_devices = komputeManager()->listDevices();
178
- } catch (vk::SystemError & err) {
179
- std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
180
- return results;
181
- }
182
-
183
- uint32_t deviceCount = physical_devices.size();
184
- if (deviceCount == 0)
185
- return results;
186
-
187
- std::unordered_map<std::string, size_t> count_by_name;
188
-
189
- for (uint32_t i = 0; i < deviceCount; i++) {
190
- const auto & physical_device = physical_devices[i];
191
-
192
- VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
193
- VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
194
- const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
195
- const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
196
- if (major < 1 || minor < 2)
197
- continue;
198
-
199
- if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
200
- continue;
201
-
202
- size_t heapSize = 0;
203
- for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
204
- VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
205
- if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
206
- heapSize = heap.size;
207
- break;
208
- }
209
- }
210
-
211
- if (heapSize < memoryRequired)
212
- continue;
213
-
214
- auto ext_props = physical_device.enumerateDeviceExtensionProperties();
215
- bool has_maintenance4 = false;
216
-
217
- // Check if maintenance4 is supported
218
- for (const auto & properties : ext_props) {
219
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
220
- has_maintenance4 = true;
221
- }
222
- }
223
-
224
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
225
- vk::PhysicalDeviceProperties2 dev_props2;
226
- vk::PhysicalDeviceMaintenance3Properties dev_props3;
227
- vk::PhysicalDeviceMaintenance4Properties dev_props4;
228
- dev_props2.pNext = &dev_props3;
229
- dev_props3.pNext = &subgroup_props;
230
- if (has_maintenance4) {
231
- subgroup_props.pNext = &dev_props4;
232
- }
233
- physical_device.getProperties2(&dev_props2);
234
-
235
- if (subgroup_props.subgroupSize < 32)
236
- continue;
237
-
238
- ggml_vk_device d;
239
- d.index = i;
240
- d.type = dev_props.deviceType;
241
- d.heapSize = heapSize;
242
- d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
243
- d.subgroupSize = subgroup_props.subgroupSize;
244
- d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
245
-
246
- if (has_maintenance4) {
247
- d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
248
- } else {
249
- d.maxAlloc = dev_props3.maxMemoryAllocationSize;
250
- }
251
-
252
- std::string name(dev_props.deviceName);
253
- size_t n_idx = ++count_by_name[name];
254
- if (n_idx > 1) {
255
- name += " (" + std::to_string(n_idx) + ")";
256
- }
257
- d.name = strdup(name.c_str());
258
-
259
- results.push_back(d);
260
- }
261
-
262
- std::stable_sort(results.begin(), results.end(),
263
- [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
264
- if (lhs.type != rhs.type) {
265
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
266
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
267
-
268
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
269
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
270
- }
271
- return lhs.heapSize < rhs.heapSize;
272
- }
273
- );
274
-
275
- return results;
276
- }
277
-
278
- static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
279
- static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
280
- return devices;
281
- }
282
-
283
- static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
284
- devices.erase(
285
- std::remove_if(devices.begin(), devices.end(),
286
- [&targetVendor](const ggml_vk_device& device) {
287
- return device.vendor != targetVendor;
288
- }),
289
- devices.end()
290
- );
291
- }
292
-
293
- static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
294
- devices.erase(
295
- std::remove_if(devices.begin(), devices.end(),
296
- [&targetName](const ggml_vk_device& device) {
297
- return device.name != targetName;
298
- }),
299
- devices.end()
300
- );
301
- }
302
-
303
- static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
304
- if (name.empty())
305
- return false;
306
-
307
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
308
- if (name == "amd" || name == "nvidia" || name == "intel") {
309
- ggml_vk_filterByVendor(devices, name);
310
- } else if (name != "gpu") {
311
- ggml_vk_filterByName(devices, name);
312
- }
313
-
314
- if (devices.empty())
315
- return false;
316
-
317
- *device = devices.front();
318
- return true;
319
- }
320
-
321
- bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
322
- return ggml_vk_get_device(device, memoryRequired, std::string(name));
323
- }
324
-
325
- bool ggml_vk_has_vulkan() {
326
- return komputeManager()->hasVulkan();
327
- }
328
-
329
- bool ggml_vk_has_device() {
330
- return komputeManager()->hasDevice();
331
- }
332
-
333
- ggml_vk_device ggml_vk_current_device() {
334
- if (!komputeManager()->hasDevice())
335
- return ggml_vk_device();
336
-
337
- auto devices = ggml_vk_available_devices();
338
- ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
339
- GGML_ASSERT(!devices.empty());
340
- return devices.front();
341
- }
342
-
343
- static
344
- void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
345
- std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
346
- vk::DescriptorPoolSize(
347
- vk::DescriptorType::eStorageBuffer,
348
- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
349
- )
350
- };
351
-
352
- vk::DescriptorPoolCreateInfo descriptorPoolInfo(
353
- vk::DescriptorPoolCreateFlags(),
354
- size, // Max sets
355
- static_cast<uint32_t>(descriptorPoolSizes.size()),
356
- descriptorPoolSizes.data());
357
-
358
- ctx->pool = std::make_shared<vk::DescriptorPool>();
359
- vk::Result r = komputeManager()->device()->createDescriptorPool(
360
- &descriptorPoolInfo, nullptr, ctx->pool.get());
361
- if (r != vk::Result::eSuccess)
362
- std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
363
- }
364
-
365
- static
366
- void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
367
- if (ctx->pool) {
368
- komputeManager()->device()->destroy(
369
- *ctx->pool,
370
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
371
- ctx->pool = nullptr;
372
- }
373
- }
374
-
375
- static
376
- vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
377
- vk::BufferCreateInfo bufferCreateInfo;
378
- bufferCreateInfo.size = size;
379
- bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
380
- vk::BufferUsageFlagBits::eTransferSrc |
381
- vk::BufferUsageFlagBits::eTransferDst;
382
- bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
383
-
384
- vk::Buffer *vkBuffer = new vk::Buffer;
385
- vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
386
- if (r != vk::Result::eSuccess)
387
- std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
388
- return vkBuffer;
389
- }
390
-
391
- static
392
- vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
393
-
394
- uint32_t memoryTypeIndex = -1;
395
- bool memoryTypeIndexFound = false;
396
- vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
397
- for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
398
- const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
399
- const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
400
- if (memoryHeap.size < size) {
401
- continue;
402
- }
403
-
404
- if (requirements.memoryTypeBits & (1 << i)) {
405
- if (((memoryProperties.memoryTypes[i]).propertyFlags &
406
- flags) == flags) {
407
- memoryTypeIndex = i;
408
- memoryTypeIndexFound = true;
409
- if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
410
- *isHostVisible = true;
411
- }
412
- break;
413
- }
414
- }
415
- }
416
- if (!memoryTypeIndexFound) {
417
- throw std::runtime_error(
418
- "Memory type index for buffer creation not found");
419
- }
420
-
421
- vk::MemoryAllocateInfo allocInfo;
422
- allocInfo.allocationSize = size;
423
- allocInfo.memoryTypeIndex = memoryTypeIndex;
424
- vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory;
425
- vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
426
- if (r != vk::Result::eSuccess) {
427
- std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
428
- throw std::runtime_error("Error allocating vulkan memory.");
429
- }
430
- return vkDeviceMemory;
431
- }
432
-
433
- static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
434
- size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
435
-
436
- // If offset is already aligned, return it directly
437
- if (offset % minStorageBufferOffsetAlignment == 0) {
438
- return offset;
439
- }
440
-
441
- // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
442
- return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
443
- }
444
-
445
- static ggml_vk_memory ggml_vk_allocate(size_t size) {
446
- ggml_vk_memory memory;
447
- bool isHostVisible = false;
448
- {
449
- memory.primaryBuffer = ggml_vk_allocate_buffer(size);
450
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
451
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
452
- memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
453
- komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
454
- if (isHostVisible) {
455
- vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
456
- if (r != vk::Result::eSuccess)
457
- std::cerr << "Error mapping memory" << vk::to_string(r);
458
- }
459
- }
460
-
461
- if (!isHostVisible) {
462
- memory.stagingBuffer = ggml_vk_allocate_buffer(size);
463
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
464
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
465
- vk::MemoryPropertyFlagBits::eHostCoherent |
466
- vk::MemoryPropertyFlagBits::eHostCached;
467
- memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
468
- komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
469
- vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
470
- if (r != vk::Result::eSuccess)
471
- std::cerr << "Error mapping memory" << vk::to_string(r);
472
- }
473
-
474
- memory.size = size;
475
- return memory;
476
- }
477
-
478
- static void ggml_vk_free_memory(ggml_vk_memory &memory)
479
- {
480
- komputeManager()->device()->destroy(
481
- *memory.primaryBuffer,
482
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
483
- if (memory.stagingBuffer) {
484
- komputeManager()->device()->destroy(
485
- *memory.stagingBuffer,
486
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
487
- }
488
- komputeManager()->device()->freeMemory(
489
- *memory.primaryMemory,
490
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
491
- if (memory.stagingMemory) {
492
- komputeManager()->device()->freeMemory(
493
- *memory.stagingMemory,
494
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
495
- }
496
- }
497
-
498
- static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
499
-
500
- static
501
- ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
502
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
503
-
504
- // compatibility with ggml-backend
505
- GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
506
-
507
- ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
508
-
509
- const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
510
-
511
- GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
512
-
513
- offset = uint64_t(ioffs);
514
- return buf_ctx;
515
- }
516
-
517
- static
518
- const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
519
- uint64_t originalOffset = 0;
520
- auto * res = ggml_vk_find_tensor(t, originalOffset);
521
- if (!res) {
522
- static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
523
- return nullTensor;
524
- }
525
-
526
- // Create a tensor whose memory will be composed of our buffers at the correct offset
527
- const size_t nelements = ggml_nelements(t);
528
- size_t nbytes = ggml_nbytes(t);
529
-
530
- size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
531
- if (alignedOffset) {
532
- *alignedOffset = originalOffset - vulkanOffset;
533
- nbytes += *alignedOffset;
534
- }
535
-
536
- return komputeManager()->tensor(
537
- t->data,
538
- nelements,
539
- nbytes, kp::Tensor::TensorDataTypes::eFloat,
540
- res->primaryMemory, res->primaryBuffer,
541
- res->stagingMemory, res->stagingBuffer,
542
- vulkanOffset);
543
- }
544
-
545
- static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
546
- if (size % sizeof(uint32_t) != 0) {
547
- throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
548
- }
549
-
550
- const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
551
- size_t count = size / sizeof(uint32_t);
552
- return std::vector<uint32_t>(data_ptr, data_ptr + count);
553
- }
554
-
555
- inline static
556
- uint32_t safe_divide(uint32_t a, uint32_t b) {
557
- if (b <= 1) {
558
- return a;
559
- }
560
- if ((a % b) != 0) {
561
- fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
562
- GGML_ABORT("safe_divide result would've had remainder");
563
- }
564
- return a / b;
565
- }
566
-
567
- static void ggml_vk_add(
568
- kp::Sequence& seq,
569
- const std::shared_ptr<kp::Tensor>& inA,
570
- const std::shared_ptr<kp::Tensor>& inB,
571
- const std::shared_ptr<kp::Tensor>& out,
572
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
573
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
574
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
575
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
576
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
577
- int32_t ne0,
578
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
579
- ) {
580
- const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
581
- kp::shader_data::op_add_comp_spv_len);
582
-
583
- struct PushConstants {
584
- uint32_t inAOff, inBOff, outOff;
585
- int32_t ne00;
586
- int32_t nb00, nb01, nb02, nb03;
587
- int32_t ne10, ne11, ne12, ne13;
588
- int32_t nb10, nb11, nb12, nb13;
589
- int32_t ne0;
590
- int32_t nb0, nb1, nb2, nb3;
591
- } const pushConsts {
592
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
593
- ne00,
594
- nb00, nb01, nb02, nb03,
595
- ne10, ne11, ne12, ne13,
596
- nb10, nb11, nb12, nb13,
597
- ne0,
598
- nb0, nb1, nb2, nb3
599
- };
600
-
601
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
602
- if (!komputeManager()->hasAlgorithm(__func__)) {
603
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
604
- } else {
605
- s_algo = komputeManager()->getAlgorithm(__func__);
606
- s_algo->setTensors({inA, inB, out});
607
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
608
- s_algo->setPushConstants<PushConstants>({pushConsts});
609
- s_algo->updateDescriptors(s_kompute_context->pool.get());
610
- }
611
- seq.record<kp::OpAlgoDispatch>(s_algo);
612
- }
613
-
614
- static void ggml_vk_addrow(kp::Sequence& seq,
615
- const std::shared_ptr<kp::Tensor>& inA,
616
- const std::shared_ptr<kp::Tensor>& inB,
617
- const std::shared_ptr<kp::Tensor>& out,
618
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
619
- uint32_t size, uint32_t row = 0) {
620
-
621
- const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
622
- kp::shader_data::op_addrow_comp_spv_len);
623
-
624
- struct PushConstants {
625
- uint32_t inAOff, inBOff, outOff;
626
- uint32_t row;
627
- } const pushConsts {
628
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
629
- row
630
- };
631
-
632
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
633
- if (!komputeManager()->hasAlgorithm(__func__))
634
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
635
- else {
636
- s_algo = komputeManager()->getAlgorithm(__func__);
637
- s_algo->setTensors({inA, inB, out});
638
- s_algo->setWorkgroup({size});
639
- s_algo->setPushConstants<PushConstants>({pushConsts});
640
- s_algo->updateDescriptors(s_kompute_context->pool.get());
641
- }
642
- seq.record<kp::OpAlgoDispatch>(s_algo);
643
- }
644
-
645
- static void ggml_vk_mul(
646
- kp::Sequence& seq,
647
- const std::shared_ptr<kp::Tensor>& inA,
648
- const std::shared_ptr<kp::Tensor>& inB,
649
- const std::shared_ptr<kp::Tensor>& out,
650
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
651
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
652
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
653
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
654
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
655
- int32_t ne0,
656
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
657
- ) {
658
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
659
- kp::shader_data::op_mul_comp_spv_len);
660
-
661
- struct PushConstants {
662
- uint32_t inAOff, inBOff, outOff;
663
- int32_t ne00;
664
- int32_t nb00, nb01, nb02, nb03;
665
- int32_t ne10, ne11, ne12, ne13;
666
- int32_t nb10, nb11, nb12, nb13;
667
- int32_t ne0;
668
- int32_t nb0, nb1, nb2, nb3;
669
- } const pushConsts {
670
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
671
- ne00,
672
- nb00, nb01, nb02, nb03,
673
- ne10, ne11, ne12, ne13,
674
- nb10, nb11, nb12, nb13,
675
- ne0,
676
- nb0, nb1, nb2, nb3
677
- };
678
-
679
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
680
- if (!komputeManager()->hasAlgorithm(__func__)) {
681
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
682
- } else {
683
- s_algo = komputeManager()->getAlgorithm(__func__);
684
- s_algo->setTensors({inA, inB, out});
685
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
686
- s_algo->setPushConstants<PushConstants>({pushConsts});
687
- s_algo->updateDescriptors(s_kompute_context->pool.get());
688
- }
689
- seq.record<kp::OpAlgoDispatch>(s_algo);
690
- }
691
-
692
- static void ggml_vk_scale(kp::Sequence& seq,
693
- const std::shared_ptr<kp::Tensor>& in,
694
- const std::shared_ptr<kp::Tensor>& out,
695
- uint32_t inOff, uint32_t outOff,
696
- uint32_t size, float scale) {
697
- const static auto spirv_1 = getSpirvShader(
698
- kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
699
- );
700
- const static auto spirv_8 = getSpirvShader(
701
- kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
702
- );
703
-
704
- struct PushConstants {
705
- uint32_t inOff, outOff;
706
- float scale;
707
- } const pushConsts {
708
- safe_divide(inOff, 4), safe_divide(outOff, 4),
709
- scale
710
- };
711
-
712
- const auto * spirv = &spirv_1;
713
- std::string name(__func__);
714
- if (size % 8 == 0) {
715
- size /= 8;
716
- name += "_8";
717
- spirv = &spirv_8;
718
- }
719
-
720
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
721
- if (!komputeManager()->hasAlgorithm(name)) {
722
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
723
- } else {
724
- s_algo = komputeManager()->getAlgorithm(name);
725
- s_algo->setTensors({in, out});
726
- s_algo->setWorkgroup({size});
727
- s_algo->setPushConstants<PushConstants>({pushConsts});
728
- s_algo->updateDescriptors(s_kompute_context->pool.get());
729
- }
730
- seq.record<kp::OpAlgoDispatch>(s_algo);
731
- }
732
-
733
- static void ggml_vk_xxlu(
734
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
735
- const std::shared_ptr<kp::Tensor>& in,
736
- const std::shared_ptr<kp::Tensor>& out,
737
- uint32_t inOff, uint32_t outOff,
738
- uint32_t size
739
- ) {
740
- struct PushConstants {
741
- uint32_t inOff, outOff;
742
- } const pushConsts {
743
- safe_divide(inOff, 4), safe_divide(outOff, 4),
744
- };
745
-
746
- auto name = std::string(__func__) + "_" + suffix;
747
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
748
- if (!komputeManager()->hasAlgorithm(name)) {
749
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
750
- } else {
751
- s_algo = komputeManager()->getAlgorithm(name);
752
- s_algo->setTensors({in, out});
753
- s_algo->setWorkgroup({size});
754
- s_algo->setPushConstants<PushConstants>({pushConsts});
755
- s_algo->updateDescriptors(s_kompute_context->pool.get());
756
- }
757
- seq.record<kp::OpAlgoDispatch>(s_algo);
758
- }
759
-
760
- template <typename... Args>
761
- static void ggml_vk_silu(Args&&... args) {
762
- const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
763
- kp::shader_data::op_silu_comp_spv_len);
764
-
765
- ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
766
- }
767
-
768
- template <typename... Args>
769
- static void ggml_vk_relu(Args&&... args) {
770
- const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
771
- kp::shader_data::op_relu_comp_spv_len);
772
-
773
- ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
774
- }
775
-
776
- template <typename... Args>
777
- static void ggml_vk_gelu(Args&&... args) {
778
- const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
779
- kp::shader_data::op_gelu_comp_spv_len);
780
-
781
- ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
782
- }
783
-
784
- static void ggml_vk_soft_max(
785
- kp::Sequence& seq,
786
- const std::shared_ptr<kp::Tensor>& inA,
787
- const std::shared_ptr<kp::Tensor>& inB,
788
- const std::shared_ptr<kp::Tensor>& out,
789
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
790
- int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
791
- float scale
792
- ) {
793
- const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
794
- kp::shader_data::op_softmax_comp_spv_len);
795
-
796
- struct PushConstants {
797
- uint32_t inAOff, inBOff, outOff;
798
- int32_t ne00, ne01, ne02;
799
- float scale;
800
- int32_t mask;
801
- } pushConsts {
802
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
803
- ne00, ne01, ne02,
804
- scale,
805
- bool(inB)
806
- };
807
-
808
- auto & inB_ = inB ? inB : inA;
809
-
810
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
811
- if (!komputeManager()->hasAlgorithm(__func__)) {
812
- // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
813
- const uint32_t local_x = 32;
814
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
815
- } else {
816
- s_algo = komputeManager()->getAlgorithm(__func__);
817
- s_algo->setTensors({inA, inB_, out});
818
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
819
- s_algo->setPushConstants<PushConstants>({pushConsts});
820
- s_algo->updateDescriptors(s_kompute_context->pool.get());
821
- }
822
- seq.record<kp::OpAlgoDispatch>(s_algo);
823
- }
824
-
825
- static void ggml_vk_norm_(
826
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
827
- const std::shared_ptr<kp::Tensor>& in,
828
- const std::shared_ptr<kp::Tensor>& out,
829
- uint32_t inOff, uint32_t outOff,
830
- int32_t ne00, int32_t nb01,
831
- int32_t nrows, float epsilon
832
- ) {
833
- GGML_ASSERT(nb01%sizeof(float) == 0);
834
- GGML_ASSERT(ne00%sizeof(float) == 0);
835
-
836
- struct PushConstants {
837
- uint32_t inOff, outOff;
838
- uint32_t ne00, nb01;
839
- float eps;
840
- } pushConsts {
841
- safe_divide(inOff, 4), safe_divide(outOff, 4),
842
- (uint32_t)ne00, (uint32_t)nb01, epsilon
843
- };
844
-
845
- auto name = std::string(__func__) + "_" + suffix;
846
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
847
- if (!komputeManager()->hasAlgorithm(name)) {
848
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
849
- } else {
850
- s_algo = komputeManager()->getAlgorithm(name);
851
- s_algo->setTensors({in, out});
852
- s_algo->setWorkgroup({(uint32_t)nrows});
853
- s_algo->setPushConstants<PushConstants>({pushConsts});
854
- s_algo->updateDescriptors(s_kompute_context->pool.get());
855
- }
856
- seq.record<kp::OpAlgoDispatch>(s_algo);
857
- }
858
-
859
- template <typename... Args>
860
- static void ggml_vk_norm(Args&&... args) {
861
- const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
862
- kp::shader_data::op_norm_comp_spv_len);
863
-
864
- ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
865
- }
866
-
867
- template <typename... Args>
868
- static void ggml_vk_rms_norm(Args&&... args) {
869
- const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
870
- kp::shader_data::op_rmsnorm_comp_spv_len);
871
-
872
- ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
873
- }
874
-
875
- static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
876
- const std::shared_ptr<kp::Tensor>& in,
877
- const std::shared_ptr<kp::Tensor>& out,
878
- uint32_t inOff, uint32_t outOff,
879
- uint32_t n_past,
880
- int32_t ne00, int32_t ne01, int32_t ne02) {
881
- const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
882
- kp::shader_data::op_diagmask_comp_spv_len);
883
-
884
- struct PushConstants {
885
- uint32_t inOff, outOff;
886
- uint32_t n_past;
887
- int32_t ne00, ne01;
888
- } pushConsts {
889
- safe_divide(inOff, 4), safe_divide(outOff, 4),
890
- n_past,
891
- ne00, ne01
892
- };
893
-
894
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
895
- if (!komputeManager()->hasAlgorithm(__func__))
896
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
897
- else {
898
- s_algo = komputeManager()->getAlgorithm(__func__);
899
- s_algo->setTensors({in, out});
900
- s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
901
- s_algo->setPushConstants<PushConstants>({pushConsts});
902
- s_algo->updateDescriptors(s_kompute_context->pool.get());
903
- }
904
- seq.record<kp::OpAlgoDispatch>(s_algo);
905
- }
906
-
907
- static void ggml_vk_mul_mat_f16(
908
- kp::Sequence& seq,
909
- const std::shared_ptr<kp::Tensor>& inA,
910
- const std::shared_ptr<kp::Tensor>& inB,
911
- const std::shared_ptr<kp::Tensor>& out,
912
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
913
- int32_t ne00, int32_t ne01, int32_t ne02,
914
- uint32_t nb00, uint32_t nb01, uint32_t nb02,
915
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
916
- uint32_t nb10, uint32_t nb11, uint32_t nb12,
917
- int32_t ne0, int32_t ne1,
918
- uint32_t r2, uint32_t r3
919
- ) {
920
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
921
- kp::shader_data::op_mul_mat_f16_comp_spv_len);
922
-
923
- struct PushConstants {
924
- uint32_t inAOff, inBOff, outOff;
925
- int32_t ne00, ne01, ne02;
926
- uint32_t nb00, nb01, nb02;
927
- int32_t ne10, ne11, ne12;
928
- uint32_t nb10, nb11, nb12;
929
- int32_t ne0, ne1;
930
- uint32_t r2, r3;
931
- } pushConsts {
932
- safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
933
- ne00, ne01, ne02,
934
- nb00, nb01, nb02,
935
- ne10, ne11, ne12,
936
- nb10, nb11, nb12,
937
- ne0, ne1,
938
- r2, r3
939
- };
940
-
941
- const unsigned ny = unsigned((ne11 + 4 - 1)/4);
942
-
943
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
944
- if (!komputeManager()->hasAlgorithm(__func__)) {
945
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
946
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
947
- } else {
948
- s_algo = komputeManager()->getAlgorithm(__func__);
949
- s_algo->setTensors({inA, inB, out});
950
- s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
951
- s_algo->setPushConstants<PushConstants>({pushConsts});
952
- s_algo->updateDescriptors(s_kompute_context->pool.get());
953
- }
954
- seq.record<kp::OpAlgoDispatch>(s_algo);
955
- }
956
-
957
- static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
958
- const std::shared_ptr<kp::Tensor>& inA,
959
- const std::shared_ptr<kp::Tensor>& inB,
960
- const std::shared_ptr<kp::Tensor>& out,
961
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
962
- int32_t ne00, int32_t ne01, int32_t ne02,
963
- uint32_t nb01, uint32_t nb02,
964
- int32_t ne11, int32_t ne12,
965
- uint32_t nb11, uint32_t nb12,
966
- uint32_t nb1, uint32_t nb2) {
967
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
968
- kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
969
-
970
- struct PushConstants {
971
- uint32_t inAOff, inBOff, outOff;
972
- int32_t ne00, ne01, ne02, ne11, ne12;
973
- uint32_t nb01, nb02;
974
- uint32_t nb11, nb12;
975
- uint32_t nb1, nb2;
976
- } pushConsts {
977
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
978
- ne00, ne01, ne02, ne11, ne12,
979
- nb01, nb02, nb11, nb12,
980
- nb1, nb2
981
- };
982
-
983
- const uint32_t local_x = ggml_vk_current_device().subgroupSize;
984
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
985
- if (!komputeManager()->hasAlgorithm(__func__)) {
986
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
987
- {inA, inB, out}, spirv,
988
- {unsigned(ne01),
989
- unsigned(ne11),
990
- unsigned(std::max(ne12, ne02))
991
- },
992
- {local_x},
993
- {pushConsts});
994
- } else {
995
- s_algo = komputeManager()->getAlgorithm(__func__);
996
- s_algo->setTensors({inA, inB, out});
997
- s_algo->setWorkgroup({unsigned(ne01),
998
- unsigned(ne11),
999
- unsigned(std::max(ne12, ne02)),
1000
- });
1001
- s_algo->setPushConstants<PushConstants>({pushConsts});
1002
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1003
- }
1004
- seq.record<kp::OpAlgoDispatch>(s_algo);
1005
- }
1006
-
1007
- static void ggml_vk_mul_mat_impl(
1008
- const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
1009
- const std::shared_ptr<kp::Tensor>& inA,
1010
- const std::shared_ptr<kp::Tensor>& inB,
1011
- const std::shared_ptr<kp::Tensor>& out,
1012
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1013
- int32_t ne00, int32_t ne01, int32_t ne02,
1014
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1015
- int32_t ne0, int32_t ne1,
1016
- uint32_t r2, uint32_t r3
1017
- ) {
1018
- struct PushConstants {
1019
- uint32_t inAOff, inBOff, outOff;
1020
- int32_t ne00, ne01, ne02;
1021
- int32_t ne10, ne12;
1022
- int32_t ne0, ne1;
1023
- uint32_t r2, r3;
1024
- } pushConsts {
1025
- safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1026
- ne00, ne01, ne02,
1027
- ne10, ne12,
1028
- ne0, ne1,
1029
- r2, r3
1030
- };
1031
-
1032
- auto name = std::string(__func__) + "_" + suffix;
1033
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1034
- if (!komputeManager()->hasAlgorithm(name)) {
1035
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1036
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1037
- } else {
1038
- s_algo = komputeManager()->getAlgorithm(name);
1039
- s_algo->setTensors({inA, inB, out});
1040
- s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
1041
- s_algo->setPushConstants<PushConstants>({pushConsts});
1042
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1043
- }
1044
- seq.record<kp::OpAlgoDispatch>(s_algo);
1045
- }
1046
-
1047
- template <typename... Args>
1048
- static void ggml_vk_mul_mat_q4_0(Args&&... args) {
1049
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
1050
- kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
1051
-
1052
- ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1053
- }
1054
-
1055
- template <typename... Args>
1056
- static void ggml_vk_mul_mat_q4_1(Args&&... args) {
1057
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
1058
- kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
1059
-
1060
- ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1061
- }
1062
-
1063
- template <typename... Args>
1064
- static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1065
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
1066
- kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
1067
-
1068
- ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1069
- }
1070
-
1071
- static void ggml_vk_mul_mat_q4_k(
1072
- kp::Sequence& seq,
1073
- const std::shared_ptr<kp::Tensor>& inA,
1074
- const std::shared_ptr<kp::Tensor>& inB,
1075
- const std::shared_ptr<kp::Tensor>& out,
1076
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1077
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1078
- int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1079
- int32_t ne1, int32_t r2, int32_t r3
1080
- ) {
1081
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1082
- kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1083
-
1084
- struct PushConstants {
1085
- uint32_t inAOff, inBOff, outOff;
1086
- int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1087
- } pushConsts {
1088
- 0, 0, 0,
1089
- ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1090
- };
1091
-
1092
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1093
- if (!komputeManager()->hasAlgorithm(__func__)) {
1094
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1095
- } else {
1096
- s_algo = komputeManager()->getAlgorithm(__func__);
1097
- s_algo->setTensors({inA, inB, out});
1098
- s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1099
- s_algo->setPushConstants<PushConstants>({pushConsts});
1100
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1101
- }
1102
- seq.record<kp::OpAlgoDispatch>(s_algo);
1103
- }
1104
-
1105
- static void ggml_vk_mul_mat_q6_k(
1106
- kp::Sequence& seq,
1107
- const std::shared_ptr<kp::Tensor>& inA,
1108
- const std::shared_ptr<kp::Tensor>& inB,
1109
- const std::shared_ptr<kp::Tensor>& out,
1110
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1111
- int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1112
- int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1113
- ) {
1114
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1115
- kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1116
-
1117
- struct PushConstants {
1118
- uint32_t inAOff, inBOff, outOff;
1119
- int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1120
- } pushConsts {
1121
- inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1122
- ne00, ne10, ne0, ne1, ne01, ne12/ne02
1123
- };
1124
-
1125
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1126
- if (!komputeManager()->hasAlgorithm(__func__)) {
1127
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1128
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1129
- } else {
1130
- s_algo = komputeManager()->getAlgorithm(__func__);
1131
- s_algo->setTensors({inA, inB, out});
1132
- s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1133
- s_algo->setPushConstants<PushConstants>({pushConsts});
1134
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1135
- }
1136
- seq.record<kp::OpAlgoDispatch>(s_algo);
1137
- }
1138
-
1139
- static void ggml_vk_get_rows(
1140
- const std::vector<uint32_t>& spirv,
1141
- const char * suffix,
1142
- unsigned element_size, unsigned qk,
1143
- kp::Sequence& seq,
1144
- const std::shared_ptr<kp::Tensor>& inA,
1145
- const std::shared_ptr<kp::Tensor>& inB,
1146
- const std::shared_ptr<kp::Tensor>& out,
1147
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1148
- int32_t ne00, int32_t nb01, int32_t nb1,
1149
- uint32_t size
1150
- ) {
1151
- GGML_ASSERT(nb01%element_size == 0);
1152
- GGML_ASSERT(nb1%sizeof(float) == 0);
1153
- if (qk) GGML_ASSERT(ne00%qk == 0);
1154
-
1155
- struct PushConstants {
1156
- uint32_t inAOff, inBOff, outOff;
1157
- int32_t ne00, nb01, nb1;
1158
- } pushConsts {
1159
- safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1160
- ne00, nb01, nb1
1161
- };
1162
-
1163
- auto name = std::string(__func__) + "_" + suffix;
1164
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1165
- if (!komputeManager()->hasAlgorithm(name)) {
1166
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
1167
- } else {
1168
- s_algo = komputeManager()->getAlgorithm(name);
1169
- s_algo->setTensors({inA, inB, out});
1170
- s_algo->setWorkgroup({size});
1171
- s_algo->setPushConstants<PushConstants>({pushConsts});
1172
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1173
- }
1174
- seq.record<kp::OpAlgoDispatch>(s_algo);
1175
- }
1176
-
1177
- template <typename... Args>
1178
- static void ggml_vk_get_rows_f32(Args&&... args) {
1179
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1180
- kp::shader_data::op_getrows_f32_comp_spv_len);
1181
-
1182
- ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1183
- }
1184
-
1185
- template <typename... Args>
1186
- static void ggml_vk_get_rows_f16(Args&&... args) {
1187
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
1188
- kp::shader_data::op_getrows_f16_comp_spv_len);
1189
-
1190
- ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
1191
- }
1192
-
1193
- template <typename... Args>
1194
- static void ggml_vk_get_rows_q4_0(Args&&... args) {
1195
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
1196
- kp::shader_data::op_getrows_q4_0_comp_spv_len);
1197
-
1198
- ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
1199
- }
1200
-
1201
- template <typename... Args>
1202
- static void ggml_vk_get_rows_q4_1(Args&&... args) {
1203
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
1204
- kp::shader_data::op_getrows_q4_1_comp_spv_len);
1205
-
1206
- ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
1207
- }
1208
-
1209
- template <typename... Args>
1210
- static void ggml_vk_get_rows_q6_k(Args&&... args) {
1211
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
1212
- kp::shader_data::op_getrows_q6_k_comp_spv_len);
1213
- ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
1214
- }
1215
-
1216
- static void ggml_vk_rope(
1217
- kp::Sequence& seq,
1218
- const std::shared_ptr<kp::Tensor>& inA,
1219
- const std::shared_ptr<kp::Tensor>& inB,
1220
- const std::shared_ptr<kp::Tensor>& out,
1221
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1222
- ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1223
- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1224
- int32_t ne01, int32_t ne02, int32_t ne03,
1225
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1226
- int32_t ne0,
1227
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1228
- ) {
1229
- GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1230
-
1231
- static const auto spirv_f16 = getSpirvShader(
1232
- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1233
- );
1234
- static const auto spirv_f32 = getSpirvShader(
1235
- kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1236
- );
1237
-
1238
- int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
1239
-
1240
- GGML_ASSERT(nb03 % type_size == 0);
1241
- GGML_ASSERT(nb02 % type_size == 0);
1242
- GGML_ASSERT(nb01 % type_size == 0);
1243
- GGML_ASSERT(nb00 % type_size == 0);
1244
- GGML_ASSERT(nb3 % type_size == 0);
1245
- GGML_ASSERT(nb2 % type_size == 0);
1246
- GGML_ASSERT(nb1 % type_size == 0);
1247
- GGML_ASSERT(nb0 % type_size == 0);
1248
-
1249
- struct PushConstants {
1250
- uint32_t inAOff, inBOff, outOff;
1251
- int32_t n_dims, mode, n_ctx_orig;
1252
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1253
- uint32_t nb00, nb01, nb02, nb03;
1254
- int32_t ne0;
1255
- uint32_t nb0, nb1, nb2, nb3;
1256
- } pushConsts {
1257
- safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1258
- n_dims, mode, n_ctx_orig,
1259
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1260
- nb00, nb01, nb02, nb03,
1261
- ne0,
1262
- nb0, nb1, nb2, nb3
1263
- };
1264
-
1265
- auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1266
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1267
- if (!komputeManager()->hasAlgorithm(name)) {
1268
- s_algo = komputeManager()->algorithm<float, PushConstants>(
1269
- name, s_kompute_context->pool.get(), {inA, inB, out},
1270
- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1271
- {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1272
- );
1273
- } else {
1274
- s_algo = komputeManager()->getAlgorithm(name);
1275
- s_algo->setTensors({inA, inB, out});
1276
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1277
- s_algo->setPushConstants<PushConstants>({pushConsts});
1278
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1279
- }
1280
- seq.record<kp::OpAlgoDispatch>(s_algo);
1281
- }
1282
-
1283
- static void ggml_vk_cpy(
1284
- const std::vector<uint32_t>& spirv,
1285
- uint32_t in_element_size, uint32_t out_element_size,
1286
- kp::Sequence& seq,
1287
- const std::shared_ptr<kp::Tensor>& in,
1288
- const std::shared_ptr<kp::Tensor>& out,
1289
- uint32_t inOff, uint32_t outOff,
1290
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
1291
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1292
- int32_t ne0, int32_t ne1, int32_t ne2,
1293
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1294
- ) {
1295
- struct PushConstants {
1296
- uint32_t inOff, outOff;
1297
- int32_t ne00, ne01, ne02;
1298
- uint32_t nb00, nb01, nb02, nb03;
1299
- int32_t ne0, ne1, ne2;
1300
- uint32_t nb0, nb1, nb2, nb3;
1301
- } pushConsts {
1302
- safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
1303
- ne00, ne01, ne02,
1304
- nb00, nb01, nb02, nb03,
1305
- ne0, ne1, ne2,
1306
- nb0, nb1, nb2, nb3
1307
- };
1308
-
1309
- std::string name = std::string(__func__)
1310
- + "_i_" + std::to_string(in_element_size)
1311
- + "_o_" + std::to_string(out_element_size);
1312
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1313
- if (!komputeManager()->hasAlgorithm(name))
1314
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
1315
- else {
1316
- s_algo = komputeManager()->getAlgorithm(name);
1317
- s_algo->setTensors({in, out});
1318
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1319
- s_algo->setPushConstants<PushConstants>({pushConsts});
1320
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1321
- }
1322
- seq.record<kp::OpAlgoDispatch>(s_algo);
1323
- }
1324
-
1325
- template <typename... Args>
1326
- static void ggml_vk_cpy_f32_f16(Args&&... args) {
1327
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
1328
- kp::shader_data::op_cpy_f32_f16_comp_spv_len);
1329
- ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
1330
- }
1331
-
1332
- template <typename... Args>
1333
- static void ggml_vk_cpy_f32_f32(Args&&... args) {
1334
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
1335
- kp::shader_data::op_cpy_f32_f32_comp_spv_len);
1336
- ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
1337
- }
1338
-
1339
- template <typename... Args>
1340
- static void ggml_vk_cpy_f16_f16(Args&&... args) {
1341
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
1342
- kp::shader_data::op_cpy_f16_f16_comp_spv_len);
1343
- ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
1344
- }
1345
-
1346
- template <typename... Args>
1347
- static void ggml_vk_cpy_f16_f32(Args&&... args) {
1348
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
1349
- kp::shader_data::op_cpy_f16_f32_comp_spv_len);
1350
- ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1351
- }
1352
-
1353
- static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1354
- switch (op->op) {
1355
- case GGML_OP_UNARY:
1356
- switch (ggml_get_unary_op(op)) {
1357
- case GGML_UNARY_OP_RELU:
1358
- case GGML_UNARY_OP_GELU:
1359
- case GGML_UNARY_OP_SILU:
1360
- return ggml_is_contiguous(op->src[0]);
1361
- default:
1362
- ;
1363
- }
1364
- break;
1365
- case GGML_OP_NONE:
1366
- case GGML_OP_RESHAPE:
1367
- case GGML_OP_VIEW:
1368
- case GGML_OP_TRANSPOSE:
1369
- case GGML_OP_PERMUTE:
1370
- case GGML_OP_ADD:
1371
- case GGML_OP_MUL:
1372
- case GGML_OP_SCALE:
1373
- case GGML_OP_SOFT_MAX:
1374
- case GGML_OP_RMS_NORM:
1375
- case GGML_OP_NORM:
1376
- case GGML_OP_ROPE:
1377
- return true;
1378
- case GGML_OP_DUP:
1379
- case GGML_OP_CPY:
1380
- case GGML_OP_CONT:
1381
- switch (op->src[0]->type) {
1382
- case GGML_TYPE_F32:
1383
- case GGML_TYPE_F16:
1384
- break;
1385
- default:
1386
- return false;
1387
- }
1388
- switch (op->type) {
1389
- case GGML_TYPE_F32:
1390
- case GGML_TYPE_F16:
1391
- break;
1392
- default:
1393
- return false;
1394
- }
1395
- return true;
1396
- case GGML_OP_DIAG_MASK_INF:
1397
- return op->ne[3] == 1;
1398
- case GGML_OP_GET_ROWS:
1399
- switch (op->src[0]->type) {
1400
- case GGML_TYPE_F32:
1401
- case GGML_TYPE_F16:
1402
- case GGML_TYPE_Q4_0:
1403
- case GGML_TYPE_Q4_1:
1404
- case GGML_TYPE_Q6_K:
1405
- return op->ne[2] == 1 && op->ne[3] == 1;
1406
- default:
1407
- ;
1408
- }
1409
- return false;
1410
- case GGML_OP_MUL_MAT:
1411
- if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
1412
- return false;
1413
-
1414
- switch (op->src[0]->type) {
1415
- case GGML_TYPE_F32:
1416
- case GGML_TYPE_Q6_K:
1417
- return op->ne[3] == 1;
1418
- case GGML_TYPE_F16:
1419
- case GGML_TYPE_Q8_0:
1420
- case GGML_TYPE_Q4_0:
1421
- case GGML_TYPE_Q4_1:
1422
- case GGML_TYPE_Q4_K:
1423
- return true;
1424
- default:
1425
- ;
1426
- }
1427
- default:
1428
- ;
1429
- }
1430
- return false;
1431
-
1432
- GGML_UNUSED(dev);
1433
- }
1434
-
1435
- static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
1436
- const int n_seq = 8;
1437
-
1438
- // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
1439
- // it to the size of the graph, but I think it can be made smaller?
1440
- ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
1441
-
1442
- std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
1443
-
1444
- for (auto& sequence : sequences) {
1445
- sequence = komputeManager()->sequence();
1446
- }
1447
- for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
1448
- const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
1449
-
1450
- auto& seq = *sequences[seq_idx];
1451
-
1452
- const int node_start = (seq_idx + 0) * n_nodes_per_seq;
1453
- const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
1454
-
1455
- bool any_commands_recorded = false;
1456
-
1457
- for (int i = node_start; i < node_end; ++i) {
1458
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1459
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1460
- struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1461
- struct ggml_tensor * dst = gf->nodes[i];
1462
- GGML_ASSERT(dst->data != nullptr);
1463
-
1464
- if (ggml_is_empty(dst)) {
1465
- continue;
1466
- }
1467
-
1468
- switch (dst->op) {
1469
- case GGML_OP_NONE:
1470
- case GGML_OP_RESHAPE:
1471
- case GGML_OP_VIEW:
1472
- case GGML_OP_TRANSPOSE:
1473
- case GGML_OP_PERMUTE:
1474
- continue; // noop -> next node
1475
- default:
1476
- break;
1477
- }
1478
-
1479
- any_commands_recorded = true;
1480
-
1481
- const int32_t ne00 = src0 ? src0->ne[0] : 0;
1482
- const int32_t ne01 = src0 ? src0->ne[1] : 0;
1483
- const int32_t ne02 = src0 ? src0->ne[2] : 0;
1484
- const int32_t ne03 = src0 ? src0->ne[3] : 0;
1485
-
1486
- const uint32_t nb00 = src0 ? src0->nb[0] : 0;
1487
- const uint32_t nb01 = src0 ? src0->nb[1] : 0;
1488
- const uint32_t nb02 = src0 ? src0->nb[2] : 0;
1489
- const uint32_t nb03 = src0 ? src0->nb[3] : 0;
1490
-
1491
- const int32_t ne10 = src1 ? src1->ne[0] : 0;
1492
- const int32_t ne11 = src1 ? src1->ne[1] : 0;
1493
- const int32_t ne12 = src1 ? src1->ne[2] : 0;
1494
- const int32_t ne13 = src1 ? src1->ne[3] : 0;
1495
-
1496
- const uint32_t nb10 = src1 ? src1->nb[0] : 0;
1497
- const uint32_t nb11 = src1 ? src1->nb[1] : 0;
1498
- const uint32_t nb12 = src1 ? src1->nb[2] : 0;
1499
- const uint32_t nb13 = src1 ? src1->nb[3] : 0;
1500
-
1501
- const int32_t ne0 = dst ? dst->ne[0] : 0;
1502
- const int32_t ne1 = dst ? dst->ne[1] : 0;
1503
- const int32_t ne2 = dst ? dst->ne[2] : 0;
1504
- // const int32_t ne3 = dst ? dst->ne[3] : 0;
1505
-
1506
- const uint32_t nb0 = dst ? dst->nb[0] : 0;
1507
- const uint32_t nb1 = dst ? dst->nb[1] : 0;
1508
- const uint32_t nb2 = dst ? dst->nb[2] : 0;
1509
- const uint32_t nb3 = dst ? dst->nb[3] : 0;
1510
-
1511
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1512
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1513
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1514
-
1515
- const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1516
- uint32_t off_src0 = 0;
1517
- uint32_t off_src1 = 0;
1518
- uint32_t off_dst = 0;
1519
- const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1520
- const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1521
- const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
1522
-
1523
- switch (dst->op) {
1524
- case GGML_OP_ADD:
1525
- {
1526
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1527
- // src1 is a row
1528
- ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
1529
- } else {
1530
- ggml_vk_add(
1531
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1532
- ne00, ne01, ne02, ne03,
1533
- nb00, nb01, nb02, nb03,
1534
- ne10, ne11, ne12, ne13,
1535
- nb10, nb11, nb12, nb13,
1536
- ne0,
1537
- nb0, nb1, nb2, nb3
1538
- );
1539
- }
1540
- } break;
1541
- case GGML_OP_MUL:
1542
- {
1543
- ggml_vk_mul(
1544
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1545
- ne00, ne01, ne02, ne03,
1546
- nb00, nb01, nb02, nb03,
1547
- ne10, ne11, ne12, ne13,
1548
- nb10, nb11, nb12, nb13,
1549
- ne0,
1550
- nb0, nb1, nb2, nb3
1551
- );
1552
- } break;
1553
- case GGML_OP_SCALE:
1554
- {
1555
- float scale; memcpy(&scale, dst->op_params, sizeof(float));
1556
-
1557
- ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
1558
- } break;
1559
- case GGML_OP_UNARY:
1560
- {
1561
- int64_t n = ggml_nelements(dst);
1562
- GGML_ASSERT(n % 4 == 0);
1563
- switch (ggml_get_unary_op(gf->nodes[i])) {
1564
- case GGML_UNARY_OP_SILU:
1565
- {
1566
- ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1567
- } break;
1568
- case GGML_UNARY_OP_RELU:
1569
- {
1570
- ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1571
- } break;
1572
- case GGML_UNARY_OP_GELU:
1573
- {
1574
- GGML_ASSERT(n % 8 == 0);
1575
- ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
1576
- } break;
1577
- default:
1578
- {
1579
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1580
- GGML_ABORT("fatal error");
1581
- }
1582
- }
1583
- } break;
1584
- case GGML_OP_SOFT_MAX:
1585
- {
1586
- float scale;
1587
- float max_bias;
1588
-
1589
- memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1590
- memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1591
-
1592
- #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1593
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1594
- GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1595
-
1596
- #pragma message("TODO: add ALiBi support")
1597
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1598
- GGML_ASSERT(max_bias == 0.0f);
1599
-
1600
- ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1601
- } break;
1602
- case GGML_OP_DIAG_MASK_INF:
1603
- {
1604
- const int n_past = ((int32_t *)(dst->op_params))[0];
1605
- ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
1606
- } break;
1607
- case GGML_OP_NORM:
1608
- {
1609
- float eps;
1610
- memcpy(&eps, dst->op_params, sizeof(float));
1611
- ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1612
- } break;
1613
- case GGML_OP_RMS_NORM:
1614
- {
1615
- GGML_ASSERT(ne00 % 4 == 0);
1616
-
1617
- float eps;
1618
- memcpy(&eps, dst->op_params, sizeof(float));
1619
- ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1620
- } break;
1621
- case GGML_OP_MUL_MAT:
1622
- {
1623
- GGML_ASSERT(ne00 == ne10);
1624
-
1625
- GGML_ASSERT(ne12 % ne02 == 0);
1626
- GGML_ASSERT(ne13 % ne03 == 0);
1627
-
1628
- const uint32_t r2 = ne12/ne02;
1629
- const uint32_t r3 = ne13/ne03;
1630
-
1631
- if (src1t != GGML_TYPE_F32) {
1632
- fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1633
- goto not_implemented;
1634
- }
1635
-
1636
- if (ggml_is_transposed(src0) ||
1637
- ggml_is_transposed(src1)) {
1638
- fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1639
- goto not_implemented;
1640
- }
1641
-
1642
- switch (src0t) {
1643
- case GGML_TYPE_F32:
1644
- ggml_vk_mul_mat_mat_f32(
1645
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1646
- ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
1647
- );
1648
- break;
1649
- case GGML_TYPE_F16:
1650
- ggml_vk_mul_mat_f16(
1651
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1652
- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1653
- ne0, ne1, r2, r3
1654
- );
1655
- break;
1656
- case GGML_TYPE_Q8_0:
1657
- ggml_vk_mul_mat_q8_0(
1658
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1659
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1660
- );
1661
- break;
1662
- case GGML_TYPE_Q4_0:
1663
- ggml_vk_mul_mat_q4_0(
1664
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1665
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1666
- );
1667
- break;
1668
- case GGML_TYPE_Q4_1:
1669
- ggml_vk_mul_mat_q4_1(
1670
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1671
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1672
- );
1673
- break;
1674
- case GGML_TYPE_Q4_K:
1675
- ggml_vk_mul_mat_q4_k(
1676
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1677
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1678
- );
1679
- break;
1680
- case GGML_TYPE_Q6_K:
1681
- ggml_vk_mul_mat_q6_k(
1682
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1683
- ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1684
- );
1685
- break;
1686
- default: {
1687
- fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1688
- goto not_implemented;
1689
- }
1690
- }
1691
-
1692
- } break;
1693
- case GGML_OP_GET_ROWS:
1694
- {
1695
- if (src0t == GGML_TYPE_F32) {
1696
- ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1697
- } else if (src0t == GGML_TYPE_F16) {
1698
- ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1699
- } else if (src0t == GGML_TYPE_Q4_0) {
1700
- ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1701
- } else if (src0t == GGML_TYPE_Q4_1) {
1702
- ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1703
- } else if (src0t == GGML_TYPE_Q6_K) {
1704
- ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1705
- } else {
1706
- fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
1707
- goto not_implemented;
1708
- }
1709
- } break;
1710
- case GGML_OP_ROPE:
1711
- {
1712
- #pragma message("TODO: implement phi3 frequency factors support")
1713
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1714
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1715
-
1716
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
1717
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1718
-
1719
- GGML_ASSERT(ne10 == ne02);
1720
- GGML_ASSERT(src0t == dstt);
1721
- // const int n_past = ((int32_t *) dst->op_params)[0];
1722
- const int n_dims = ((int32_t *) dst->op_params)[1];
1723
- const int mode = ((int32_t *) dst->op_params)[2];
1724
- // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1725
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1726
-
1727
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1728
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1729
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1730
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1731
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1732
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1733
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1734
- ggml_vk_rope(
1735
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1736
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1737
- ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1738
- );
1739
- } break;
1740
- case GGML_OP_DUP:
1741
- case GGML_OP_CPY:
1742
- case GGML_OP_CONT:
1743
- {
1744
- switch (src0t) {
1745
- case GGML_TYPE_F32:
1746
- {
1747
- switch (dstt) {
1748
- case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1749
- case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1750
- default: goto not_implemented;
1751
- }
1752
- } break;
1753
- case GGML_TYPE_F16:
1754
- {
1755
- switch (dstt) {
1756
- case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1757
- case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1758
- default: goto not_implemented;
1759
- } break;
1760
- default: goto not_implemented;
1761
- }
1762
- }
1763
- } break;
1764
- default: goto not_implemented;
1765
- }
1766
- continue;
1767
- not_implemented: {}
1768
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1769
- //GGML_ABORT("fatal error");
1770
- }
1771
-
1772
- // Evaluate sequence
1773
- if (any_commands_recorded) {
1774
- seq.evalAsync();
1775
- }
1776
- }
1777
-
1778
- // Wait for all sequences to finish
1779
- for (auto& sequence : sequences) {
1780
- if (sequence->isRunning())
1781
- sequence->evalAwait();
1782
- }
1783
-
1784
- ggml_vk_free_descriptor_pool(ctx);
1785
- }
1786
-
1787
- template<>
1788
- kp::Tensor::TensorDataTypes
1789
- kp::TensorT<half>::dataType()
1790
- {
1791
- return TensorDataTypes::eFloat;
1792
- }
1793
-
1794
- template<>
1795
- kp::Tensor::TensorDataTypes
1796
- kp::TensorT<uint8_t>::dataType()
1797
- {
1798
- return TensorDataTypes::eUnsignedInt;
1799
- }
1800
-
1801
- ////////////////////////////////////////////////////////////////////////////////
1802
-
1803
- // backend interface
1804
-
1805
- struct ggml_backend_kompute_buffer_type_context {
1806
- int device;
1807
- int device_ref = 0;
1808
- uint64_t buffer_alignment;
1809
- uint64_t max_alloc;
1810
- std::string name;
1811
-
1812
- ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
1813
- : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
1814
- };
1815
-
1816
- static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
1817
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1818
-
1819
- if (!ctx->device_ref) {
1820
- komputeManager()->initializeDevice(
1821
- ctx->device, {}, {
1822
- "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
1823
- "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
1824
- }
1825
- );
1826
- }
1827
-
1828
- assert(ggml_vk_has_device());
1829
- ctx->device_ref++;
1830
- }
1831
-
1832
- static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1833
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1834
-
1835
- assert(ctx->device_ref > 0);
1836
-
1837
- ctx->device_ref--;
1838
-
1839
- if (!ctx->device_ref) {
1840
- komputeManager.destroy();
1841
- }
1842
- }
1843
-
1844
- static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1845
- auto * memory = (ggml_vk_memory *)buffer->context;
1846
- if (ggml_vk_has_device()) {
1847
- ggml_vk_free_memory(*memory);
1848
- }
1849
- delete memory;
1850
- }
1851
-
1852
- static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
1853
- return ((ggml_vk_memory *)buffer->context)->data;
1854
- }
1855
-
1856
- static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1857
- GGML_UNUSED(buffer);
1858
-
1859
- const auto res = ggml_vk_get_tensor(tensor);
1860
- GGML_ASSERT(res);
1861
-
1862
- memcpy((char *)tensor->data + offset, data, size);
1863
-
1864
- komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
1865
- }
1866
-
1867
- static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1868
- GGML_UNUSED(buffer);
1869
-
1870
- const auto res = ggml_vk_get_tensor(tensor);
1871
- GGML_ASSERT(res);
1872
-
1873
- komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
1874
-
1875
- memcpy(data, (const char *)tensor->data + offset, size);
1876
- }
1877
-
1878
- static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1879
- auto * memory = (ggml_vk_memory *)buffer->context;
1880
- memset(memory->data, value, buffer->size);
1881
-
1882
- if (memory->stagingBuffer)
1883
- komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
1884
- }
1885
-
1886
- static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1887
- /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1888
- /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1889
- /* .init_tensor = */ NULL,
1890
- /* .memset_tensor = */ NULL,
1891
- /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
1892
- /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
1893
- /* .cpy_tensor = */ NULL,
1894
- /* .clear = */ ggml_backend_kompute_buffer_clear,
1895
- /* .reset = */ NULL,
1896
- };
1897
-
1898
- // default buffer type
1899
-
1900
- static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1901
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1902
- return ctx->name.c_str();
1903
- }
1904
-
1905
- static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1906
- ggml_backend_kompute_device_ref(buft);
1907
- auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
1908
- return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
1909
- }
1910
-
1911
- static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1912
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1913
- return ctx->buffer_alignment;
1914
- }
1915
-
1916
- static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1917
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1918
- return ctx->max_alloc;
1919
- }
1920
-
1921
- static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1922
- /* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
1923
- /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
1924
- /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
1925
- /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
1926
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1927
- /* .is_host = */ NULL,
1928
- };
1929
-
1930
- ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1931
- static std::mutex mutex;
1932
- std::lock_guard<std::mutex> lock(mutex);
1933
-
1934
- auto devices = ggml_vk_available_devices();
1935
- int32_t device_count = (int32_t) devices.size();
1936
- GGML_ASSERT(device < device_count);
1937
- GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
1938
-
1939
- static ggml_backend_buffer_type
1940
- ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
1941
-
1942
- static bool ggml_backend_kompute_buffer_type_initialized = false;
1943
-
1944
- if (!ggml_backend_kompute_buffer_type_initialized) {
1945
- for (int32_t i = 0; i < device_count; i++) {
1946
- ggml_backend_kompute_buffer_types[i] = {
1947
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1948
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
1949
- /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
1950
- };
1951
- }
1952
- ggml_backend_kompute_buffer_type_initialized = true;
1953
- }
1954
-
1955
- return &ggml_backend_kompute_buffer_types[device];
1956
- }
1957
-
1958
- // backend
1959
-
1960
- static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
1961
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1962
- return ctx->name.c_str();
1963
- }
1964
-
1965
- static void ggml_backend_kompute_free(ggml_backend_t backend) {
1966
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1967
-
1968
- assert(ctx == s_kompute_context);
1969
- s_kompute_context = nullptr;
1970
- if (ctx != nullptr) {
1971
- delete ctx;
1972
- }
1973
-
1974
- delete backend;
1975
- }
1976
-
1977
- static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1978
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1979
- ggml_vk_graph_compute(ctx, cgraph);
1980
- return GGML_STATUS_SUCCESS;
1981
- }
1982
-
1983
- static struct ggml_backend_i kompute_backend_i = {
1984
- /* .get_name = */ ggml_backend_kompute_name,
1985
- /* .free = */ ggml_backend_kompute_free,
1986
- /* .set_tensor_async = */ NULL,
1987
- /* .get_tensor_async = */ NULL,
1988
- /* .cpy_tensor_async = */ NULL,
1989
- /* .synchronize = */ NULL,
1990
- /* .graph_plan_create = */ NULL,
1991
- /* .graph_plan_free = */ NULL,
1992
- /* .graph_plan_update = */ NULL,
1993
- /* .graph_plan_compute = */ NULL,
1994
- /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1995
- /* .event_record = */ NULL,
1996
- /* .event_wait = */ NULL,
1997
- };
1998
-
1999
- static ggml_guid_t ggml_backend_kompute_guid() {
2000
- static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
2001
- return &guid;
2002
- }
2003
-
2004
- ggml_backend_t ggml_backend_kompute_init(int device) {
2005
- GGML_ASSERT(s_kompute_context == nullptr);
2006
- s_kompute_context = new ggml_kompute_context(device);
2007
-
2008
- ggml_backend_t kompute_backend = new ggml_backend {
2009
- /* .guid = */ ggml_backend_kompute_guid(),
2010
- /* .interface = */ kompute_backend_i,
2011
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
2012
- /* .context = */ s_kompute_context,
2013
- };
2014
-
2015
- return kompute_backend;
2016
- }
2017
-
2018
- bool ggml_backend_is_kompute(ggml_backend_t backend) {
2019
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2020
- }
2021
-
2022
- static size_t ggml_backend_kompute_get_device_count() {
2023
- auto devices = ggml_vk_available_devices();
2024
- return devices.size();
2025
- }
2026
-
2027
- static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
2028
- auto devices = ggml_vk_available_devices();
2029
- GGML_ASSERT((size_t) device < devices.size());
2030
- snprintf(description, description_size, "%s", devices[device].name);
2031
- }
2032
-
2033
- static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
2034
- auto devices = ggml_vk_available_devices();
2035
- GGML_ASSERT((size_t) device < devices.size());
2036
- *total = devices[device].heapSize;
2037
- *free = devices[device].heapSize;
2038
- }
2039
-
2040
- //////////////////////////
2041
-
2042
- struct ggml_backend_kompute_device_context {
2043
- int device;
2044
- std::string name;
2045
- std::string description;
2046
- };
2047
-
2048
- static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
2049
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2050
- return ctx->name.c_str();
2051
- }
2052
-
2053
- static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
2054
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2055
- return ctx->description.c_str();
2056
- }
2057
-
2058
- static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2059
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2060
- ggml_backend_kompute_get_device_memory(ctx->device, free, total);
2061
- }
2062
-
2063
- static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
2064
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2065
- return ggml_backend_kompute_buffer_type(ctx->device);
2066
- }
2067
-
2068
- static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2069
- if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
2070
- return false;
2071
- }
2072
-
2073
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2074
- ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
2075
-
2076
- return buft_ctx->device == ctx->device;
2077
- }
2078
-
2079
- static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
2080
- GGML_UNUSED(dev);
2081
- return GGML_BACKEND_DEVICE_TYPE_GPU;
2082
- }
2083
-
2084
- static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2085
- props->name = ggml_backend_kompute_device_get_name(dev);
2086
- props->description = ggml_backend_kompute_device_get_description(dev);
2087
- props->type = ggml_backend_kompute_device_get_type(dev);
2088
- ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
2089
- props->caps = {
2090
- /* async = */ false,
2091
- /* host_buffer = */ false,
2092
- /* .buffer_from_host_ptr = */ false,
2093
- /* events = */ false,
2094
- };
2095
- }
2096
-
2097
- static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
2098
- GGML_UNUSED(params);
2099
- ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2100
- return ggml_backend_kompute_init(ctx->device);
2101
- }
2102
-
2103
- static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2104
- const int min_batch_size = 32;
2105
-
2106
- return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2107
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2108
-
2109
- GGML_UNUSED(dev);
2110
- }
2111
-
2112
- static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2113
- /* .get_name = */ ggml_backend_kompute_device_get_name,
2114
- /* .get_description = */ ggml_backend_kompute_device_get_description,
2115
- /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2116
- /* .get_type = */ ggml_backend_kompute_device_get_type,
2117
- /* .get_props = */ ggml_backend_kompute_device_get_props,
2118
- /* .init_backend = */ ggml_backend_kompute_device_init,
2119
- /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2120
- /* .get_host_buffer_type = */ NULL,
2121
- /* .buffer_from_host_ptr = */ NULL,
2122
- /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2123
- /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2124
- /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2125
- /* .event_new = */ NULL,
2126
- /* .event_free = */ NULL,
2127
- /* .event_synchronize = */ NULL,
2128
- };
2129
-
2130
- static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
2131
- GGML_UNUSED(reg);
2132
- return "Kompute";
2133
- }
2134
-
2135
- static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
2136
- GGML_UNUSED(reg);
2137
- return ggml_backend_kompute_get_device_count();
2138
- }
2139
-
2140
- static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
2141
- static std::vector<ggml_backend_dev_t> devices;
2142
-
2143
- static bool initialized = false;
2144
-
2145
- {
2146
- static std::mutex mutex;
2147
- std::lock_guard<std::mutex> lock(mutex);
2148
- if (!initialized) {
2149
- for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
2150
- ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2151
- char desc[256];
2152
- ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
2153
- ctx->device = i;
2154
- ctx->name = "Kompute" + std::to_string(i);
2155
- ctx->description = desc;
2156
- devices.push_back(new ggml_backend_device {
2157
- /* .iface = */ ggml_backend_kompute_device_i,
2158
- /* .reg = */ reg,
2159
- /* .context = */ ctx,
2160
- });
2161
- }
2162
- initialized = true;
2163
- }
2164
- }
2165
-
2166
- GGML_ASSERT(device < devices.size());
2167
- return devices[device];
2168
- }
2169
-
2170
- static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2171
- /* .get_name = */ ggml_backend_kompute_reg_get_name,
2172
- /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2173
- /* .get_device = */ ggml_backend_kompute_reg_get_device,
2174
- /* .get_proc_address = */ NULL,
2175
- };
2176
-
2177
- ggml_backend_reg_t ggml_backend_kompute_reg() {
2178
- static ggml_backend_reg reg = {
2179
- /* .iface = */ ggml_backend_kompute_reg_i,
2180
- /* .context = */ nullptr,
2181
- };
2182
-
2183
- return &reg;
2184
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-metal.m DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-metal.metal DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -3651,7 +3651,7 @@ static enum ggml_status ggml_metal_graph_compute(
3651
  dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
3652
 
3653
  // wait for completion and check status of each command buffer
3654
- // needed to detect if the device ran out-of-memory for example (ggml/1881)
3655
  {
3656
  id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
3657
  [command_buffer waitUntilCompleted];
 
3651
  dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
3652
 
3653
  // wait for completion and check status of each command buffer
3654
+ // needed to detect if the device ran out-of-memory for example (#1881)
3655
  {
3656
  id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
3657
  [command_buffer waitUntilCompleted];
ggml/src/ggml-musa/CMakeLists.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if (NOT EXISTS $ENV{MUSA_PATH})
2
+ if (NOT EXISTS /opt/musa)
3
+ set(MUSA_PATH /usr/local/musa)
4
+ else()
5
+ set(MUSA_PATH /opt/musa)
6
+ endif()
7
+ else()
8
+ set(MUSA_PATH $ENV{MUSA_PATH})
9
+ endif()
10
+
11
+ set(CMAKE_C_COMPILER "${MUSA_PATH}/bin/clang")
12
+ set(CMAKE_C_EXTENSIONS OFF)
13
+ set(CMAKE_CXX_COMPILER "${MUSA_PATH}/bin/clang++")
14
+ set(CMAKE_CXX_EXTENSIONS OFF)
15
+
16
+ list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
17
+
18
+ find_package(MUSAToolkit)
19
+
20
+ if (MUSAToolkit_FOUND)
21
+ message(STATUS "MUSA Toolkit found")
22
+
23
+ file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
24
+ list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
25
+
26
+ file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
27
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
28
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
29
+ file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
30
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
31
+
32
+ if (GGML_CUDA_FA_ALL_QUANTS)
33
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
34
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
35
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
36
+ else()
37
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
38
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
39
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
40
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
41
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
42
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
43
+ endif()
44
+
45
+ set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
46
+ foreach(SOURCE ${GGML_SOURCES_MUSA})
47
+ set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
48
+ endforeach()
49
+
50
+ add_library(ggml-musa
51
+ ${GGML_HEADERS_MUSA}
52
+ ${GGML_SOURCES_MUSA})
53
+
54
+ target_link_libraries(ggml-musa PRIVATE ggml-base)
55
+ target_include_directories(ggml-musa PRIVATE . ..)
56
+
57
+ # TODO: do not use CUDA definitions for MUSA
58
+ target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
59
+
60
+ add_compile_definitions(GGML_USE_MUSA)
61
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
62
+
63
+ if (GGML_CUDA_GRAPHS)
64
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS)
65
+ endif()
66
+
67
+ if (GGML_CUDA_FORCE_MMQ)
68
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
69
+ endif()
70
+
71
+ if (GGML_CUDA_FORCE_CUBLAS)
72
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
73
+ endif()
74
+
75
+ if (GGML_CUDA_NO_VMM)
76
+ add_compile_definitions(GGML_CUDA_NO_VMM)
77
+ endif()
78
+
79
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
80
+ add_compile_definitions(GGML_CUDA_F16)
81
+ endif()
82
+
83
+ if (GGML_CUDA_NO_PEER_COPY)
84
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
85
+ endif()
86
+
87
+ if (GGML_STATIC)
88
+ target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
89
+ else()
90
+ target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
91
+ endif()
92
+
93
+ if (GGML_CUDA_NO_VMM)
94
+ # No VMM requested, no need to link directly with the musa driver lib (libmusa.so)
95
+ else()
96
+ target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver)
97
+ endif()
98
+ else()
99
+ message(FATAL_ERROR "MUSA Toolkit not found")
100
+ endif()
ggml/src/ggml-rpc.cpp DELETED
@@ -1,1403 +0,0 @@
1
- #include "ggml-rpc.h"
2
- #include "ggml-impl.h"
3
- #include "ggml-backend-impl.h"
4
-
5
- #include <cinttypes>
6
- #include <string>
7
- #include <vector>
8
- #include <memory>
9
- #include <mutex>
10
- #include <unordered_map>
11
- #include <unordered_set>
12
- #ifdef _WIN32
13
- # define WIN32_LEAN_AND_MEAN
14
- # ifndef NOMINMAX
15
- # define NOMINMAX
16
- # endif
17
- # include <windows.h>
18
- # include <winsock2.h>
19
- #else
20
- # include <arpa/inet.h>
21
- # include <sys/socket.h>
22
- # include <sys/types.h>
23
- # include <netinet/in.h>
24
- # include <netinet/tcp.h>
25
- # include <netdb.h>
26
- # include <unistd.h>
27
- #endif
28
- #include <cstring>
29
-
30
- #define UNUSED GGML_UNUSED
31
-
32
- #define GGML_DEBUG 0
33
- #if (GGML_DEBUG >= 1)
34
- #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
35
- #else
36
- #define GGML_PRINT_DEBUG(...)
37
- #endif
38
-
39
- #ifdef _WIN32
40
- typedef SOCKET sockfd_t;
41
- using ssize_t = __int64;
42
- #else
43
- typedef int sockfd_t;
44
- #endif
45
-
46
- // cross-platform socket
47
- struct socket_t {
48
- sockfd_t fd;
49
- socket_t(sockfd_t fd) : fd(fd) {}
50
- ~socket_t() {
51
- GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
52
- #ifdef _WIN32
53
- closesocket(this->fd);
54
- #else
55
- close(this->fd);
56
- #endif
57
- }
58
- };
59
-
60
- // all RPC structures must be packed
61
- #pragma pack(push, 1)
62
- // ggml_tensor is serialized into rpc_tensor
63
- struct rpc_tensor {
64
- uint64_t id;
65
- uint32_t type;
66
- uint64_t buffer;
67
- uint32_t ne[GGML_MAX_DIMS];
68
- uint32_t nb[GGML_MAX_DIMS];
69
- uint32_t op;
70
- int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
71
- int32_t flags;
72
- uint64_t src[GGML_MAX_SRC];
73
- uint64_t view_src;
74
- uint64_t view_offs;
75
- uint64_t data;
76
- char name[GGML_MAX_NAME];
77
-
78
- char padding[4];
79
- };
80
-
81
- static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
82
-
83
- // RPC commands
84
- enum rpc_cmd {
85
- RPC_CMD_ALLOC_BUFFER = 0,
86
- RPC_CMD_GET_ALIGNMENT,
87
- RPC_CMD_GET_MAX_SIZE,
88
- RPC_CMD_BUFFER_GET_BASE,
89
- RPC_CMD_FREE_BUFFER,
90
- RPC_CMD_BUFFER_CLEAR,
91
- RPC_CMD_SET_TENSOR,
92
- RPC_CMD_GET_TENSOR,
93
- RPC_CMD_COPY_TENSOR,
94
- RPC_CMD_GRAPH_COMPUTE,
95
- RPC_CMD_GET_DEVICE_MEMORY,
96
- RPC_CMD_COUNT,
97
- };
98
-
99
- struct rpc_msg_alloc_buffer_req {
100
- uint64_t size;
101
- };
102
-
103
- struct rpc_msg_alloc_buffer_rsp {
104
- uint64_t remote_ptr;
105
- uint64_t remote_size;
106
- };
107
-
108
- struct rpc_msg_get_alignment_rsp {
109
- uint64_t alignment;
110
- };
111
-
112
- struct rpc_msg_get_max_size_rsp {
113
- uint64_t max_size;
114
- };
115
-
116
- struct rpc_msg_buffer_get_base_req {
117
- uint64_t remote_ptr;
118
- };
119
-
120
- struct rpc_msg_buffer_get_base_rsp {
121
- uint64_t base_ptr;
122
- };
123
-
124
- struct rpc_msg_free_buffer_req {
125
- uint64_t remote_ptr;
126
- };
127
-
128
- struct rpc_msg_buffer_clear_req {
129
- uint64_t remote_ptr;
130
- uint8_t value;
131
- };
132
-
133
- struct rpc_msg_get_tensor_req {
134
- rpc_tensor tensor;
135
- uint64_t offset;
136
- uint64_t size;
137
- };
138
-
139
- struct rpc_msg_copy_tensor_req {
140
- rpc_tensor src;
141
- rpc_tensor dst;
142
- };
143
-
144
- struct rpc_msg_copy_tensor_rsp {
145
- uint8_t result;
146
- };
147
-
148
- struct rpc_msg_graph_compute_rsp {
149
- uint8_t result;
150
- };
151
-
152
- struct rpc_msg_get_device_memory_rsp {
153
- uint64_t free_mem;
154
- uint64_t total_mem;
155
- };
156
- #pragma pack(pop)
157
-
158
- // RPC data structures
159
-
160
- static ggml_guid_t ggml_backend_rpc_guid() {
161
- static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
162
- return &guid;
163
- }
164
-
165
- struct ggml_backend_rpc_buffer_type_context {
166
- std::string endpoint;
167
- std::string name;
168
- size_t alignment;
169
- size_t max_size;
170
- };
171
-
172
- struct ggml_backend_rpc_context {
173
- std::string endpoint;
174
- std::string name;
175
- };
176
-
177
- struct ggml_backend_rpc_buffer_context {
178
- std::shared_ptr<socket_t> sock;
179
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
180
- uint64_t remote_ptr;
181
- };
182
-
183
- // RPC helper functions
184
-
185
- static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
186
- #ifdef _WIN32
187
- if (fd == INVALID_SOCKET) {
188
- return nullptr;
189
- }
190
- #else
191
- if (fd < 0) {
192
- return nullptr;
193
- }
194
- #endif
195
- return std::make_shared<socket_t>(fd);
196
- }
197
-
198
- static bool set_no_delay(sockfd_t sockfd) {
199
- int flag = 1;
200
- // set TCP_NODELAY to disable Nagle's algorithm
201
- int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
202
- return ret == 0;
203
- }
204
-
205
- static bool set_reuse_addr(sockfd_t sockfd) {
206
- int flag = 1;
207
- int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
208
- return ret == 0;
209
- }
210
-
211
- static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
212
- struct sockaddr_in addr;
213
- auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
214
- auto sock_ptr = make_socket(sockfd);
215
- if (sock_ptr == nullptr) {
216
- return nullptr;
217
- }
218
- if (!set_no_delay(sockfd)) {
219
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
220
- return nullptr;
221
- }
222
- addr.sin_family = AF_INET;
223
- addr.sin_port = htons(port);
224
- struct hostent * server = gethostbyname(host);
225
- if (server == NULL) {
226
- fprintf(stderr, "Cannot resolve host '%s'\n", host);
227
- return nullptr;
228
- }
229
- memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
230
- if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
231
- return nullptr;
232
- }
233
- return sock_ptr;
234
- }
235
-
236
- static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
237
- auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
238
- auto client_socket = make_socket(client_socket_fd);
239
- if (client_socket == nullptr) {
240
- return nullptr;
241
- }
242
- if (!set_no_delay(client_socket_fd)) {
243
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
244
- return nullptr;
245
- }
246
- return client_socket;
247
- }
248
-
249
- static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
250
- auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
251
- auto sock = make_socket(sockfd);
252
- if (sock == nullptr) {
253
- return nullptr;
254
- }
255
- if (!set_reuse_addr(sockfd)) {
256
- fprintf(stderr, "Failed to set SO_REUSEADDR\n");
257
- return nullptr;
258
- }
259
- if (inet_addr(host) == INADDR_NONE) {
260
- fprintf(stderr, "Invalid host address: %s\n", host);
261
- return nullptr;
262
- }
263
- struct sockaddr_in serv_addr;
264
- serv_addr.sin_family = AF_INET;
265
- serv_addr.sin_addr.s_addr = inet_addr(host);
266
- serv_addr.sin_port = htons(port);
267
-
268
- if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
269
- return nullptr;
270
- }
271
- if (listen(sockfd, 1) < 0) {
272
- return nullptr;
273
- }
274
- return sock;
275
- }
276
-
277
- static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
278
- size_t bytes_sent = 0;
279
- while (bytes_sent < size) {
280
- ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
281
- if (n < 0) {
282
- return false;
283
- }
284
- bytes_sent += n;
285
- }
286
- return true;
287
- }
288
-
289
- static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
290
- size_t bytes_recv = 0;
291
- while (bytes_recv < size) {
292
- ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
293
- if (n <= 0) {
294
- return false;
295
- }
296
- bytes_recv += n;
297
- }
298
- return true;
299
- }
300
-
301
- static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
302
- if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
303
- return false;
304
- }
305
- return send_data(sockfd, msg, msg_size);
306
- }
307
-
308
- static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
309
- uint64_t size;
310
- if (!recv_data(sockfd, &size, sizeof(size))) {
311
- return false;
312
- }
313
- if (size != msg_size) {
314
- return false;
315
- }
316
- return recv_data(sockfd, msg, msg_size);
317
- }
318
-
319
- static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
320
- uint64_t size;
321
- if (!recv_data(sockfd, &size, sizeof(size))) {
322
- return false;
323
- }
324
- try {
325
- input.resize(size);
326
- } catch (const std::bad_alloc & e) {
327
- fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
328
- return false;
329
- }
330
- return recv_data(sockfd, input.data(), size);
331
- }
332
-
333
- static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
334
- size_t pos = endpoint.find(':');
335
- if (pos == std::string::npos) {
336
- return false;
337
- }
338
- host = endpoint.substr(0, pos);
339
- port = std::stoi(endpoint.substr(pos + 1));
340
- return true;
341
- }
342
-
343
- // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
344
- // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
345
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
346
- uint8_t cmd_byte = cmd;
347
- if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
348
- return false;
349
- }
350
- if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
351
- return false;
352
- }
353
- if (!send_data(sock->fd, input, input_size)) {
354
- return false;
355
- }
356
- // TODO: currently the output_size is always known, do we need support for commands with variable output size?
357
- // even if we do, we can skip sending output_size from the server for commands with known output size
358
- uint64_t out_size;
359
- if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
360
- return false;
361
- }
362
- if (out_size != output_size) {
363
- return false;
364
- }
365
- if (!recv_data(sock->fd, output, output_size)) {
366
- return false;
367
- }
368
- return true;
369
- }
370
-
371
- // RPC client-side implementation
372
-
373
- static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
374
- static std::mutex mutex;
375
- std::lock_guard<std::mutex> lock(mutex);
376
- static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
377
- static bool initialized = false;
378
-
379
- auto it = sockets.find(endpoint);
380
- if (it != sockets.end()) {
381
- if (auto sock = it->second.lock()) {
382
- return sock;
383
- }
384
- }
385
- std::string host;
386
- int port;
387
- if (!parse_endpoint(endpoint, host, port)) {
388
- return nullptr;
389
- }
390
- #ifdef _WIN32
391
- if (!initialized) {
392
- WSADATA wsaData;
393
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
394
- if (res != 0) {
395
- return nullptr;
396
- }
397
- initialized = true;
398
- }
399
- #else
400
- UNUSED(initialized);
401
- #endif
402
- auto sock = socket_connect(host.c_str(), port);
403
- if (sock == nullptr) {
404
- return nullptr;
405
- }
406
- GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
407
- sockets[endpoint] = sock;
408
- return sock;
409
- }
410
-
411
- static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
412
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
413
- rpc_msg_free_buffer_req request = {ctx->remote_ptr};
414
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
415
- GGML_ASSERT(status);
416
- delete ctx;
417
- }
418
-
419
- static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
420
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
421
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
422
- return ctx->base_cache[buffer];
423
- }
424
- rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
425
- rpc_msg_buffer_get_base_rsp response;
426
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
427
- GGML_ASSERT(status);
428
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
429
- ctx->base_cache[buffer] = base_ptr;
430
- return base_ptr;
431
- }
432
-
433
- static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
434
- rpc_tensor result;
435
- result.id = reinterpret_cast<uint64_t>(tensor);
436
- result.type = tensor->type;
437
- if (tensor->buffer) {
438
- ggml_backend_buffer_t buffer = tensor->buffer;
439
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
440
- result.buffer = ctx->remote_ptr;
441
- } else {
442
- result.buffer = 0;
443
- }
444
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
445
- result.ne[i] = tensor->ne[i];
446
- result.nb[i] = tensor->nb[i];
447
- }
448
- result.op = tensor->op;
449
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
450
- result.op_params[i] = tensor->op_params[i];
451
- }
452
- result.flags = tensor->flags;
453
- for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
454
- result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
455
- }
456
- result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
457
- result.view_offs = tensor->view_offs;
458
- result.data = reinterpret_cast<uint64_t>(tensor->data);
459
- snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
460
- return result;
461
- }
462
-
463
- static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
464
- UNUSED(buffer);
465
- if (ggml_is_quantized(tensor->type)) {
466
- // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
467
- GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
468
- }
469
- }
470
-
471
- static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
472
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
473
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
474
- size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
475
- std::vector<uint8_t> input(input_size, 0);
476
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
477
- memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
478
- memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
479
- memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
480
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
481
- GGML_ASSERT(status);
482
- }
483
-
484
- static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
485
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
486
- rpc_msg_get_tensor_req request;
487
- request.tensor = serialize_tensor(tensor);
488
- request.offset = offset;
489
- request.size = size;
490
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
491
- GGML_ASSERT(status);
492
- }
493
-
494
- static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
495
- // check if src and dst are on the same server
496
- ggml_backend_buffer_t src_buffer = src->buffer;
497
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
498
- ggml_backend_buffer_t dst_buffer = dst->buffer;
499
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
500
- if (src_ctx->sock != dst_ctx->sock) {
501
- return false;
502
- }
503
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
504
- rpc_msg_copy_tensor_req request;
505
- request.src = serialize_tensor(src);
506
- request.dst = serialize_tensor(dst);
507
- rpc_msg_copy_tensor_rsp response;
508
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
509
- GGML_ASSERT(status);
510
- return response.result;
511
- }
512
-
513
- static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
514
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
515
- rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
516
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
517
- GGML_ASSERT(status);
518
- }
519
-
520
- static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
521
- /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
522
- /* .get_base = */ ggml_backend_rpc_buffer_get_base,
523
- /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
524
- /* .memset_tensor = */ NULL,
525
- /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
526
- /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
527
- /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
528
- /* .clear = */ ggml_backend_rpc_buffer_clear,
529
- /* .reset = */ NULL,
530
- };
531
-
532
- static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
534
- return buft_ctx->name.c_str();
535
- }
536
-
537
- static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
538
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
539
- rpc_msg_alloc_buffer_req request = {size};
540
- rpc_msg_alloc_buffer_rsp response;
541
- auto sock = get_socket(buft_ctx->endpoint);
542
- bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
543
- GGML_ASSERT(status);
544
- if (response.remote_ptr != 0) {
545
- ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
546
- ggml_backend_rpc_buffer_interface,
547
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
548
- response.remote_size);
549
- return buffer;
550
- } else {
551
- return nullptr;
552
- }
553
- }
554
-
555
- static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
556
- rpc_msg_get_alignment_rsp response;
557
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
558
- GGML_ASSERT(status);
559
- return response.alignment;
560
- }
561
-
562
- static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
563
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
564
- return buft_ctx->alignment;
565
- }
566
-
567
- static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
568
- rpc_msg_get_max_size_rsp response;
569
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
570
- GGML_ASSERT(status);
571
- return response.max_size;
572
- }
573
-
574
- static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
575
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
576
- return buft_ctx->max_size;
577
- }
578
-
579
- static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
580
- UNUSED(buft);
581
- return ggml_nbytes(tensor);
582
- }
583
-
584
- static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
585
- /* .get_name = */ ggml_backend_rpc_buffer_type_name,
586
- /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
587
- /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
588
- /* .get_max_size = */ ggml_backend_rpc_get_max_size,
589
- /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
590
- /* .is_host = */ NULL,
591
- };
592
-
593
- static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
594
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
595
-
596
- return rpc_ctx->name.c_str();
597
- }
598
-
599
- static void ggml_backend_rpc_free(ggml_backend_t backend) {
600
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
601
- delete rpc_ctx;
602
- delete backend;
603
- }
604
-
605
- static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
606
- UNUSED(backend);
607
- // this is no-op because we don't have any async operations
608
- }
609
-
610
- static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
611
- if (tensor == nullptr) {
612
- return;
613
- }
614
- if (visited.find(tensor) != visited.end()) {
615
- return;
616
- }
617
- visited.insert(tensor);
618
- for (int i = 0; i < GGML_MAX_SRC; i++) {
619
- add_tensor(tensor->src[i], tensors, visited);
620
- }
621
- add_tensor(tensor->view_src, tensors, visited);
622
- tensors.push_back(serialize_tensor(tensor));
623
- }
624
-
625
- static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
626
- uint32_t n_nodes = cgraph->n_nodes;
627
- std::vector<rpc_tensor> tensors;
628
- std::unordered_set<ggml_tensor*> visited;
629
- for (uint32_t i = 0; i < n_nodes; i++) {
630
- add_tensor(cgraph->nodes[i], tensors, visited);
631
- }
632
- // serialization format:
633
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
634
- uint32_t n_tensors = tensors.size();
635
- int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
636
- output.resize(output_size, 0);
637
- memcpy(output.data(), &n_nodes, sizeof(n_nodes));
638
- for (uint32_t i = 0; i < n_nodes; i++) {
639
- memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
640
- }
641
- uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
642
- *out_ntensors = n_tensors;
643
- rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
644
- memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
645
- }
646
-
647
- static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
648
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
649
- std::vector<uint8_t> input;
650
- serialize_graph(cgraph, input);
651
- rpc_msg_graph_compute_rsp response;
652
- auto sock = get_socket(rpc_ctx->endpoint);
653
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
654
- GGML_ASSERT(status);
655
- return (enum ggml_status)response.result;
656
- }
657
-
658
- static ggml_backend_i ggml_backend_rpc_interface = {
659
- /* .get_name = */ ggml_backend_rpc_name,
660
- /* .free = */ ggml_backend_rpc_free,
661
- /* .set_tensor_async = */ NULL,
662
- /* .get_tensor_async = */ NULL,
663
- /* .cpy_tensor_async = */ NULL,
664
- /* .synchronize = */ ggml_backend_rpc_synchronize,
665
- /* .graph_plan_create = */ NULL,
666
- /* .graph_plan_free = */ NULL,
667
- /* .graph_plan_update = */ NULL,
668
- /* .graph_plan_compute = */ NULL,
669
- /* .graph_compute = */ ggml_backend_rpc_graph_compute,
670
- /* .event_record = */ NULL,
671
- /* .event_wait = */ NULL,
672
- };
673
-
674
- GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
675
- static std::mutex mutex;
676
- std::lock_guard<std::mutex> lock(mutex);
677
- // NOTE: buffer types are allocated and never freed; this is by design
678
- static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
679
- auto it = buft_map.find(endpoint);
680
- if (it != buft_map.end()) {
681
- return it->second;
682
- }
683
- auto sock = get_socket(endpoint);
684
- if (sock == nullptr) {
685
- fprintf(stderr, "Failed to connect to %s\n", endpoint);
686
- return nullptr;
687
- }
688
- size_t alignment = get_alignment(sock);
689
- size_t max_size = get_max_size(sock);
690
- ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
691
- /* .endpoint = */ endpoint,
692
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
693
- /* .alignment = */ alignment,
694
- /* .max_size = */ max_size
695
- };
696
-
697
- ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
698
- /* .iface = */ ggml_backend_rpc_buffer_type_interface,
699
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
700
- /* .context = */ buft_ctx
701
- };
702
- buft_map[endpoint] = buft;
703
- return buft;
704
- }
705
-
706
- ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707
- ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
708
- /* .endpoint = */ endpoint,
709
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
710
- };
711
-
712
- ggml_backend_t backend = new ggml_backend {
713
- /* .guid = */ ggml_backend_rpc_guid(),
714
- /* .interface = */ ggml_backend_rpc_interface,
715
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
716
- /* .context = */ ctx
717
- };
718
- return backend;
719
- }
720
-
721
- GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
722
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
723
- }
724
-
725
- static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
726
- rpc_msg_get_device_memory_rsp response;
727
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
728
- GGML_ASSERT(status);
729
- *free = response.free_mem;
730
- *total = response.total_mem;
731
- }
732
-
733
- GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
734
- auto sock = get_socket(endpoint);
735
- if (sock == nullptr) {
736
- *free = 0;
737
- *total = 0;
738
- return;
739
- }
740
- get_device_memory(sock, free, total);
741
- }
742
-
743
- // RPC server-side implementation
744
-
745
- class rpc_server {
746
- public:
747
- rpc_server(ggml_backend_t backend) : backend(backend) {}
748
- ~rpc_server();
749
-
750
- void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
751
- void get_alignment(rpc_msg_get_alignment_rsp & response);
752
- void get_max_size(rpc_msg_get_max_size_rsp & response);
753
- bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
754
- bool free_buffer(const rpc_msg_free_buffer_req & request);
755
- bool buffer_clear(const rpc_msg_buffer_clear_req & request);
756
- bool set_tensor(const std::vector<uint8_t> & input);
757
- bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758
- bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759
- bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
760
-
761
- private:
762
- ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
763
- ggml_tensor * create_node(uint64_t id,
764
- struct ggml_context * ctx,
765
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
766
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
767
-
768
-
769
- ggml_backend_t backend;
770
- std::unordered_set<ggml_backend_buffer_t> buffers;
771
- };
772
-
773
- void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
774
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
775
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
776
- response.remote_ptr = 0;
777
- response.remote_size = 0;
778
- if (buffer != nullptr) {
779
- response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
780
- response.remote_size = buffer->size;
781
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
782
- buffers.insert(buffer);
783
- } else {
784
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
785
- }
786
- }
787
-
788
- void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
789
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
790
- size_t alignment = ggml_backend_buft_get_alignment(buft);
791
- GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
792
- response.alignment = alignment;
793
- }
794
-
795
- void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
796
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
797
- size_t max_size = ggml_backend_buft_get_max_size(buft);
798
- GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
799
- response.max_size = max_size;
800
- }
801
-
802
- bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
803
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
804
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
805
- if (buffers.find(buffer) == buffers.end()) {
806
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
807
- return false;
808
- }
809
- void * base = ggml_backend_buffer_get_base(buffer);
810
- response.base_ptr = reinterpret_cast<uint64_t>(base);
811
- return true;
812
- }
813
-
814
- bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
815
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
816
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
817
- if (buffers.find(buffer) == buffers.end()) {
818
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
819
- return false;
820
- }
821
- ggml_backend_buffer_free(buffer);
822
- buffers.erase(buffer);
823
- return true;
824
- }
825
-
826
- bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
827
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
828
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
829
- if (buffers.find(buffer) == buffers.end()) {
830
- GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
831
- return false;
832
- }
833
- ggml_backend_buffer_clear(buffer, request.value);
834
- return true;
835
- }
836
-
837
- ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
838
- ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
839
- tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
840
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
841
- result->nb[i] = tensor->nb[i];
842
- }
843
- result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
844
- if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
845
- result->buffer = nullptr;
846
- }
847
-
848
- if (result->buffer) {
849
- // require that the tensor data does not go beyond the buffer end
850
- uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
851
- uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
852
- uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
853
- GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
854
- GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
855
- }
856
-
857
- result->op = (ggml_op) tensor->op;
858
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
859
- result->op_params[i] = tensor->op_params[i];
860
- }
861
- result->flags = tensor->flags;
862
- result->data = reinterpret_cast<void *>(tensor->data);
863
- ggml_set_name(result, tensor->name);
864
- return result;
865
- }
866
-
867
-
868
- bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
869
- // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
870
- if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
871
- return false;
872
- }
873
- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
874
- uint64_t offset;
875
- memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
876
- const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
877
-
878
- struct ggml_init_params params {
879
- /*.mem_size =*/ ggml_tensor_overhead(),
880
- /*.mem_buffer =*/ NULL,
881
- /*.no_alloc =*/ true,
882
- };
883
- struct ggml_context * ctx = ggml_init(params);
884
- ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
885
- if (tensor == nullptr) {
886
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
887
- ggml_free(ctx);
888
- return false;
889
- }
890
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
891
-
892
- // sanitize tensor->data
893
- {
894
- const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
895
- const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
896
-
897
- if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
898
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
899
- }
900
- }
901
-
902
- const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
903
- ggml_backend_tensor_set(tensor, data, offset, size);
904
- ggml_free(ctx);
905
- return true;
906
- }
907
-
908
- bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
909
- struct ggml_init_params params {
910
- /*.mem_size =*/ ggml_tensor_overhead(),
911
- /*.mem_buffer =*/ NULL,
912
- /*.no_alloc =*/ true,
913
- };
914
- struct ggml_context * ctx = ggml_init(params);
915
- ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
916
- if (tensor == nullptr) {
917
- GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
918
- ggml_free(ctx);
919
- return false;
920
- }
921
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
922
-
923
- // sanitize tensor->data
924
- {
925
- const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
926
- const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
927
-
928
- if (request.tensor.data + request.offset < p0 ||
929
- request.tensor.data + request.offset >= p1 ||
930
- request.size > (p1 - request.tensor.data - request.offset)) {
931
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
932
- }
933
- }
934
-
935
- response.resize(request.size, 0);
936
- ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
937
- ggml_free(ctx);
938
- return true;
939
- }
940
-
941
- bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
942
- struct ggml_init_params params {
943
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
944
- /*.mem_buffer =*/ NULL,
945
- /*.no_alloc =*/ true,
946
- };
947
- struct ggml_context * ctx = ggml_init(params);
948
- ggml_tensor * src = deserialize_tensor(ctx, &request.src);
949
- ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
950
- if (src == nullptr || dst == nullptr) {
951
- GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
952
- ggml_free(ctx);
953
- return false;
954
- }
955
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
956
- response.result = ggml_backend_buffer_copy_tensor(src, dst);
957
- ggml_free(ctx);
958
- return true;
959
- }
960
-
961
- ggml_tensor * rpc_server::create_node(uint64_t id,
962
- struct ggml_context * ctx,
963
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
964
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
965
- if (id == 0) {
966
- return nullptr;
967
- }
968
- if (tensor_map.find(id) != tensor_map.end()) {
969
- return tensor_map[id];
970
- }
971
- const rpc_tensor * tensor = tensor_ptrs.at(id);
972
- struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
973
- if (result == nullptr) {
974
- return nullptr;
975
- }
976
- tensor_map[id] = result;
977
- for (int i = 0; i < GGML_MAX_SRC; i++) {
978
- result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
979
- }
980
- result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
981
- result->view_offs = tensor->view_offs;
982
- return result;
983
- }
984
-
985
- bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
986
- // serialization format:
987
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
988
- if (input.size() < sizeof(uint32_t)) {
989
- return false;
990
- }
991
- uint32_t n_nodes;
992
- memcpy(&n_nodes, input.data(), sizeof(n_nodes));
993
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
994
- return false;
995
- }
996
- const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
997
- uint32_t n_tensors;
998
- memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
999
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1000
- return false;
1001
- }
1002
- const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
1003
- GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1004
-
1005
- size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1006
- struct ggml_init_params params = {
1007
- /*.mem_size =*/ buf_size,
1008
- /*.mem_buffer =*/ NULL,
1009
- /*.no_alloc =*/ true,
1010
- };
1011
- struct ggml_context * ctx = ggml_init(params);
1012
- struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1013
- graph->n_nodes = n_nodes;
1014
- std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1015
- for (uint32_t i = 0; i < n_tensors; i++) {
1016
- tensor_ptrs[tensors[i].id] = &tensors[i];
1017
- }
1018
- std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1019
- for (uint32_t i = 0; i < n_nodes; i++) {
1020
- int64_t id;
1021
- memcpy(&id, &nodes[i], sizeof(id));
1022
- graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1023
- }
1024
- ggml_status status = ggml_backend_graph_compute(backend, graph);
1025
- response.result = status;
1026
- ggml_free(ctx);
1027
- return true;
1028
- }
1029
-
1030
- rpc_server::~rpc_server() {
1031
- for (auto buffer : buffers) {
1032
- ggml_backend_buffer_free(buffer);
1033
- }
1034
- }
1035
-
1036
- static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1037
- rpc_server server(backend);
1038
- while (true) {
1039
- uint8_t cmd;
1040
- if (!recv_data(sockfd, &cmd, 1)) {
1041
- break;
1042
- }
1043
- if (cmd >= RPC_CMD_COUNT) {
1044
- // fail fast if the command is invalid
1045
- fprintf(stderr, "Unknown command: %d\n", cmd);
1046
- break;
1047
- }
1048
- switch (cmd) {
1049
- case RPC_CMD_ALLOC_BUFFER: {
1050
- rpc_msg_alloc_buffer_req request;
1051
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1052
- return;
1053
- }
1054
- rpc_msg_alloc_buffer_rsp response;
1055
- server.alloc_buffer(request, response);
1056
- if (!send_msg(sockfd, &response, sizeof(response))) {
1057
- return;
1058
- }
1059
- break;
1060
- }
1061
- case RPC_CMD_GET_ALIGNMENT: {
1062
- if (!recv_msg(sockfd, nullptr, 0)) {
1063
- return;
1064
- }
1065
- rpc_msg_get_alignment_rsp response;
1066
- server.get_alignment(response);
1067
- if (!send_msg(sockfd, &response, sizeof(response))) {
1068
- return;
1069
- }
1070
- break;
1071
- }
1072
- case RPC_CMD_GET_MAX_SIZE: {
1073
- if (!recv_msg(sockfd, nullptr, 0)) {
1074
- return;
1075
- }
1076
- rpc_msg_get_max_size_rsp response;
1077
- server.get_max_size(response);
1078
- if (!send_msg(sockfd, &response, sizeof(response))) {
1079
- return;
1080
- }
1081
- break;
1082
- }
1083
- case RPC_CMD_BUFFER_GET_BASE: {
1084
- rpc_msg_buffer_get_base_req request;
1085
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1086
- return;
1087
- }
1088
- rpc_msg_buffer_get_base_rsp response;
1089
- if (!server.buffer_get_base(request, response)) {
1090
- return;
1091
- }
1092
- if (!send_msg(sockfd, &response, sizeof(response))) {
1093
- return;
1094
- }
1095
- break;
1096
- }
1097
- case RPC_CMD_FREE_BUFFER: {
1098
- rpc_msg_free_buffer_req request;
1099
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1100
- return;
1101
- }
1102
- if (!server.free_buffer(request)) {
1103
- return;
1104
- }
1105
- if (!send_msg(sockfd, nullptr, 0)) {
1106
- return;
1107
- }
1108
- break;
1109
- }
1110
- case RPC_CMD_BUFFER_CLEAR: {
1111
- rpc_msg_buffer_clear_req request;
1112
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1113
- return;
1114
- }
1115
- if (!server.buffer_clear(request)) {
1116
- return;
1117
- }
1118
- if (!send_msg(sockfd, nullptr, 0)) {
1119
- return;
1120
- }
1121
- break;
1122
- }
1123
- case RPC_CMD_SET_TENSOR: {
1124
- std::vector<uint8_t> input;
1125
- if (!recv_msg(sockfd, input)) {
1126
- return;
1127
- }
1128
- if (!server.set_tensor(input)) {
1129
- return;
1130
- }
1131
- if (!send_msg(sockfd, nullptr, 0)) {
1132
- return;
1133
- }
1134
- break;
1135
- }
1136
- case RPC_CMD_GET_TENSOR: {
1137
- rpc_msg_get_tensor_req request;
1138
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1139
- return;
1140
- }
1141
- std::vector<uint8_t> response;
1142
- if (!server.get_tensor(request, response)) {
1143
- return;
1144
- }
1145
- if (!send_msg(sockfd, response.data(), response.size())) {
1146
- return;
1147
- }
1148
- break;
1149
- }
1150
- case RPC_CMD_COPY_TENSOR: {
1151
- rpc_msg_copy_tensor_req request;
1152
- if (!recv_msg(sockfd, &request, sizeof(request))) {
1153
- return;
1154
- }
1155
- rpc_msg_copy_tensor_rsp response;
1156
- if (!server.copy_tensor(request, response)) {
1157
- return;
1158
- }
1159
- if (!send_msg(sockfd, &response, sizeof(response))) {
1160
- return;
1161
- }
1162
- break;
1163
- }
1164
- case RPC_CMD_GRAPH_COMPUTE: {
1165
- std::vector<uint8_t> input;
1166
- if (!recv_msg(sockfd, input)) {
1167
- return;
1168
- }
1169
- rpc_msg_graph_compute_rsp response;
1170
- if (!server.graph_compute(input, response)) {
1171
- return;
1172
- }
1173
- if (!send_msg(sockfd, &response, sizeof(response))) {
1174
- return;
1175
- }
1176
- break;
1177
- }
1178
- case RPC_CMD_GET_DEVICE_MEMORY: {
1179
- if (!recv_msg(sockfd, nullptr, 0)) {
1180
- return;
1181
- }
1182
- rpc_msg_get_device_memory_rsp response;
1183
- response.free_mem = free_mem;
1184
- response.total_mem = total_mem;
1185
- if (!send_msg(sockfd, &response, sizeof(response))) {
1186
- return;
1187
- }
1188
- break;
1189
- }
1190
- default: {
1191
- fprintf(stderr, "Unknown command: %d\n", cmd);
1192
- return;
1193
- }
1194
- }
1195
- }
1196
- }
1197
-
1198
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1199
- std::string host;
1200
- int port;
1201
- if (!parse_endpoint(endpoint, host, port)) {
1202
- return;
1203
- }
1204
- #ifdef _WIN32
1205
- {
1206
- WSADATA wsaData;
1207
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
1208
- if (res != 0) {
1209
- fprintf(stderr, "WSAStartup failed: %d\n", res);
1210
- return;
1211
- }
1212
- }
1213
- #endif
1214
- auto server_socket = create_server_socket(host.c_str(), port);
1215
- if (server_socket == nullptr) {
1216
- fprintf(stderr, "Failed to create server socket\n");
1217
- return;
1218
- }
1219
- while (true) {
1220
- auto client_socket = socket_accept(server_socket->fd);
1221
- if (client_socket == nullptr) {
1222
- fprintf(stderr, "Failed to accept client connection\n");
1223
- return;
1224
- }
1225
- printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1226
- fflush(stdout);
1227
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1228
- printf("Client connection closed\n");
1229
- fflush(stdout);
1230
- }
1231
- #ifdef _WIN32
1232
- WSACleanup();
1233
- #endif
1234
- }
1235
-
1236
- // device interface
1237
-
1238
- struct ggml_backend_rpc_device_context {
1239
- std::string endpoint;
1240
- std::string name;
1241
- };
1242
-
1243
- static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1244
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1245
-
1246
- return ctx->name.c_str();
1247
- }
1248
-
1249
- static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1250
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1251
-
1252
- return ctx->name.c_str();
1253
- }
1254
-
1255
- static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1256
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1257
-
1258
- ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1259
-
1260
- UNUSED(dev);
1261
- }
1262
-
1263
- static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1264
- // TODO: obtain value from the server
1265
- return GGML_BACKEND_DEVICE_TYPE_GPU;
1266
-
1267
- UNUSED(dev);
1268
- }
1269
-
1270
- static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1271
- props->name = ggml_backend_rpc_device_get_name(dev);
1272
- props->description = ggml_backend_rpc_device_get_description(dev);
1273
- props->type = ggml_backend_rpc_device_get_type(dev);
1274
- ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1275
- props->caps = {
1276
- /* .async = */ false,
1277
- /* .host_buffer = */ false,
1278
- /* .buffer_from_host_ptr = */ false,
1279
- /* .events = */ false,
1280
- };
1281
- }
1282
-
1283
- static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1284
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1285
-
1286
- return ggml_backend_rpc_init(ctx->endpoint.c_str());
1287
-
1288
- UNUSED(params);
1289
- }
1290
-
1291
- static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1292
- ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1293
-
1294
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1295
-
1296
- UNUSED(dev);
1297
- }
1298
-
1299
- static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1300
- UNUSED(dev);
1301
- UNUSED(op);
1302
- //TODO: call the remote backend and cache the results
1303
- return true;
1304
- }
1305
-
1306
- static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1307
- if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1308
- return false;
1309
- }
1310
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1311
- ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1312
- return buft_ctx->endpoint == dev_ctx->endpoint;
1313
- }
1314
-
1315
- static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1316
- /* .get_name = */ ggml_backend_rpc_device_get_name,
1317
- /* .get_description = */ ggml_backend_rpc_device_get_description,
1318
- /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1319
- /* .get_type = */ ggml_backend_rpc_device_get_type,
1320
- /* .get_props = */ ggml_backend_rpc_device_get_props,
1321
- /* .init_backend = */ ggml_backend_rpc_device_init,
1322
- /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1323
- /* .get_host_buffer_type = */ NULL,
1324
- /* .buffer_from_host_ptr = */ NULL,
1325
- /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1326
- /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1327
- /* .offload_op = */ NULL,
1328
- /* .event_new = */ NULL,
1329
- /* .event_free = */ NULL,
1330
- /* .event_synchronize = */ NULL,
1331
- };
1332
-
1333
- // backend reg interface
1334
-
1335
- static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1336
- return "RPC";
1337
-
1338
- UNUSED(reg);
1339
- }
1340
-
1341
- static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1342
- return 0;
1343
-
1344
- UNUSED(reg);
1345
- }
1346
-
1347
- static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1348
- GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1349
-
1350
- UNUSED(reg);
1351
- UNUSED(index);
1352
- }
1353
-
1354
- static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1355
- if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1356
- return (void *)ggml_backend_rpc_add_device;
1357
- }
1358
- return NULL;
1359
-
1360
- UNUSED(reg);
1361
- }
1362
-
1363
- static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1364
- /* .get_name = */ ggml_backend_rpc_reg_get_name,
1365
- /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1366
- /* .get_device = */ ggml_backend_rpc_reg_get_device,
1367
- /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1368
- };
1369
-
1370
- ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1371
- static struct ggml_backend_reg ggml_backend_rpc_reg = {
1372
- /* .iface = */ ggml_backend_rpc_reg_i,
1373
- /* .context = */ NULL,
1374
- };
1375
-
1376
- return &ggml_backend_rpc_reg;
1377
- }
1378
-
1379
- ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1380
- static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1381
-
1382
- static std::mutex mutex;
1383
- std::lock_guard<std::mutex> lock(mutex);
1384
-
1385
- if (dev_map.find(endpoint) != dev_map.end()) {
1386
- return dev_map[endpoint];
1387
- }
1388
-
1389
- ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1390
- /* .endpoint = */ endpoint,
1391
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
1392
- };
1393
-
1394
- ggml_backend_dev_t dev = new ggml_backend_device {
1395
- /* .iface = */ ggml_backend_rpc_device_i,
1396
- /* .reg = */ ggml_backend_rpc_reg(),
1397
- /* .context = */ ctx,
1398
- };
1399
-
1400
- dev_map[endpoint] = dev;
1401
-
1402
- return dev;
1403
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-sycl.cpp DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-vulkan.cpp DELETED
The diff for this file is too large to render. See raw diff
 
ggml/src/kompute-shaders/common.comp DELETED
@@ -1,102 +0,0 @@
1
- #extension GL_EXT_shader_16bit_storage: require
2
- #extension GL_EXT_shader_8bit_storage: require
3
- #extension GL_EXT_shader_explicit_arithmetic_types_float16: require
4
- #extension GL_EXT_shader_explicit_arithmetic_types_int8: require
5
- #extension GL_EXT_shader_explicit_arithmetic_types_int16: require
6
- #extension GL_EXT_control_flow_attributes: enable
7
- #extension GL_KHR_shader_subgroup_arithmetic : require
8
- #extension GL_EXT_debug_printf : enable
9
-
10
- #define QK4_0 32
11
- #define QK4_1 32
12
-
13
- #define GELU_COEF_A 0.044715
14
- #define SQRT_2_OVER_PI 0.79788456080286535587989211986876
15
- #define TWOPI_F 6.283185307179586f
16
-
17
- #define QK_K 256
18
-
19
- #define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
20
- #define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
21
- #define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
22
- #define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
23
-
24
- #define sizeof_block_q4_0 0x12
25
- struct block_q4_0 {
26
- float16_t d;
27
- uint8_t qs[QK4_0 / 2];
28
- };
29
- mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
30
- const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
31
- const float d2 = d1 / 256.f;
32
- const float md = -8.f * xb.d;
33
- const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
34
- const uint16_t mask1 = mask0 << 8;
35
-
36
- mat4 reg;
37
- for (int i=0;i<8;i++) {
38
- uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
39
- reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
40
- reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
41
- }
42
- return reg;
43
- }
44
-
45
- #define sizeof_block_q4_1 0x14
46
- struct block_q4_1 {
47
- float16_t d;
48
- float16_t m;
49
- uint8_t qs[QK4_1 / 2];
50
- };
51
- mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
52
- const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
53
- const float d2 = d1 / 256.f;
54
- const float m = xb.m;
55
- const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
56
- const uint16_t mask1 = mask0 << 8;
57
-
58
- mat4 reg;
59
- for (int i=0;i<8;i++) {
60
- uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
61
- reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
62
- reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
63
- }
64
- return reg;
65
- }
66
-
67
- #define sizeof_block_q6_k 210
68
- struct block_q6_k {
69
- uint8_t ql[QK_K/2]; // quants, lower 4 bits
70
- uint8_t qh[QK_K/4]; // quants, upper 2 bits
71
- int8_t scales[QK_K/16]; // scales, quantized with 8 bits
72
- float16_t d; // super-block scale
73
- };
74
- mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
75
- const float16_t d_all = xb.d;
76
-
77
- const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
78
- const uint qhIndex = 32*(il/8) + 16*(il&1);
79
- float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
80
- il = (il/2) & 3;
81
-
82
- const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
83
- const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
84
- const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
85
- const float16_t ml = float16_t(d_all * sc * 32.f);
86
- const float16_t dl = float16_t(d_all * sc * coef);
87
- mat4 reg;
88
- for (int i = 0; i < 16; ++i) {
89
- const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
90
- : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
91
- reg[i/4][i%4] = dl * q - ml;
92
- }
93
- return reg;
94
- }
95
-
96
-
97
- #define QK8_0 32
98
- // struct block_q8_0 {
99
- // float16_t d; // delta
100
- // int8_t qs[QK8_0]; // quants
101
- // };
102
- #define sizeof_block_q8_0 34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_add.comp DELETED
@@ -1,58 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1024) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
8
- layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
9
- layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
10
-
11
- layout(push_constant) uniform PushConstants {
12
- uint inAOff;
13
- uint inBOff;
14
- uint outOff;
15
- int ne00;
16
- int nb00;
17
- int nb01;
18
- int nb02;
19
- int nb03;
20
- int ne10;
21
- int ne11;
22
- int ne12;
23
- int ne13;
24
- int nb10;
25
- int nb11;
26
- int nb12;
27
- int nb13;
28
- int ne0;
29
- int nb0;
30
- int nb1;
31
- int nb2;
32
- int nb3;
33
- //int offs; // TODO: needed for GGML_OP_ACC, see metal code
34
- } pcs;
35
-
36
- // general-purpose kernel for addition of two tensors
37
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
38
- // cons: not very efficient
39
- void main() {
40
- const uint i03 = gl_WorkGroupID.z;
41
- const uint i02 = gl_WorkGroupID.y;
42
- const uint i01 = gl_WorkGroupID.x;
43
-
44
- const uint i13 = i03 % pcs.ne13;
45
- const uint i12 = i02 % pcs.ne12;
46
- const uint i11 = i01 % pcs.ne11;
47
-
48
- int offs = 0; // TMP (see above)
49
-
50
- uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
51
- uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
52
- uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
53
-
54
- for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
55
- const uint i10 = i0 % pcs.ne10;
56
- out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
57
- }
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_addrow.comp DELETED
@@ -1,25 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
8
- layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
9
- layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
10
-
11
- layout(push_constant) uniform PushConstants {
12
- uint inAOff;
13
- uint inBOff;
14
- uint outOff;
15
- uint row;
16
- } pcs;
17
-
18
- void main() {
19
- const uint baseIndex = gl_WorkGroupID.x * 4;
20
-
21
- for (uint x = 0; x < 4; x++) {
22
- const uint i = baseIndex + x;
23
- out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
24
- }
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_cpy_f16_f16.comp DELETED
@@ -1,52 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define IN_TYPE float16_t
6
- #define IN_TYPE_SIZE 2
7
- #define OUT_TYPE float16_t
8
- #define OUT_TYPE_SIZE 2
9
-
10
- layout(local_size_x = 1024) in;
11
-
12
- layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
13
- layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inOff;
17
- uint outOff;
18
- int ne00;
19
- int ne01;
20
- int ne02;
21
- uint nb00;
22
- uint nb01;
23
- uint nb02;
24
- uint nb03;
25
- int ne0;
26
- int ne1;
27
- int ne2;
28
- uint nb0;
29
- uint nb1;
30
- uint nb2;
31
- uint nb3;
32
- } pcs;
33
-
34
- void main() {
35
- const uint i03 = gl_WorkGroupID.z;
36
- const uint i02 = gl_WorkGroupID.y;
37
- const uint i01 = gl_WorkGroupID.x;
38
-
39
- const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
40
-
41
- const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
42
- const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
43
- const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
44
- const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
45
-
46
- const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
47
-
48
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
49
- const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
50
- out_[dst_data+i00] = OUT_TYPE(in_[src]);
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_cpy_f16_f32.comp DELETED
@@ -1,52 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define IN_TYPE float16_t
6
- #define IN_TYPE_SIZE 2
7
- #define OUT_TYPE float
8
- #define OUT_TYPE_SIZE 4
9
-
10
- layout(local_size_x = 1024) in;
11
-
12
- layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
13
- layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inOff;
17
- uint outOff;
18
- int ne00;
19
- int ne01;
20
- int ne02;
21
- uint nb00;
22
- uint nb01;
23
- uint nb02;
24
- uint nb03;
25
- int ne0;
26
- int ne1;
27
- int ne2;
28
- uint nb0;
29
- uint nb1;
30
- uint nb2;
31
- uint nb3;
32
- } pcs;
33
-
34
- void main() {
35
- const uint i03 = gl_WorkGroupID.z;
36
- const uint i02 = gl_WorkGroupID.y;
37
- const uint i01 = gl_WorkGroupID.x;
38
-
39
- const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
40
-
41
- const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
42
- const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
43
- const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
44
- const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
45
-
46
- const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
47
-
48
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
49
- const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
50
- out_[dst_data+i00] = OUT_TYPE(in_[src]);
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_cpy_f32_f16.comp DELETED
@@ -1,52 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define IN_TYPE float
6
- #define IN_TYPE_SIZE 4
7
- #define OUT_TYPE float16_t
8
- #define OUT_TYPE_SIZE 2
9
-
10
- layout(local_size_x = 1024) in;
11
-
12
- layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
13
- layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inOff;
17
- uint outOff;
18
- int ne00;
19
- int ne01;
20
- int ne02;
21
- uint nb00;
22
- uint nb01;
23
- uint nb02;
24
- uint nb03;
25
- int ne0;
26
- int ne1;
27
- int ne2;
28
- uint nb0;
29
- uint nb1;
30
- uint nb2;
31
- uint nb3;
32
- } pcs;
33
-
34
- void main() {
35
- const uint i03 = gl_WorkGroupID.z;
36
- const uint i02 = gl_WorkGroupID.y;
37
- const uint i01 = gl_WorkGroupID.x;
38
-
39
- const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
40
-
41
- const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
42
- const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
43
- const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
44
- const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
45
-
46
- const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
47
-
48
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
49
- const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
50
- out_[dst_data+i00] = OUT_TYPE(in_[src]);
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_cpy_f32_f32.comp DELETED
@@ -1,52 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define IN_TYPE float
6
- #define IN_TYPE_SIZE 4
7
- #define OUT_TYPE float
8
- #define OUT_TYPE_SIZE 4
9
-
10
- layout(local_size_x = 1024) in;
11
-
12
- layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
13
- layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inOff;
17
- uint outOff;
18
- int ne00;
19
- int ne01;
20
- int ne02;
21
- uint nb00;
22
- uint nb01;
23
- uint nb02;
24
- uint nb03;
25
- int ne0;
26
- int ne1;
27
- int ne2;
28
- uint nb0;
29
- uint nb1;
30
- uint nb2;
31
- uint nb3;
32
- } pcs;
33
-
34
- void main() {
35
- const uint i03 = gl_WorkGroupID.z;
36
- const uint i02 = gl_WorkGroupID.y;
37
- const uint i01 = gl_WorkGroupID.x;
38
-
39
- const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
40
-
41
- const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
42
- const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
43
- const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
44
- const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
45
-
46
- const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
47
-
48
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
49
- const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
50
- out_[dst_data+i00] = OUT_TYPE(in_[src]);
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_diagmask.comp DELETED
@@ -1,30 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
-
10
- layout(push_constant) uniform PushConstants {
11
- uint inOff;
12
- uint outOff;
13
- uint n_past;
14
- int ne00;
15
- int ne01;
16
- } pcs;
17
-
18
- void main() {
19
- const uint i02 = gl_WorkGroupID.z;
20
- const uint i01 = gl_WorkGroupID.y;
21
- const uint i00 = gl_WorkGroupID.x;
22
-
23
- const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
24
-
25
- if (i00 > pcs.n_past + i01) {
26
- out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
27
- } else {
28
- out_[index + pcs.outOff] = in_[index + pcs.inOff];
29
- }
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_gelu.comp DELETED
@@ -1,22 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
- layout(push_constant) uniform PushConstants {
10
- uint inOff;
11
- uint outOff;
12
- } pcs;
13
-
14
- void main() {
15
- const uint baseIndex = gl_WorkGroupID.x * 8;
16
-
17
- for (uint x = 0; x < 8; x++) {
18
- const uint i = baseIndex + x;
19
- const float y = in_[i + pcs.inOff];
20
- out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows.comp DELETED
@@ -1,17 +0,0 @@
1
- void main() {
2
- const uint i = gl_WorkGroupID.x;
3
- const int r = inB[i + pcs.inBOff];
4
-
5
- int z = 0;
6
- for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
7
- const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
8
- const mat4 result = dequantize_block(inIndex, ind%NL);
9
- for (uint j = 0; j < 4; ++j) {
10
- for (uint k = 0; k < 4; ++k) {
11
- const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
12
- out_[outIndex] = result[j][k];
13
- ++z;
14
- }
15
- }
16
- }
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows_f16.comp DELETED
@@ -1,31 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
8
- layout (binding = 1) readonly buffer tensorInB { int inB[]; };
9
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
10
-
11
- layout (push_constant) uniform parameter {
12
- uint inAOff;
13
- uint inBOff;
14
- uint outOff;
15
- int ne00;
16
- int nb01;
17
- int nb1;
18
- } pcs;
19
-
20
- void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
21
- for (int j = 0; j < k; j++) {
22
- out_[y + j] = inA[x + j];
23
- }
24
- }
25
-
26
- void main() {
27
- const uint i = gl_WorkGroupID.x;
28
- const int r = inB[i + pcs.inBOff];
29
-
30
- dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows_f32.comp DELETED
@@ -1,31 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout (binding = 0) readonly buffer tensorInA { float inA[]; };
8
- layout (binding = 1) readonly buffer tensorInB { int inB[]; };
9
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
10
-
11
- layout (push_constant) uniform parameter {
12
- uint inAOff;
13
- uint inBOff;
14
- uint outOff;
15
- int ne00;
16
- int nb01;
17
- int nb1;
18
- } pcs;
19
-
20
- void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
21
- for (int j = 0; j < k; j++) {
22
- out_[y + j] = inA[x + j];
23
- }
24
- }
25
-
26
- void main() {
27
- const uint i = gl_WorkGroupID.x;
28
- const int r = inB[i + pcs.inBOff];
29
-
30
- dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows_q4_0.comp DELETED
@@ -1,38 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define NL 2
6
- #define BYTES_FOR_TYPE 4 /*bytes for float*/
7
- #define SIZE_OF_BLOCK sizeof_block_q4_0
8
-
9
- layout(local_size_x = 1) in;
10
-
11
- layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
12
- layout (binding = 1) readonly buffer tensorInB { int inB[]; };
13
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inAOff;
17
- uint inBOff;
18
- uint outOff;
19
- int ne00;
20
- int nb01;
21
- int nb1;
22
- } pcs;
23
-
24
- block_q4_0 get_unaligned_block_q4_0(uint index) {
25
- block_q4_0 fres;
26
- fres.d = u8BufToFloat16(inA, index);
27
- [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
28
- fres.qs[it] = inA[index+2+it];
29
- }
30
- return fres;
31
- }
32
-
33
- mat4 dequantize_block(uint index, uint il) {
34
- const block_q4_0 block = get_unaligned_block_q4_0(index);
35
- return dequantize_q4_0(block, il);
36
- }
37
-
38
- #include "op_getrows.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows_q4_1.comp DELETED
@@ -1,39 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define NL 2
6
- #define BYTES_FOR_TYPE 4 /*bytes for float*/
7
- #define SIZE_OF_BLOCK sizeof_block_q4_1
8
-
9
- layout(local_size_x = 1) in;
10
-
11
- layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
12
- layout (binding = 1) readonly buffer tensorInB { int inB[]; };
13
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inAOff;
17
- uint inBOff;
18
- uint outOff;
19
- int ne00;
20
- int nb01;
21
- int nb1;
22
- } pcs;
23
-
24
- block_q4_1 get_unaligned_block_q4_1(uint index) {
25
- block_q4_1 fres;
26
- fres.d = u8BufToFloat16(inA, index);
27
- fres.m = u8BufToFloat16(inA, index+2);
28
- [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
29
- fres.qs[it] = inA[index+4+it];
30
- }
31
- return fres;
32
- }
33
-
34
- mat4 dequantize_block(uint index, uint il) {
35
- const block_q4_1 block = get_unaligned_block_q4_1(index);
36
- return dequantize_q4_1(block, il);
37
- }
38
-
39
- #include "op_getrows.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_getrows_q6_k.comp DELETED
@@ -1,44 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define NL 16
6
- #define BYTES_FOR_TYPE 4 /*bytes for float*/
7
- #define SIZE_OF_BLOCK sizeof_block_q6_k
8
-
9
- layout(local_size_x = 1) in;
10
-
11
- layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
12
- layout (binding = 1) readonly buffer tensorInB { int inB[]; };
13
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inAOff;
17
- uint inBOff;
18
- uint outOff;
19
- int ne00;
20
- int nb01;
21
- int nb1;
22
- } pcs;
23
-
24
- block_q6_k get_unaligned_block_q6_k(uint index) {
25
- block_q6_k fres;
26
- [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
27
- fres.ql[it] = inA[index + it];
28
- }
29
- [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
30
- fres.qh[it] = inA[index + QK_K/2 + it];
31
- }
32
- [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
33
- fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
34
- }
35
- fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
36
- return fres;
37
- }
38
-
39
- mat4 dequantize_block(uint index, uint il) {
40
- const block_q6_k block = get_unaligned_block_q6_k(index);
41
- return dequantize_q6_k(block, il);
42
- }
43
-
44
- #include "op_getrows.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul.comp DELETED
@@ -1,52 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1024) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
8
- layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
9
- layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
10
-
11
- layout(push_constant) uniform PushConstants {
12
- uint inAOff;
13
- uint inBOff;
14
- uint outOff;
15
- int ne00;
16
- int nb00;
17
- int nb01;
18
- int nb02;
19
- int nb03;
20
- int ne10;
21
- int ne11;
22
- int ne12;
23
- int ne13;
24
- int nb10;
25
- int nb11;
26
- int nb12;
27
- int nb13;
28
- int ne0;
29
- int nb0;
30
- int nb1;
31
- int nb2;
32
- int nb3;
33
- } pcs;
34
-
35
- void main() {
36
- const uint i03 = gl_WorkGroupID.z;
37
- const uint i02 = gl_WorkGroupID.y;
38
- const uint i01 = gl_WorkGroupID.x;
39
-
40
- const uint i13 = i03 % pcs.ne13;
41
- const uint i12 = i02 % pcs.ne12;
42
- const uint i11 = i01 % pcs.ne11;
43
-
44
- uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
45
- uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
46
- uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4);
47
-
48
- for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
49
- const uint i10 = i0 % pcs.ne10;
50
- out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_f16.comp DELETED
@@ -1,67 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #extension GL_KHR_shader_subgroup_arithmetic : require
6
-
7
- layout(local_size_x_id = 0) in;
8
-
9
- layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
10
- layout (binding = 1) readonly buffer tensorInB { float inB[]; };
11
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
12
-
13
- layout (push_constant) uniform parameter {
14
- uint inAOff;
15
- uint inBOff;
16
- uint outOff;
17
- int ne00;
18
- int ne01;
19
- int ne02;
20
- uint nb00;
21
- uint nb01;
22
- uint nb02;
23
- int ne10;
24
- int ne11;
25
- int ne12;
26
- uint nb10;
27
- uint nb11;
28
- uint nb12;
29
- int ne0;
30
- int ne1;
31
- uint r2;
32
- uint r3;
33
- } pcs;
34
-
35
- #define N_F16_F32 4
36
-
37
- void main() {
38
- const uint r0 = gl_WorkGroupID.x;
39
- const uint rb = gl_WorkGroupID.y*N_F16_F32;
40
- const uint im = gl_WorkGroupID.z;
41
-
42
- const uint i12 = im%pcs.ne12;
43
- const uint i13 = im/pcs.ne12;
44
-
45
- const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
46
-
47
- const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
48
-
49
- for (uint row = 0; row < N_F16_F32; ++row) {
50
- uint r1 = rb + row;
51
- if (r1 >= pcs.ne11) {
52
- break;
53
- }
54
-
55
- const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
56
-
57
- float sumf = 0;
58
- for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
59
- sumf += float(inA[x+i]) * float(inB[y+i]);
60
- }
61
-
62
- const float all_sum = subgroupAdd(sumf);
63
- if (subgroupElect()) {
64
- out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
65
- }
66
- }
67
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp DELETED
@@ -1,51 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #extension GL_KHR_shader_subgroup_arithmetic : require
6
- #extension GL_EXT_debug_printf : enable
7
-
8
- // device subgroup size
9
- layout (local_size_x_id = 0) in;
10
-
11
- layout(binding = 0) readonly buffer tensorInA { float inA[]; };
12
- layout(binding = 1) readonly buffer tensorInB { float inB[]; };
13
- layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
14
-
15
- layout(push_constant) uniform parameter {
16
- uint inAOff;
17
- uint inBOff;
18
- uint outOff;
19
- int ne00;
20
- int ne01;
21
- int ne02;
22
- int ne11;
23
- int ne12;
24
- uint nb01;
25
- uint nb02;
26
- uint nb11;
27
- uint nb12;
28
- uint nb1;
29
- uint nb2;
30
- }
31
- pcs;
32
-
33
-
34
- void main() {
35
- uvec3 gid = gl_WorkGroupID;
36
-
37
- uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
38
- uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
39
-
40
- const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
41
- const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
42
- float sum = 0.0f;
43
- for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
44
- sum += float(inA[x+i]) * float(inB[y+i]);
45
- }
46
-
47
- const float all_sum = subgroupAdd(sum);
48
- if (subgroupElect()) {
49
- out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
50
- }
51
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_q4_0.comp DELETED
@@ -1,33 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define BLOCKS_IN_QUANT QK4_0
6
- #define SIZE_OF_BLOCK sizeof_block_q4_0
7
- #define N_ROWS 4
8
-
9
- #include "op_mul_mv_q_n_pre.comp"
10
-
11
- // The q4_0 version of this function
12
- float block_q_n_dot_y(uint block_index, uint yb, uint il) {
13
- vec2 acc = vec2(0.0, 0.0);
14
- const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
15
- float d = float(u8BufToFloat16(inA, index));
16
- float sumy = 0.0f;
17
- for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
18
- const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
19
-
20
- const float yl0 = inB[yb + i];
21
- const float yl1 = inB[yb + i + 1];
22
- const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
23
- const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
24
-
25
- sumy += yl0 + yl1 + yl8 + yl9;
26
-
27
- acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
28
- acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
29
- }
30
- return d * (sumy * -8.f + acc[0] + acc[1]);
31
- }
32
-
33
- #include "op_mul_mv_q_n.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_q4_1.comp DELETED
@@ -1,35 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define BLOCKS_IN_QUANT QK4_1
6
- #define SIZE_OF_BLOCK sizeof_block_q4_1
7
- #define N_ROWS 4
8
-
9
- #include "op_mul_mv_q_n_pre.comp"
10
-
11
- // The q4_1 version of this function
12
- float block_q_n_dot_y(uint block_index, uint yb, uint il) {
13
- vec2 acc = vec2(0.0, 0.0);
14
- const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
15
- float d = float(u8BufToFloat16(inA, index));
16
- float m = float(u8BufToFloat16(inA, index+2));
17
-
18
- float sumy = 0.0f;
19
- for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
20
- const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
21
-
22
- const float yl0 = inB[yb + i];
23
- const float yl1 = inB[yb + i + 1];
24
- const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
25
- const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
26
-
27
- sumy += yl0 + yl1 + yl8 + yl9;
28
-
29
- acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
30
- acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
31
- }
32
- return d * (acc[0] + acc[1]) + sumy * m;
33
- }
34
-
35
- #include "op_mul_mv_q_n.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_q6_k.comp DELETED
@@ -1,94 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #define SIZE_OF_BLOCK sizeof_block_q6_k
6
-
7
- layout(local_size_x_id = 0) in;
8
- layout(local_size_y_id = 1) in;
9
- layout(local_size_z = 1) in;
10
-
11
- layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
12
- layout (binding = 1) readonly buffer tensorInB { float inB[]; };
13
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
14
-
15
- layout (push_constant) uniform parameter {
16
- uint inAOff;
17
- uint inBOff;
18
- uint outOff;
19
- int ne00;
20
- int ne10;
21
- int ne0;
22
- int ne1;
23
- int ne01;
24
- int gqa;
25
- } pcs;
26
-
27
- void main() {
28
- const uint8_t kmask1 = uint8_t(0x03);
29
- const uint8_t kmask2 = uint8_t(0x0C);
30
- const uint8_t kmask3 = uint8_t(0x30);
31
- const uint8_t kmask4 = uint8_t(0xC0);
32
-
33
- const uint nb = pcs.ne00/QK_K;
34
-
35
- const uint r0 = gl_WorkGroupID.x;
36
- const uint r1 = gl_WorkGroupID.y;
37
- const uint r2 = gl_WorkGroupID.z;
38
-
39
- const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
40
- const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
41
- const uint x = row * nb + offset0; // Based from inA without base offset
42
- const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
43
-
44
- float sumf = 0;
45
-
46
- // bits of invocation ID for gl_SubgroupSize=32:
47
- // x x x x x
48
- // 4 3 2 1 0
49
- // ( tid ) ix
50
- // ip ( il )
51
-
52
- const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
53
- const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
54
- const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
55
- const uint ip = tid/8; // first or second half of block (0 or 1)
56
- const uint il = tid%8; // each half has 8 parts, one per scale
57
- const uint n = 4; // 4 scales at a time (and 4 sums)
58
- const uint l0 = n*il; // offset into half-block, 0..28
59
- const uint is = 8*ip + l0/16; // 0, 1, 8, 9
60
-
61
- const uint y_offset = 128*ip + l0;
62
- const uint q_offset_l = 64*ip + l0;
63
- const uint q_offset_h = 32*ip + l0;
64
-
65
- for (uint i = ix; i < nb; i += block_stride) {
66
-
67
- const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
68
-
69
- const uint qlIndex = q_offset_l;
70
- const uint q2Index = qlIndex + QK_K/8;
71
- const uint qhIndex = q_offset_h;
72
- const uint y = yy + i * QK_K + y_offset;
73
-
74
- float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
75
- for (uint l = 0; l < n; ++l) {
76
- const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
77
- const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
78
- const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
79
-
80
- sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
81
- sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
82
- sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
83
- sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
84
- }
85
-
86
- float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
87
- sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
88
- }
89
-
90
- const float tot = subgroupAdd(sumf);
91
- if (subgroupElect()) {
92
- out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
93
- }
94
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mat_q8_0.comp DELETED
@@ -1,73 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- #include "op_mul_mv_q_n_pre.comp"
6
-
7
- #define SIZE_OF_D 2
8
-
9
- #define N_DST 4 // each SIMD group works on 4 rows
10
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
11
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
12
-
13
- #define NB_Q8_0 8
14
-
15
- void main() {
16
- // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
17
- if (gl_SubgroupInvocationID > 31)
18
- return;
19
-
20
- const int nr = N_DST;
21
- const int nsg = N_SIMDGROUP;
22
- const int nw = N_SIMDWIDTH;
23
-
24
- const int nb = pcs.ne00/QK8_0;
25
- const uint r0 = gl_WorkGroupID.x;
26
- const uint r1 = gl_WorkGroupID.y;
27
- const uint im = gl_WorkGroupID.z;
28
-
29
- const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
30
-
31
- const uint i12 = im%pcs.ne12;
32
- const uint i13 = im/pcs.ne12;
33
-
34
- const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
35
-
36
- const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
37
- const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
38
-
39
- float yl[NB_Q8_0];
40
- float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
41
-
42
- const uint ix = gl_SubgroupInvocationID.x/4;
43
- const uint il = gl_SubgroupInvocationID.x%4;
44
-
45
- uint yb = y + ix * QK8_0 + NB_Q8_0*il;
46
-
47
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
48
- for (uint ib = ix; ib < nb; ib += nw/4) {
49
- for (int i = 0; i < NB_Q8_0; ++i) {
50
- yl[i] = inB[yb + i];
51
- }
52
-
53
- for (int row = 0; row < nr; row++) {
54
- const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
55
- float sumq = 0.f;
56
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
57
- const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
58
- sumq += qs_iq * yl[iq];
59
- }
60
- const float16_t d = u8BufToFloat16(inA, x + block_offset);
61
- sumf[row] += sumq*d;
62
- }
63
-
64
- yb += NB_Q8_0 * nw;
65
- }
66
-
67
- for (int row = 0; row < nr; ++row) {
68
- const float tot = subgroupAdd(sumf[row]);
69
- if (subgroupElect() && first_row + row < pcs.ne01) {
70
- out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
71
- }
72
- }
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mv_q_n.comp DELETED
@@ -1,48 +0,0 @@
1
- void main() {
2
- // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
3
- if (gl_SubgroupInvocationID > 31)
4
- return;
5
-
6
- const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
7
-
8
- const uint r0 = gl_WorkGroupID.x;
9
- const uint r1 = gl_WorkGroupID.y;
10
- const uint im = gl_WorkGroupID.z;
11
-
12
- const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
13
-
14
- const uint i12 = im%pcs.ne12;
15
- const uint i13 = im/pcs.ne12;
16
-
17
- const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
18
-
19
- const uint x = offset0; // Based from inA without base offset
20
- const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
21
-
22
- float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
23
-
24
- const uint ix = gl_SubgroupInvocationID/2;
25
- const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
26
-
27
- uint yb = y + ix * BLOCKS_IN_QUANT + il;
28
-
29
- //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
30
- // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
31
- // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
32
-
33
- for (uint ib = ix; ib < nb; ib += 16) {
34
- for (int row = 0; row < N_ROWS; row++) {
35
- const uint block_index = x + ib + row * nb;
36
- sumf[row] += block_q_n_dot_y(block_index, yb, il);
37
- }
38
-
39
- yb += BLOCKS_IN_QUANT * 16;
40
- }
41
-
42
- for (int row = 0; row < N_ROWS; ++row) {
43
- const float tot = subgroupAdd(sumf[row]);
44
- if (first_row + row < pcs.ne01 && subgroupElect()) {
45
- out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
46
- }
47
- }
48
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp DELETED
@@ -1,22 +0,0 @@
1
- layout(local_size_x_id = 0) in;
2
- layout(local_size_y = 1) in;
3
- layout(local_size_z = 1) in;
4
-
5
- layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
6
- layout (binding = 1) readonly buffer tensorInB { float inB[]; };
7
- layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
8
-
9
- layout (push_constant) uniform parameter {
10
- uint inAOff;
11
- uint inBOff;
12
- uint outOff;
13
- int ne00;
14
- int ne01;
15
- int ne02;
16
- int ne10;
17
- int ne12;
18
- int ne0;
19
- int ne1;
20
- uint r2;
21
- uint r3;
22
- } pcs;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_norm.comp DELETED
@@ -1,84 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 256) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict tensorOut { float out_[]; };
9
-
10
- layout(push_constant) uniform PushConstants {
11
- uint inOff;
12
- uint outOff;
13
- uint ne00;
14
- uint nb01;
15
- float eps;
16
- } pcs;
17
-
18
- shared float sum[gl_WorkGroupSize.x];
19
-
20
- void main() {
21
- const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
22
- // MEAN
23
- // parallel sum
24
- sum[gl_LocalInvocationID.x] = 0.0;
25
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
26
- sum[gl_LocalInvocationID.x] += in_[x+i00];
27
- }
28
-
29
- // reduce
30
- barrier();
31
- memoryBarrierShared();
32
- [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
33
- if (gl_LocalInvocationID.x < i) {
34
- sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
35
- }
36
- barrier();
37
- memoryBarrierShared();
38
- }
39
-
40
- // broadcast
41
- if (gl_LocalInvocationID.x == 0) {
42
- sum[0] /= float(pcs.ne00);
43
- }
44
- barrier();
45
- memoryBarrierShared();
46
- const float mean = sum[0];
47
-
48
- // recenter
49
- const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
50
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
51
- out_[y+i00] = in_[x+i00] - mean;
52
- }
53
-
54
- // VARIANCE
55
- // parallel sum
56
- sum[gl_LocalInvocationID.x] = 0.0;
57
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
58
- sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
59
- }
60
-
61
- // reduce
62
- barrier();
63
- memoryBarrierShared();
64
- [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
65
- if (gl_LocalInvocationID.x < i) {
66
- sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
67
- }
68
- barrier();
69
- memoryBarrierShared();
70
- }
71
-
72
- // broadcast
73
- if (gl_LocalInvocationID.x == 0) {
74
- sum[0] /= float(pcs.ne00);
75
- }
76
- barrier();
77
- memoryBarrierShared();
78
- const float variance = sum[0];
79
-
80
- const float scale = 1.0f/sqrt(variance + pcs.eps);
81
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
82
- out_[y+i00] *= scale;
83
- }
84
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_relu.comp DELETED
@@ -1,21 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
- layout(push_constant) uniform PushConstants {
10
- uint inOff;
11
- uint outOff;
12
- } pcs;
13
-
14
- void main() {
15
- const uint baseIndex = gl_WorkGroupID.x * 4;
16
-
17
- for (uint x = 0; x < 4; x++) {
18
- const uint i = baseIndex + x;
19
- out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_rmsnorm.comp DELETED
@@ -1,53 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 512) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict tensorOut { float out_[]; };
9
-
10
- layout(push_constant) uniform PushConstants {
11
- uint inOff;
12
- uint outOff;
13
- uint ne00;
14
- uint nb01;
15
- float eps;
16
- } pcs;
17
-
18
- shared float sum[gl_WorkGroupSize.x];
19
-
20
- void main() {
21
- const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
22
-
23
- // parallel sum
24
- sum[gl_LocalInvocationID.x] = 0.0;
25
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
26
- sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
27
- }
28
-
29
- // reduce
30
- barrier();
31
- memoryBarrierShared();
32
- [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
33
- if (gl_LocalInvocationID.x < i) {
34
- sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
35
- }
36
- barrier();
37
- memoryBarrierShared();
38
- }
39
-
40
- // broadcast
41
- if (gl_LocalInvocationID.x == 0) {
42
- sum[0] /= float(pcs.ne00);
43
- }
44
- barrier();
45
- memoryBarrierShared();
46
-
47
- const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
48
-
49
- const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
50
- for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
51
- out_[y+i00] = in_[x+i00] * scale;
52
- }
53
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_rope_f16.comp DELETED
@@ -1,73 +0,0 @@
1
- #version 450
2
-
3
- #include "rope_common.comp"
4
-
5
- layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
6
- layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
7
- layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
8
-
9
- void main() {
10
- const uint i3 = gl_WorkGroupID.z;
11
- const uint i2 = gl_WorkGroupID.y;
12
- const uint i1 = gl_WorkGroupID.x;
13
-
14
- const bool is_neox = (pcs.mode & 2) != 0;
15
-
16
- float corr_dims[2];
17
- rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
18
-
19
- const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
20
-
21
- const int p = inB[pcs.inBOff + i2];
22
-
23
- float theta = float(p);
24
-
25
- if (!is_neox) {
26
- for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
27
- float cos_theta, sin_theta;
28
- rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
29
-
30
- theta *= theta_scale;
31
-
32
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
33
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
34
-
35
- const float x0 = float(inA[src]);
36
- const float x1 = float(inA[src+1]);
37
-
38
- out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
39
- out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
40
- }
41
- } else {
42
- const float inv_ndims = -1.f/pcs.n_dims;
43
- for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
44
- const uint cur_rot = ic;
45
-
46
- float cos_theta, sin_theta;
47
- rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
48
-
49
- theta *= theta_scale;
50
-
51
- const uint i0 = ic/2;
52
-
53
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
54
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
55
-
56
- const float x0 = float(inA[src]);
57
- const float x1 = float(inA[src+pcs.n_dims/2]);
58
-
59
- out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
60
- out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
61
- }
62
-
63
- for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
64
- const uint i0 = ic;
65
-
66
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
67
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
68
-
69
- out_[dst_data + 0] = inA[src + 0];
70
- out_[dst_data + 1] = inA[src + 1];
71
- }
72
- }
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_rope_f32.comp DELETED
@@ -1,73 +0,0 @@
1
- #version 450
2
-
3
- #include "rope_common.comp"
4
-
5
- layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
6
- layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
7
- layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
8
-
9
- void main() {
10
- const uint i3 = gl_WorkGroupID.z;
11
- const uint i2 = gl_WorkGroupID.y;
12
- const uint i1 = gl_WorkGroupID.x;
13
-
14
- const bool is_neox = (pcs.mode & 2) != 0;
15
-
16
- float corr_dims[2];
17
- rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
18
-
19
- const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
20
-
21
- const int p = inB[pcs.inBOff + i2];
22
-
23
- float theta = float(p);
24
-
25
- if (!is_neox) {
26
- for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
27
- float cos_theta, sin_theta;
28
- rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
29
-
30
- theta *= theta_scale;
31
-
32
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
33
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
34
-
35
- const float x0 = inA[src];
36
- const float x1 = inA[src+1];
37
-
38
- out_[dst_data] = x0*cos_theta - x1*sin_theta;
39
- out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
40
- }
41
- } else {
42
- const float inv_ndims = -1.f/pcs.n_dims;
43
- for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
44
- const uint cur_rot = ic;
45
-
46
- float cos_theta, sin_theta;
47
- rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
48
-
49
- theta *= theta_scale;
50
-
51
- const uint i0 = ic/2;
52
-
53
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
54
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
55
-
56
- const float x0 = inA[src];
57
- const float x1 = inA[src+pcs.n_dims/2];
58
-
59
- out_[dst_data] = x0*cos_theta - x1*sin_theta;
60
- out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
61
- }
62
-
63
- for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
64
- const uint i0 = ic;
65
-
66
- const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
67
- const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
68
-
69
- out_[dst_data + 0] = inA[src + 0];
70
- out_[dst_data + 1] = inA[src + 1];
71
- }
72
- }
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_scale.comp DELETED
@@ -1,19 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
-
10
- layout(push_constant) uniform PushConstants {
11
- uint inOff;
12
- uint outOff;
13
- float scale;
14
- } pcs;
15
-
16
- void main() {
17
- const uint i = gl_WorkGroupID.x;
18
- out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_scale_8.comp DELETED
@@ -1,23 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
-
10
- layout(push_constant) uniform PushConstants {
11
- uint inOff;
12
- uint outOff;
13
- float scale;
14
- } pcs;
15
-
16
- void main() {
17
- const uint baseIndex = gl_WorkGroupID.x * 8;
18
-
19
- for (uint x = 0; x < 8; x++) {
20
- const uint i = baseIndex + x;
21
- out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_silu.comp DELETED
@@ -1,22 +0,0 @@
1
- #version 450
2
-
3
- #include "common.comp"
4
-
5
- layout(local_size_x = 1) in;
6
-
7
- layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
8
- layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
9
- layout(push_constant) uniform PushConstants {
10
- uint inOff;
11
- uint outOff;
12
- } pcs;
13
-
14
- void main() {
15
- const uint baseIndex = gl_WorkGroupID.x * 4;
16
-
17
- for (uint x = 0; x < 4; x++) {
18
- const uint i = baseIndex + x;
19
- const float y = in_[i + pcs.inOff];
20
- out_[i + pcs.outOff] = y / (1.0 + exp(-y));
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/kompute-shaders/op_softmax.comp DELETED
@@ -1,56 +0,0 @@
1
- // TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
2
-
3
- #version 450
4
-
5
- #include "common.comp"
6
-
7
- layout(local_size_x_id = 0) in;
8
-
9
- layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
10
- layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
11
- layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
12
-
13
- layout(push_constant) uniform PushConstants {
14
- uint inAOff;
15
- uint inBOff;
16
- uint outOff;
17
- int ne00;
18
- int ne01;
19
- int ne02;
20
- float scale;
21
- int mask;
22
- } pcs;
23
-
24
- void main() {
25
- if (gl_SubgroupInvocationID > 31)
26
- return;
27
-
28
- const uint i03 = gl_WorkGroupID.z;
29
- const uint i02 = gl_WorkGroupID.y;
30
- const uint i01 = gl_WorkGroupID.x;
31
-
32
- const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
33
- const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
34
- const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
35
- const uint pdst = extra_off + pcs.outOff; // Based from out_
36
-
37
- // parallel max
38
- float localMax = uintBitsToFloat(0xFF800000);
39
- for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
40
- localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
41
- }
42
- float max_ = subgroupMax(localMax);
43
-
44
- // parallel sum
45
- float localSum = 0.0f;
46
- for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
47
- const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
48
- localSum += exp_psrc0;
49
- out_[pdst + i00] = exp_psrc0;
50
- }
51
-
52
- const float sum = subgroupAdd(localSum);
53
- for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
54
- out_[pdst + i00] /= sum;
55
- }
56
- }