Skip to content

Commit 25ebc7b

Browse files
support pickling allocator, test SVM
1 parent fdb3525 commit 25ebc7b

2 files changed

Lines changed: 29 additions & 1 deletion

File tree

pyopencl/array.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ class _copy_queue: # noqa: N801
352352

353353

354354
@contextmanager
355-
def queue_for_pickling(queue):
355+
def queue_for_pickling(queue, alloc=None):
356356
r"""A context manager that, for the current thread, sets the command queue
357357
to be used for pickling and unpickling :class:`Array`\ s to *queue*."""
358358
try:
@@ -365,6 +365,7 @@ def queue_for_pickling(queue):
365365
"inside the context of its own invocation.")
366366

367367
_QUEUE_FOR_PICKLING_TLS.queue = queue
368+
_QUEUE_FOR_PICKLING_TLS.alloc = alloc
368369
try:
369370
yield None
370371
finally:
@@ -749,6 +750,8 @@ def __getstate__(self):
749750
"queue_for_pickling is active.")
750751

751752
state = self.__dict__.copy()
753+
754+
del state["allocator"]
752755
del state["context"]
753756
del state["events"]
754757
del state["queue"]
@@ -760,14 +763,18 @@ def __getstate__(self):
760763
def __setstate__(self, state):
761764
try:
762765
queue = _QUEUE_FOR_PICKLING_TLS.queue
766+
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
763767
except AttributeError:
764768
queue = None
769+
alloc = None
765770

766771
if queue is None:
767772
raise RuntimeError("CL Array instances can only be pickled while "
768773
"queue_for_pickling is active.")
769774

770775
self.__dict__.update(state)
776+
777+
self.allocator = alloc
771778
self.context = queue.context
772779
self.events = []
773780
self.queue = queue

test/test_array.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,8 @@ def test_array_pickling(ctx_factory):
24192419
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
24202420
assert np.all(a_gpu_pickled.get() == a)
24212421

2422+
# {{{ subclass test
2423+
24222424
a_gpu_tagged = TaggableCLArray(queue, a.shape, a.dtype, tags={"foo", "bar"})
24232425
a_gpu_tagged.set(a)
24242426

@@ -2428,6 +2430,25 @@ def test_array_pickling(ctx_factory):
24282430
assert np.all(a_gpu_tagged_pickled.get() == a)
24292431
assert a_gpu_tagged_pickled.tags == a_gpu_tagged.tags
24302432

2433+
# }}}
2434+
2435+
# {{{ SVM test
2436+
2437+
from pyopencl.tools import SVMAllocator, SVMPool
2438+
2439+
alloc = SVMAllocator(context, alignment=0, queue=queue)
2440+
alloc = SVMPool(alloc)
2441+
2442+
a_dev = cl_array.to_device(queue, a, allocator=alloc)
2443+
2444+
with cl_array.queue_for_pickling(queue, alloc):
2445+
a_dev_pickled = pickle.loads(pickle.dumps(a_dev))
2446+
2447+
assert np.all(a_dev_pickled.get() == a)
2448+
assert a_dev_pickled.allocator is alloc
2449+
2450+
# }}}
2451+
24312452

24322453
# }}}
24332454

0 commit comments

Comments
 (0)