Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cuda_core/cuda/core/graph/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from cpython.ref cimport Py_INCREF

from libc.stddef cimport size_t
from libc.stdint cimport uintptr_t
from libc.string cimport memset as c_memset
Expand Down Expand Up @@ -54,6 +56,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value
from cuda.core.graph._utils cimport (
_attach_host_callback_to_graph,
_attach_user_object,
_py_host_destructor,
)

import weakref
Expand Down Expand Up @@ -617,6 +620,12 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
<cydriver.CUhostFn>_destroy_kernel_handle_copy)

cdef object kernel_args = ker_args.kernel_args
if kernel_args is not None:
Py_INCREF(kernel_args)
_attach_user_object(as_cu(h_graph), <void*>kernel_args,
<cydriver.CUhostFn>_py_host_destructor)

return _registered(KernelNode._create_with_params(
create_graph_node_handle(new_node, h_graph),
conf.grid, conf.block, conf.shmem_size,
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/graph/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ from cuda.bindings cimport cydriver

cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil

cdef void _py_host_destructor(void* data) noexcept with gil

cdef void _attach_user_object(
cydriver.CUgraph graph, void* ptr,
cydriver.CUhostFn destroy) except *
Expand Down
107 changes: 106 additions & 1 deletion cuda_core/tests/graph/test_graph_definition_lifetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,34 @@

"""Tests for GraphDefinition resource lifetime management and RAII correctness."""

import ctypes
import gc
import time
import weakref

import pytest
from helpers.graph_kernels import compile_common_kernels
from helpers.misc import try_create_condition

from cuda.core import Device, EventOptions, Kernel, LaunchConfig

def _wait_until(predicate, timeout=2.0, interval=0.01):
"""Poll predicate() until True or timeout, driving gc each iteration.

Used for assertions about resource cleanup that may be delayed by CUDA's
asynchronous user-object destructor pump (DPC) or, on free-threaded
Python, by deferred reference-count processing. A bounded poll keeps the
test correct without depending on undocumented driver timing guarantees.
"""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
gc.collect()
if predicate():
return
time.sleep(interval)
raise AssertionError(f"condition not satisfied within {timeout}s")
Comment on lines +16 to +30
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this still welcome flakiness? I am concerned about this being tested in SWQA hands



from cuda.core import Device, DeviceMemoryResource, EventOptions, Kernel, LaunchConfig
from cuda.core.graph import (
ChildGraphNode,
ConditionalNode,
Expand Down Expand Up @@ -485,3 +506,87 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda):
stream = Device().create_stream()
graph.launch(stream)
stream.sync()


# =============================================================================
# Kernel argument lifetime — kernel nodes should keep argument objects alive
# =============================================================================


def test_kernel_args_buffer_lifetime(init_cuda):
"""Buffer passed as a kernel arg is kept alive by the graph, the kernel
executes against its memory after the original Python ref drops, and the
Buffer is released once the graph is destroyed.

Without the user-object attachment, the ParamHolder is destroyed when the
kernel node is added, the Buffer is GC'd, and the graph is left with a
stale device pointer.

The final freeing assertion uses a bounded poll because CUgraphExec
releases its user-object references via an asynchronous DPC, and on
free-threaded Python the resulting Py_DECREF chain may need an extra
GC pass to settle.
"""
from cuda.core._utils.cuda_utils import driver, handle_return
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move imports to the top, no need to defer import to here


_skip_if_no_mempool()
dev = Device()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dev = Device()
dev = init_cuda

mr = DeviceMemoryResource(dev)
add_one = compile_common_kernels().get_kernel("add_one")
buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream)
buf.fill(0, stream=dev.default_stream)
dev.default_stream.sync()
buf_weak = weakref.ref(buf)
dptr = int(buf.handle)

g = GraphDefinition()
g.launch(LaunchConfig(grid=1, block=1), add_one, buf)

del buf
gc.collect()
assert buf_weak() is not None # graph kept the Buffer alive
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test prove the buffer is kept alive, but it doesn't validate that its cleaned up after the graph is released.

Copy link
Copy Markdown
Contributor Author

@Andy-Jost Andy-Jost May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for this. If it is flakey, we might need to adjust the CU_USER_OBJECT_NO_DESTRUCTOR_SYNC flag so that graph destructors cannot be invoked asynchronously.

Update: I confirmed this is not a concern for source graphs. Asynchronous destruction only comes into play for exec graphs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test creates an exec graph, so there is a race. CI for free-threaded Python seems more likely to trigger it. 9f2c8f2 adds polling, but removing the test would also be defensible.


stream = dev.create_stream()
g.instantiate().launch(stream)
stream.sync()

out = (ctypes.c_int * 1)(0)
handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int)))
assert out[0] == 1

del g
_wait_until(lambda: buf_weak() is None)


def test_kernel_args_survive_graph_clone(init_cuda):
"""Cloned graph keeps Buffer alive via CUDA user objects.

A graph clone does not inherit Python-level references, so only user
objects (which propagate through cuGraphClone) can keep the args alive.
"""
from cuda.core._utils.cuda_utils import driver, handle_return

_skip_if_no_mempool()
dev = Device()
mr = DeviceMemoryResource(dev)
add_one = compile_common_kernels().get_kernel("add_one")
buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream)
buf.fill(0, stream=dev.default_stream)
dev.default_stream.sync()
dptr = int(buf.handle)

g = GraphDefinition()
g.launch(LaunchConfig(grid=1, block=1), add_one, buf)
cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle)))

del buf, g
gc.collect()

graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0))
stream = dev.create_stream()
handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle))))
stream.sync()

out = (ctypes.c_int * 1)(0)
handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int)))
assert out[0] == 1
Loading