-
Notifications
You must be signed in to change notification settings - Fork 279
cuda.core: keep kernel-argument objects alive in graph kernel nodes #2041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,13 +3,34 @@ | |||||
|
|
||||||
| """Tests for GraphDefinition resource lifetime management and RAII correctness.""" | ||||||
|
|
||||||
| import ctypes | ||||||
| import gc | ||||||
| import time | ||||||
| import weakref | ||||||
|
|
||||||
| import pytest | ||||||
| from helpers.graph_kernels import compile_common_kernels | ||||||
| from helpers.misc import try_create_condition | ||||||
|
|
||||||
| from cuda.core import Device, EventOptions, Kernel, LaunchConfig | ||||||
|
|
||||||
| def _wait_until(predicate, timeout=2.0, interval=0.01): | ||||||
| """Poll predicate() until True or timeout, driving gc each iteration. | ||||||
|
|
||||||
| Used for assertions about resource cleanup that may be delayed by CUDA's | ||||||
| asynchronous user-object destructor pump (DPC) or, on free-threaded | ||||||
| Python, by deferred reference-count processing. A bounded poll keeps the | ||||||
| test correct without depending on undocumented driver timing guarantees. | ||||||
| """ | ||||||
| deadline = time.monotonic() + timeout | ||||||
| while time.monotonic() < deadline: | ||||||
| gc.collect() | ||||||
| if predicate(): | ||||||
| return | ||||||
| time.sleep(interval) | ||||||
| raise AssertionError(f"condition not satisfied within {timeout}s") | ||||||
|
|
||||||
|
|
||||||
| from cuda.core import Device, DeviceMemoryResource, EventOptions, Kernel, LaunchConfig | ||||||
| from cuda.core.graph import ( | ||||||
| ChildGraphNode, | ||||||
| ConditionalNode, | ||||||
|
|
@@ -485,3 +506,87 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda): | |||||
| stream = Device().create_stream() | ||||||
| graph.launch(stream) | ||||||
| stream.sync() | ||||||
|
|
||||||
|
|
||||||
| # ============================================================================= | ||||||
| # Kernel argument lifetime — kernel nodes should keep argument objects alive | ||||||
| # ============================================================================= | ||||||
|
|
||||||
|
|
||||||
| def test_kernel_args_buffer_lifetime(init_cuda): | ||||||
| """Buffer passed as a kernel arg is kept alive by the graph, the kernel | ||||||
| executes against its memory after the original Python ref drops, and the | ||||||
| Buffer is released once the graph is destroyed. | ||||||
|
|
||||||
| Without the user-object attachment, the ParamHolder is destroyed when the | ||||||
| kernel node is added, the Buffer is GC'd, and the graph is left with a | ||||||
| stale device pointer. | ||||||
|
|
||||||
| The final freeing assertion uses a bounded poll because CUgraphExec | ||||||
| releases its user-object references via an asynchronous DPC, and on | ||||||
| free-threaded Python the resulting Py_DECREF chain may need an extra | ||||||
| GC pass to settle. | ||||||
| """ | ||||||
| from cuda.core._utils.cuda_utils import driver, handle_return | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move imports to the top, no need to defer import to here |
||||||
|
|
||||||
| _skip_if_no_mempool() | ||||||
| dev = Device() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| mr = DeviceMemoryResource(dev) | ||||||
| add_one = compile_common_kernels().get_kernel("add_one") | ||||||
| buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream) | ||||||
| buf.fill(0, stream=dev.default_stream) | ||||||
| dev.default_stream.sync() | ||||||
| buf_weak = weakref.ref(buf) | ||||||
| dptr = int(buf.handle) | ||||||
|
|
||||||
| g = GraphDefinition() | ||||||
| g.launch(LaunchConfig(grid=1, block=1), add_one, buf) | ||||||
|
|
||||||
| del buf | ||||||
| gc.collect() | ||||||
| assert buf_weak() is not None # graph kept the Buffer alive | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test prove the buffer is kept alive, but it doesn't validate that its cleaned up after the graph is released.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test for this. If it is flakey, we might need to adjust the Update: I confirmed this is not a concern for source graphs. Asynchronous destruction only comes into play for exec graphs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test creates an exec graph, so there is a race. CI for free-threaded Python seems more likely to trigger it. 9f2c8f2 adds polling, but removing the test would also be defensible. |
||||||
|
|
||||||
| stream = dev.create_stream() | ||||||
| g.instantiate().launch(stream) | ||||||
| stream.sync() | ||||||
|
|
||||||
| out = (ctypes.c_int * 1)(0) | ||||||
| handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int))) | ||||||
| assert out[0] == 1 | ||||||
|
|
||||||
| del g | ||||||
| _wait_until(lambda: buf_weak() is None) | ||||||
|
|
||||||
|
|
||||||
| def test_kernel_args_survive_graph_clone(init_cuda): | ||||||
| """Cloned graph keeps Buffer alive via CUDA user objects. | ||||||
|
|
||||||
| A graph clone does not inherit Python-level references, so only user | ||||||
| objects (which propagate through cuGraphClone) can keep the args alive. | ||||||
| """ | ||||||
| from cuda.core._utils.cuda_utils import driver, handle_return | ||||||
|
|
||||||
| _skip_if_no_mempool() | ||||||
| dev = Device() | ||||||
| mr = DeviceMemoryResource(dev) | ||||||
| add_one = compile_common_kernels().get_kernel("add_one") | ||||||
| buf = mr.allocate(ctypes.sizeof(ctypes.c_int), stream=dev.default_stream) | ||||||
| buf.fill(0, stream=dev.default_stream) | ||||||
| dev.default_stream.sync() | ||||||
| dptr = int(buf.handle) | ||||||
|
|
||||||
| g = GraphDefinition() | ||||||
| g.launch(LaunchConfig(grid=1, block=1), add_one, buf) | ||||||
| cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle))) | ||||||
|
|
||||||
| del buf, g | ||||||
| gc.collect() | ||||||
|
|
||||||
| graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0)) | ||||||
| stream = dev.create_stream() | ||||||
| handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle)))) | ||||||
| stream.sync() | ||||||
|
|
||||||
| out = (ctypes.c_int * 1)(0) | ||||||
| handle_return(driver.cuMemcpyDtoH(out, dptr, ctypes.sizeof(ctypes.c_int))) | ||||||
| assert out[0] == 1 | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this still welcome flakiness? I am concerned about this being tested in SWQA hands