Skip to content

Commit 0ec19de

Browse files
add _get_queue_for_pickling, outline some pool support
1 parent 60e71ee commit 0ec19de

2 files changed

Lines changed: 90 additions & 43 deletions

File tree

pyopencl/__init__.py

Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
MemoryObject,
135135
MemoryMap,
136136
Buffer,
137+
PooledBuffer,
137138

138139
_Program,
139140
Kernel,
@@ -194,7 +195,7 @@
194195
enqueue_migrate_mem_objects, unload_platform_compiler)
195196

196197
if get_cl_header_version() >= (2, 0):
197-
from pyopencl._cl import SVM, SVMAllocation, SVMPointer
198+
from pyopencl._cl import SVM, SVMAllocation, SVMPointer, PooledSVM
198199

199200
if _cl.have_gl():
200201
from pyopencl._cl import ( # noqa: F401
@@ -2436,21 +2437,28 @@ def queue_for_pickling(queue, alloc=None):
24362437
_QUEUE_FOR_PICKLING_TLS.alloc = None
24372438

24382439

2439-
def _getstate_buffer(self):
2440-
import pyopencl as cl
2441-
state = {}
2442-
state["size"] = self.size
2443-
state["flags"] = self.flags
2444-
2440+
def _get_queue_for_pickling(obj):
24452441
try:
24462442
queue = _QUEUE_FOR_PICKLING_TLS.queue
2443+
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
24472444
except AttributeError:
24482445
queue = None
24492446

24502447
if queue is None:
2451-
raise RuntimeError("CL Buffer instances can only be pickled while "
2448+
raise RuntimeError(f"{type(obj).__name__} instances can only be pickled while "
24522449
"queue_for_pickling is active.")
24532450

2451+
return queue, alloc
2452+
2453+
2454+
def _getstate_buffer(self):
2455+
import pyopencl as cl
2456+
queue, _alloc = _get_queue_for_pickling(self)
2457+
2458+
state = {}
2459+
state["size"] = self.size
2460+
state["flags"] = self.flags
2461+
24542462
a = bytearray(self.size)
24552463
cl.enqueue_copy(queue, a, self)
24562464

@@ -2460,42 +2468,57 @@ def _getstate_buffer(self):
24602468

24612469

24622470
def _setstate_buffer(self, state):
2463-
try:
2464-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2465-
except AttributeError:
2466-
queue = None
2467-
2468-
if queue is None:
2469-
raise RuntimeError("CL Buffer instances can only be unpickled while "
2470-
"queue_for_pickling is active.")
2471+
import pyopencl as cl
2472+
queue, _alloc = _get_queue_for_pickling(self)
24712473

24722474
size = state["size"]
24732475
flags = state["flags"]
24742476

2475-
import pyopencl as cl
2476-
24772477
a = state["_pickle_data"]
24782478
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)
24792479

24802480

24812481
Buffer.__getstate__ = _getstate_buffer
24822482
Buffer.__setstate__ = _setstate_buffer
24832483

2484+
2485+
def _getstate_pooledbuffer(self):
2486+
import pyopencl as cl
2487+
queue, _alloc = _get_queue_for_pickling(self)
2488+
2489+
state = {}
2490+
state["size"] = self.size
2491+
state["flags"] = self.flags
2492+
2493+
a = bytearray(self.size)
2494+
cl.enqueue_copy(queue, a, self)
2495+
state["_pickle_data"] = a
2496+
2497+
return state
2498+
2499+
2500+
def _setstate_pooledbuffer(self, state):
2501+
_queue, _alloc = _get_queue_for_pickling(self)
2502+
2503+
_size = state["size"]
2504+
_flags = state["flags"]
2505+
2506+
_a = state["_pickle_data"]
2507+
# FIXME: Unclear what to do here - PooledBuffer does not have __init__
2508+
2509+
2510+
PooledBuffer.__getstate__ = _getstate_pooledbuffer
2511+
PooledBuffer.__setstate__ = _setstate_pooledbuffer
2512+
2513+
24842514
if get_cl_header_version() >= (2, 0):
2485-
def _getstate_svm(self):
2515+
def _getstate_svmallocation(self):
24862516
import pyopencl as cl
24872517

24882518
state = {}
24892519
state["size"] = self.size
24902520

2491-
try:
2492-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2493-
except AttributeError:
2494-
queue = None
2495-
2496-
if queue is None:
2497-
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2498-
"pickled while queue_for_pickling is active.")
2521+
queue, _alloc = _get_queue_for_pickling(self)
24992522

25002523
a = bytearray(self.size)
25012524
cl.enqueue_copy(queue, a, self)
@@ -2504,17 +2527,10 @@ def _getstate_svm(self):
25042527

25052528
return state
25062529

2507-
def _setstate_svm(self, state):
2530+
def _setstate_svmallocation(self, state):
25082531
import pyopencl as cl
25092532

2510-
try:
2511-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2512-
except AttributeError:
2513-
queue = None
2514-
2515-
if queue is None:
2516-
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2517-
"unpickled while queue_for_pickling is active.")
2533+
queue, _alloc = _get_queue_for_pickling(self)
25182534

25192535
size = state["size"]
25202536

@@ -2523,8 +2539,33 @@ def _setstate_svm(self, state):
25232539
queue=queue)
25242540
cl.enqueue_copy(queue, self, a)
25252541

2526-
SVMAllocation.__getstate__ = _getstate_svm
2527-
SVMAllocation.__setstate__ = _setstate_svm
2542+
SVMAllocation.__getstate__ = _getstate_svmallocation
2543+
SVMAllocation.__setstate__ = _setstate_svmallocation
2544+
2545+
def _getstate_pooled_svm(self):
2546+
import pyopencl as cl
2547+
2548+
state = {}
2549+
state["size"] = self.size
2550+
2551+
queue, _alloc = _get_queue_for_pickling(self)
2552+
2553+
a = bytearray(self.size)
2554+
cl.enqueue_copy(queue, a, self)
2555+
2556+
state["_pickle_data"] = a
2557+
2558+
return state
2559+
2560+
def _setstate_pooled_svm(self, state):
2561+
_queue, _alloc = _get_queue_for_pickling(self)
2562+
_size = state["size"]
2563+
_data = state["_pickle_data"]
2564+
2565+
# FIXME: Unclear what to do here - PooledSVM does not have __init__
2566+
2567+
PooledSVM.__getstate__ = _getstate_pooled_svm
2568+
PooledSVM.__setstate__ = _setstate_pooled_svm
25282569

25292570
# }}}
25302571

test/test_array.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,12 +2404,18 @@ def __init__(self, cq, shape, dtype, tags):
24042404
self.tags = tags
24052405

24062406

2407-
def test_array_pickling(ctx_factory):
2407+
@pytest.mark.parametrize("use_mempool", [False, True])
2408+
def test_array_pickling(ctx_factory, use_mempool):
24082409
context = ctx_factory()
24092410
queue = cl.CommandQueue(context)
24102411

2412+
if use_mempool:
2413+
alloc = cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))
2414+
else:
2415+
alloc = None
2416+
24112417
a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
2412-
a_gpu = cl_array.to_device(queue, a)
2418+
a_gpu = cl_array.to_device(queue, a, allocator=alloc)
24132419

24142420
import pickle
24152421
with pytest.raises(RuntimeError):
@@ -2437,11 +2443,11 @@ def test_array_pickling(ctx_factory):
24372443
from pyopencl.characterize import has_coarse_grain_buffer_svm
24382444

24392445
if has_coarse_grain_buffer_svm(queue.device):
2440-
from pyopencl.tools import SVMAllocator
2446+
from pyopencl.tools import SVMAllocator, SVMPool
24412447

24422448
alloc = SVMAllocator(context, alignment=0, queue=queue)
2443-
# FIXME: SVMPool is not picklable
2444-
# alloc = SVMPool(alloc)
2449+
if use_mempool:
2450+
alloc = SVMPool(alloc)
24452451

24462452
a_dev = cl_array.to_device(queue, a, allocator=alloc)
24472453

0 commit comments

Comments
 (0)