Merge pull request #2457 from imachug/segfault

Make ThreadPool a context manager to prevent memory leaks
This commit is contained in:
ZeroNet 2020-03-05 10:45:14 +01:00 committed by GitHub
commit 7ba2c9344d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 99 deletions

View File

@ -149,21 +149,19 @@ class TestNoparallel:
def testMultithreadMix(self, queue_spawn): def testMultithreadMix(self, queue_spawn):
obj1 = ExampleClass() obj1 = ExampleClass()
thread_pool = ThreadPool.ThreadPool(10) with ThreadPool.ThreadPool(10) as thread_pool:
s = time.time()
t1 = queue_spawn(obj1.countBlocking, 5)
time.sleep(0.01)
t2 = thread_pool.spawn(obj1.countBlocking, 5)
time.sleep(0.01)
t3 = thread_pool.spawn(obj1.countBlocking, 5)
time.sleep(0.3)
t4 = gevent.spawn(obj1.countBlocking, 5)
threads = [t1, t2, t3, t4]
for thread in threads:
assert thread.get() == "counted:5"
s = time.time() time_taken = time.time() - s
t1 = queue_spawn(obj1.countBlocking, 5) assert obj1.counted == 5
time.sleep(0.01) assert 0.5 < time_taken < 0.7
t2 = thread_pool.spawn(obj1.countBlocking, 5)
time.sleep(0.01)
t3 = thread_pool.spawn(obj1.countBlocking, 5)
time.sleep(0.3)
t4 = gevent.spawn(obj1.countBlocking, 5)
threads = [t1, t2, t3, t4]
for thread in threads:
assert thread.get() == "counted:5"
time_taken = time.time() - s
assert obj1.counted == 5
assert 0.5 < time_taken < 0.7
thread_pool.kill()

View File

@ -9,31 +9,29 @@ from util import ThreadPool
class TestThreadPool: class TestThreadPool:
def testExecutionOrder(self): def testExecutionOrder(self):
pool = ThreadPool.ThreadPool(4) with ThreadPool.ThreadPool(4) as pool:
events = []
events = [] @pool.wrap
def blocker():
events.append("S")
out = 0
for i in range(10000000):
if i == 3000000:
events.append("M")
out += 1
events.append("D")
return out
@pool.wrap threads = []
def blocker(): for i in range(3):
events.append("S") threads.append(gevent.spawn(blocker))
out = 0 gevent.joinall(threads)
for i in range(10000000):
if i == 3000000:
events.append("M")
out += 1
events.append("D")
return out
threads = [] assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3
for i in range(3):
threads.append(gevent.spawn(blocker))
gevent.joinall(threads)
assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3 res = blocker()
assert res == 10000000
res = blocker()
assert res == 10000000
pool.kill()
def testLockBlockingSameThread(self): def testLockBlockingSameThread(self):
lock = ThreadPool.Lock() lock = ThreadPool.Lock()
@ -60,89 +58,88 @@ class TestThreadPool:
time.sleep(0.5) time.sleep(0.5)
lock.release() lock.release()
pool = ThreadPool.ThreadPool(10) with ThreadPool.ThreadPool(10) as pool:
threads = [ threads = [
pool.spawn(locker), pool.spawn(locker),
pool.spawn(locker), pool.spawn(locker),
gevent.spawn(locker), gevent.spawn(locker),
pool.spawn(locker) pool.spawn(locker)
] ]
time.sleep(0.1) time.sleep(0.1)
s = time.time() s = time.time()
lock.acquire(True, 5.0) lock.acquire(True, 5.0)
unlock_taken = time.time() - s unlock_taken = time.time() - s
assert 1.8 < unlock_taken < 2.2 assert 1.8 < unlock_taken < 2.2
gevent.joinall(threads) gevent.joinall(threads)
def testMainLoopCallerThreadId(self): def testMainLoopCallerThreadId(self):
main_thread_id = threading.current_thread().ident main_thread_id = threading.current_thread().ident
pool = ThreadPool.ThreadPool(5) with ThreadPool.ThreadPool(5) as pool:
def getThreadId(*args, **kwargs):
return threading.current_thread().ident
def getThreadId(*args, **kwargs): t = pool.spawn(getThreadId)
return threading.current_thread().ident assert t.get() != main_thread_id
t = pool.spawn(getThreadId) t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
assert t.get() != main_thread_id assert t.get() == main_thread_id
t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
assert t.get() == main_thread_id
def testMainLoopCallerGeventSpawn(self): def testMainLoopCallerGeventSpawn(self):
main_thread_id = threading.current_thread().ident main_thread_id = threading.current_thread().ident
pool = ThreadPool.ThreadPool(5) with ThreadPool.ThreadPool(5) as pool:
def waiter(): def waiter():
time.sleep(1) time.sleep(1)
return threading.current_thread().ident return threading.current_thread().ident
def geventSpawner(): def geventSpawner():
event = ThreadPool.main_loop.call(gevent.spawn, waiter) event = ThreadPool.main_loop.call(gevent.spawn, waiter)
with pytest.raises(Exception) as greenlet_err: with pytest.raises(Exception) as greenlet_err:
event.get() event.get()
assert str(greenlet_err.value) == "cannot switch to a different thread" assert str(greenlet_err.value) == "cannot switch to a different thread"
waiter_thread_id = ThreadPool.main_loop.call(event.get) waiter_thread_id = ThreadPool.main_loop.call(event.get)
return waiter_thread_id return waiter_thread_id
s = time.time() s = time.time()
waiter_thread_id = pool.apply(geventSpawner) waiter_thread_id = pool.apply(geventSpawner)
assert main_thread_id == waiter_thread_id assert main_thread_id == waiter_thread_id
time_taken = time.time() - s time_taken = time.time() - s
assert 0.9 < time_taken < 1.2 assert 0.9 < time_taken < 1.2
def testEvent(self): def testEvent(self):
pool = ThreadPool.ThreadPool(5) with ThreadPool.ThreadPool(5) as pool:
event = ThreadPool.Event() event = ThreadPool.Event()
def setter(): def setter():
time.sleep(1) time.sleep(1)
event.set("done!") event.set("done!")
def getter(): def getter():
return event.get() return event.get()
pool.spawn(setter) pool.spawn(setter)
t_gevent = gevent.spawn(getter) t_gevent = gevent.spawn(getter)
t_pool = pool.spawn(getter) t_pool = pool.spawn(getter)
s = time.time() s = time.time()
assert event.get() == "done!" assert event.get() == "done!"
time_taken = time.time() - s time_taken = time.time() - s
gevent.joinall([t_gevent, t_pool]) gevent.joinall([t_gevent, t_pool])
assert t_gevent.get() == "done!" assert t_gevent.get() == "done!"
assert t_pool.get() == "done!" assert t_pool.get() == "done!"
assert 0.9 < time_taken < 1.2 assert 0.9 < time_taken < 1.2
with pytest.raises(Exception) as err: with pytest.raises(Exception) as err:
event.set("another result") event.set("another result")
assert "Event already has value" in str(err.value) assert "Event already has value" in str(err.value)
def testMemoryLeak(self): def testMemoryLeak(self):
import gc import gc
@ -153,10 +150,9 @@ class TestThreadPool:
return "ok" return "ok"
def poolTest(): def poolTest():
pool = ThreadPool.ThreadPool(5) with ThreadPool.ThreadPool(5) as pool:
for i in range(20): for i in range(20):
pool.spawn(worker) pool.spawn(worker)
pool.kill()
for i in range(5): for i in range(5):
poolTest() poolTest()

View File

@ -55,6 +55,12 @@ class ThreadPool:
del self.pool del self.pool
self.pool = None self.pool = None
def __enter__(self):
return self
def __exit__(self, *args):
self.kill()
lock_pool = gevent.threadpool.ThreadPool(50) lock_pool = gevent.threadpool.ThreadPool(50)
main_thread_id = threading.current_thread().ident main_thread_id = threading.current_thread().ident