| import torch |
| import importlib |
| from triton_kernels.specialize import cacheable, specialize |
| import triton |
| import triton.language as tl |
|
|
|
|
| @triton.jit |
| def template_kernel(o): |
| cst = 1.0 |
| tl.store(o, cst) |
|
|
|
|
| def retrieve_fn(module, name): |
| module = importlib.import_module(module) |
| fn = getattr(module, name) |
| return fn |
|
|
|
|
| _specialized_kernel = None |
|
|
|
|
| def get_specialized_kernel(): |
| global _specialized_kernel |
| if _specialized_kernel is not None: |
| return _specialized_kernel |
| import types |
| spec_constants = {} |
| spec_tuples = {} |
| module = types.ModuleType("specialized_kernel") |
| module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples) |
| _specialized_kernel = module.specialized |
| return _specialized_kernel |
|
|
|
|
| @cacheable |
| def cacheable_kernel(): |
| return get_specialized_kernel() |
|
|
|
|
| def test_cacheable(device, fresh_knobs): |
| specialized_kernel = get_specialized_kernel() |
|
|
| specialization_data = None |
| fn_name = None |
| module_name = None |
|
|
| def cache_hook(*args, **kwargs): |
| nonlocal specialization_data |
| nonlocal fn_name |
| nonlocal module_name |
| specialization_data = kwargs["compile"]["specialization_data"] |
| fn_name = kwargs["fn"].name |
| module_name = kwargs["fn"].module |
|
|
| triton.knobs.runtime.jit_cache_hook = cache_hook |
| o = torch.empty((1, ), dtype=torch.float32, device=device) |
| k = specialized_kernel[(1, )](o, ) |
| hash = k.hash |
| assert o.item() == 1.0 |
| assert module_name == "tests.test_specialize" |
| assert fn_name == "cacheable_kernel" |
|
|
| compile_count = 0 |
|
|
| def count_hook(*args, **kwargs): |
| nonlocal compile_count |
| compile_count += 1 |
|
|
| triton.knobs.runtime.jit_cache_hook = count_hook |
| |
| specialized_kernel.device_caches.clear() |
|
|
| |
| fn = retrieve_fn(module_name, fn_name) |
| assert fn == specialized_kernel |
| preload = fn.preload(specialization_data) |
| assert compile_count == 1 |
| assert preload.hash == hash |
|
|
| |
| compile_count = 0 |
| specialized_kernel[(1, )](o, ) |
| assert compile_count == 0 |
|
|