diff --git a/cuda_core/cuda/core/graph/_graph_node.pyx b/cuda_core/cuda/core/graph/_graph_node.pyx index c9d1786caa..e4e00d5c5f 100644 --- a/cuda_core/cuda/core/graph/_graph_node.pyx +++ b/cuda_core/cuda/core/graph/_graph_node.pyx @@ -6,6 +6,8 @@ from __future__ import annotations +from cpython.ref cimport Py_INCREF + from libc.stddef cimport size_t from libc.stdint cimport uintptr_t from libc.string cimport memset as c_memset @@ -54,6 +56,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value from cuda.core.graph._utils cimport ( _attach_host_callback_to_graph, _attach_user_object, + _py_host_destructor, ) import weakref @@ -617,6 +620,12 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) + cdef object kernel_args = ker_args.kernel_args + if kernel_args is not None: + Py_INCREF(kernel_args) + _attach_user_object(as_cu(h_graph), kernel_args, + _py_host_destructor) + return _registered(KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, diff --git a/cuda_core/cuda/core/graph/_utils.pxd b/cuda_core/cuda/core/graph/_utils.pxd index 63fdb00ac4..13d3742cc0 100644 --- a/cuda_core/cuda/core/graph/_utils.pxd +++ b/cuda_core/cuda/core/graph/_utils.pxd @@ -7,6 +7,8 @@ from cuda.bindings cimport cydriver cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil +cdef void _py_host_destructor(void* data) noexcept with gil + cdef void _attach_user_object( cydriver.CUgraph graph, void* ptr, cydriver.CUhostFn destroy) except * diff --git a/cuda_core/tests/graph/test_graph_definition_lifetime.py b/cuda_core/tests/graph/test_graph_definition_lifetime.py index e231016c8a..c53009a572 100644 --- a/cuda_core/tests/graph/test_graph_definition_lifetime.py +++ b/cuda_core/tests/graph/test_graph_definition_lifetime.py @@ -3,13 +3,34 @@ """Tests for GraphDefinition resource lifetime management and RAII correctness.""" +import ctypes import gc +import time +import weakref import pytest from helpers.graph_kernels import compile_common_kernels from helpers.misc import try_create_condition -from cuda.core import Device, EventOptions, Kernel, LaunchConfig + +def _wait_until(predicate, timeout=2.0, interval=0.01): + """Poll predicate() until True or timeout, driving gc each iteration. + + Used for assertions about resource cleanup that may be delayed by CUDA's + asynchronous user-object destructor pump (DPC) or, on free-threaded + Python, by deferred reference-count processing. A bounded poll keeps the + test correct without depending on undocumented driver timing guarantees. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + gc.collect() + if predicate(): + return + time.sleep(interval) + raise AssertionError(f"condition not satisfied within {timeout}s") + + +from cuda.core import Device, DeviceMemoryResource, EventOptions, Kernel, LaunchConfig from cuda.core.graph import ( ChildGraphNode, ConditionalNode, @@ -485,3 +506,87 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda): stream = Device().create_stream() graph.launch(stream) stream.sync() + + +# ============================================================================= +# Kernel argument lifetime — kernel nodes should keep argument objects alive +# ============================================================================= + + +def test_kernel_args_buffer_lifetime(init_cuda): + """Buffer passed as a kernel arg is kept alive by the graph, the kernel + executes against its memory after the original Python ref drops, and the + Buffer is released once the graph is destroyed. + + Without the user-object attachment, the ParamHolder is destroyed when the + kernel node is added, the Buffer is GC'd, and the graph is left with a + stale device pointer. + + The final freeing assertion uses a bounded poll because CUgraphExec + releases its user-object references via an asynchronous DPC, and on + free-threaded Python the resulting Py_DECREF chain may need an extra + GC pass to settle. + """ + from cuda.core._utils.cuda_utils import driver, handle_return + + _skip_if_no_mempool() + dev = Device() + mr = DeviceMemoryResource(dev) + add_one = compile_common_kernels().get_kernel("add_one") + buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream) + buf.fill(0, stream=dev.default_stream) + dev.default_stream.sync() + buf_weak = weakref.ref(buf) + dptr = int(buf.handle) + + g = GraphDefinition() + g.launch(LaunchConfig(grid=1, block=1), add_one, buf) + + del buf + gc.collect() + assert buf_weak() is not None # graph kept the Buffer alive + + stream = dev.create_stream() + g.instantiate().launch(stream) + stream.sync() + + out = (ctypes.c_int * 1)(0) + handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int))) + assert out[0] == 1 + + del g + _wait_until(lambda: buf_weak() is None) + + +def test_kernel_args_survive_graph_clone(init_cuda): + """Cloned graph keeps Buffer alive via CUDA user objects. + + A graph clone does not inherit Python-level references, so only user + objects (which propagate through cuGraphClone) can keep the args alive. + """ + from cuda.core._utils.cuda_utils import driver, handle_return + + _skip_if_no_mempool() + dev = Device() + mr = DeviceMemoryResource(dev) + add_one = compile_common_kernels().get_kernel("add_one") + buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream) + buf.fill(0, stream=dev.default_stream) + dev.default_stream.sync() + dptr = int(buf.handle) + + g = GraphDefinition() + g.launch(LaunchConfig(grid=1, block=1), add_one, buf) + cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle))) + + del buf, g + gc.collect() + + graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0)) + stream = dev.create_stream() + handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle)))) + stream.sync() + + out = (ctypes.c_int * 1)(0) + handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int))) + assert out[0] == 1