From 09e65e1d95ebb61f43261dc01697747a48e9eee5 Mon Sep 17 00:00:00 2001 From: Ivanq Date: Wed, 4 Mar 2020 22:11:59 +0300 Subject: [PATCH] Make ThreadPool a context manager to prevent memory leaks --- src/Test/TestNoparallel.py | 32 ++++---- src/Test/TestThreadPool.py | 160 ++++++++++++++++++------------------- src/util/ThreadPool.py | 6 ++ 3 files changed, 99 insertions(+), 99 deletions(-) diff --git a/src/Test/TestNoparallel.py b/src/Test/TestNoparallel.py index d80cc5fb..6fc4f57d 100644 --- a/src/Test/TestNoparallel.py +++ b/src/Test/TestNoparallel.py @@ -149,21 +149,19 @@ class TestNoparallel: def testMultithreadMix(self, queue_spawn): obj1 = ExampleClass() - thread_pool = ThreadPool.ThreadPool(10) + with ThreadPool.ThreadPool(10) as thread_pool: + s = time.time() + t1 = queue_spawn(obj1.countBlocking, 5) + time.sleep(0.01) + t2 = thread_pool.spawn(obj1.countBlocking, 5) + time.sleep(0.01) + t3 = thread_pool.spawn(obj1.countBlocking, 5) + time.sleep(0.3) + t4 = gevent.spawn(obj1.countBlocking, 5) + threads = [t1, t2, t3, t4] + for thread in threads: + assert thread.get() == "counted:5" - s = time.time() - t1 = queue_spawn(obj1.countBlocking, 5) - time.sleep(0.01) - t2 = thread_pool.spawn(obj1.countBlocking, 5) - time.sleep(0.01) - t3 = thread_pool.spawn(obj1.countBlocking, 5) - time.sleep(0.3) - t4 = gevent.spawn(obj1.countBlocking, 5) - threads = [t1, t2, t3, t4] - for thread in threads: - assert thread.get() == "counted:5" - - time_taken = time.time() - s - assert obj1.counted == 5 - assert 0.5 < time_taken < 0.7 - thread_pool.kill() + time_taken = time.time() - s + assert obj1.counted == 5 + assert 0.5 < time_taken < 0.7 diff --git a/src/Test/TestThreadPool.py b/src/Test/TestThreadPool.py index 6c7f35e7..5e95005e 100644 --- a/src/Test/TestThreadPool.py +++ b/src/Test/TestThreadPool.py @@ -9,31 +9,29 @@ from util import ThreadPool class TestThreadPool: def testExecutionOrder(self): - pool = ThreadPool.ThreadPool(4) + with ThreadPool.ThreadPool(4) as pool: + events = [] - events = [] + @pool.wrap + def blocker(): + events.append("S") + out = 0 + for i in range(10000000): + if i == 3000000: + events.append("M") + out += 1 + events.append("D") + return out - @pool.wrap - def blocker(): - events.append("S") - out = 0 - for i in range(10000000): - if i == 3000000: - events.append("M") - out += 1 - events.append("D") - return out + threads = [] + for i in range(3): + threads.append(gevent.spawn(blocker)) + gevent.joinall(threads) - threads = [] - for i in range(3): - threads.append(gevent.spawn(blocker)) - gevent.joinall(threads) + assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3 - assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3 - - res = blocker() - assert res == 10000000 - pool.kill() + res = blocker() + assert res == 10000000 def testLockBlockingSameThread(self): lock = ThreadPool.Lock() @@ -60,89 +58,88 @@ class TestThreadPool: time.sleep(0.5) lock.release() - pool = ThreadPool.ThreadPool(10) - threads = [ - pool.spawn(locker), - pool.spawn(locker), - gevent.spawn(locker), - pool.spawn(locker) - ] - time.sleep(0.1) + with ThreadPool.ThreadPool(10) as pool: + threads = [ + pool.spawn(locker), + pool.spawn(locker), + gevent.spawn(locker), + pool.spawn(locker) + ] + time.sleep(0.1) - s = time.time() + s = time.time() - lock.acquire(True, 5.0) + lock.acquire(True, 5.0) - unlock_taken = time.time() - s + unlock_taken = time.time() - s - assert 1.8 < unlock_taken < 2.2 + assert 1.8 < unlock_taken < 2.2 - gevent.joinall(threads) + gevent.joinall(threads) def testMainLoopCallerThreadId(self): main_thread_id = threading.current_thread().ident - pool = ThreadPool.ThreadPool(5) + with ThreadPool.ThreadPool(5) as pool: + def getThreadId(*args, **kwargs): + return threading.current_thread().ident - def getThreadId(*args, **kwargs): - return threading.current_thread().ident + t = pool.spawn(getThreadId) + assert t.get() != main_thread_id - t = pool.spawn(getThreadId) - assert t.get() != main_thread_id - - t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId)) - assert t.get() == main_thread_id + t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId)) + assert t.get() == main_thread_id def testMainLoopCallerGeventSpawn(self): main_thread_id = threading.current_thread().ident - pool = ThreadPool.ThreadPool(5) - def waiter(): - time.sleep(1) - return threading.current_thread().ident + with ThreadPool.ThreadPool(5) as pool: + def waiter(): + time.sleep(1) + return threading.current_thread().ident - def geventSpawner(): - event = ThreadPool.main_loop.call(gevent.spawn, waiter) + def geventSpawner(): + event = ThreadPool.main_loop.call(gevent.spawn, waiter) - with pytest.raises(Exception) as greenlet_err: - event.get() - assert str(greenlet_err.value) == "cannot switch to a different thread" + with pytest.raises(Exception) as greenlet_err: + event.get() + assert str(greenlet_err.value) == "cannot switch to a different thread" - waiter_thread_id = ThreadPool.main_loop.call(event.get) - return waiter_thread_id + waiter_thread_id = ThreadPool.main_loop.call(event.get) + return waiter_thread_id - s = time.time() - waiter_thread_id = pool.apply(geventSpawner) - assert main_thread_id == waiter_thread_id - time_taken = time.time() - s - assert 0.9 < time_taken < 1.2 + s = time.time() + waiter_thread_id = pool.apply(geventSpawner) + assert main_thread_id == waiter_thread_id + time_taken = time.time() - s + assert 0.9 < time_taken < 1.2 def testEvent(self): - pool = ThreadPool.ThreadPool(5) - event = ThreadPool.Event() + with ThreadPool.ThreadPool(5) as pool: + event = ThreadPool.Event() - def setter(): - time.sleep(1) - event.set("done!") + def setter(): + time.sleep(1) + event.set("done!") - def getter(): - return event.get() + def getter(): + return event.get() - pool.spawn(setter) - t_gevent = gevent.spawn(getter) - t_pool = pool.spawn(getter) - s = time.time() - assert event.get() == "done!" - time_taken = time.time() - s - gevent.joinall([t_gevent, t_pool]) + pool.spawn(setter) + t_gevent = gevent.spawn(getter) + t_pool = pool.spawn(getter) + s = time.time() + assert event.get() == "done!" + time_taken = time.time() - s + gevent.joinall([t_gevent, t_pool]) - assert t_gevent.get() == "done!" - assert t_pool.get() == "done!" + assert t_gevent.get() == "done!" + assert t_pool.get() == "done!" - assert 0.9 < time_taken < 1.2 + assert 0.9 < time_taken < 1.2 - with pytest.raises(Exception) as err: - event.set("another result") + with pytest.raises(Exception) as err: + event.set("another result") - assert "Event already has value" in str(err.value) + assert "Event already has value" in str(err.value) def testMemoryLeak(self): import gc @@ -153,10 +150,9 @@ class TestThreadPool: return "ok" def poolTest(): - pool = ThreadPool.ThreadPool(5) - for i in range(20): - pool.spawn(worker) - pool.kill() + with ThreadPool.ThreadPool(5) as pool: + for i in range(20): + pool.spawn(worker) for i in range(5): poolTest() diff --git a/src/util/ThreadPool.py b/src/util/ThreadPool.py index 5bb3c0d6..8c759039 100644 --- a/src/util/ThreadPool.py +++ b/src/util/ThreadPool.py @@ -55,6 +55,12 @@ class ThreadPool: del self.pool self.pool = None + def __enter__(self): + return self + + def __exit__(self, *args): + self.kill() + lock_pool = gevent.threadpool.ThreadPool(50) main_thread_id = threading.current_thread().ident