Skip to content
Snippets Groups Projects
Commit 739aa1d7 authored by Jan David Mol's avatar Jan David Mol
Browse files

Task #8437: Added MethodTrigger class to test function calls in asynchronous objects

parent 5c898296
No related branches found
No related tags found
No related merge requests found
...@@ -2668,6 +2668,7 @@ LCS/PyCommon/postgres.py -text ...@@ -2668,6 +2668,7 @@ LCS/PyCommon/postgres.py -text
LCS/PyCommon/test/python-coverage.sh eol=lf LCS/PyCommon/test/python-coverage.sh eol=lf
LCS/PyCommon/test/t_dbcredentials.run eol=lf LCS/PyCommon/test/t_dbcredentials.run eol=lf
LCS/PyCommon/test/t_dbcredentials.sh eol=lf LCS/PyCommon/test/t_dbcredentials.sh eol=lf
LCS/PyCommon/test/t_methodtrigger.sh eol=lf
LCS/PyCommon/util.py -text LCS/PyCommon/util.py -text
LCS/Tools/src/checkcomp.py -text LCS/Tools/src/checkcomp.py -text
LCS/Tools/src/countalllines -text LCS/Tools/src/countalllines -text
......
...@@ -10,6 +10,7 @@ add_subdirectory(test) ...@@ -10,6 +10,7 @@ add_subdirectory(test)
set(_py_files set(_py_files
dbcredentials.py dbcredentials.py
factory.py factory.py
methodtrigger.py
util.py util.py
postgres.py postgres.py
datetimeutils.py) datetimeutils.py)
......
from threading import Lock, Condition
__all__ = ["MethodTrigger"]
class MethodTrigger:
"""
Set a flag when a specific method is called, possibly asynchronously. Caller can wait on this flag.
Example:
class Foo(object):
def bar(self):
pass
foo = Foo()
trigger = MethodTrigger(foo, "bar")
if trigger.wait(): # Waits for 10 seconds for foo.bar() to get called
print "foo.bar() got called"
else
# This will happen, as foo.bar() wasn't called
print "foo.bar() did not get called"
Calls that were made before the trigger has been installed will not get recorded.
"""
def __init__(self, obj, method):
assert isinstance(obj, object), "Object %s does not derive from object." % (obj,)
self.obj = obj
self.method = method
self.old_func = obj.__getattribute__(method)
self.called = False
self.args = []
self.kwargs = {}
self.lock = Lock()
self.cond = Condition(self.lock)
# Patch the target method
obj.__setattr__(method, self.trigger)
def trigger(self, *args, **kwargs):
# Save the call parameters
self.args = args
self.kwargs = kwargs
# Call the original method
self.old_func(*args, **kwargs)
# Restore the original method
self.obj.__setattr__(self.method, self.old_func)
# Release waiting thread
with self.lock:
self.called = True
self.cond.notify()
def wait(self, timeout=10.0):
# Wait for method to get called
with self.lock:
if self.called:
return True
self.cond.wait(timeout)
return self.called
...@@ -7,3 +7,4 @@ file(COPY ...@@ -7,3 +7,4 @@ file(COPY
DESTINATION ${CMAKE_BINARY_DIR}/bin) DESTINATION ${CMAKE_BINARY_DIR}/bin)
lofar_add_test(t_dbcredentials) lofar_add_test(t_dbcredentials)
lofar_add_test(t_methodtrigger)
import unittest
from lofar.common.methodtrigger import MethodTrigger
from threading import Thread
import time
class TestMethodTrigger(unittest.TestCase):
def setUp(self):
# Create a basic object
class TestClass(object):
def func(self):
pass
self.testobj = TestClass()
# Install trigger
self.trigger = MethodTrigger(self.testobj, "func")
def test_no_call(self):
""" Do not trigger. """
# Wait for trigger
self.assertFalse(self.trigger.wait(0.1))
def test_serial_call(self):
""" Trigger and wait serially. """
# Call function
self.testobj.func()
# Wait for trigger
self.assertTrue(self.trigger.wait(0.1))
def test_parallel_call(self):
""" Trigger and wait in parallel. """
class wait_thread(Thread):
def __init__(self, trigger):
Thread.__init__(self)
self.result = None
self.trigger = trigger
def run(self):
self.result = self.trigger.wait(1.0)
class call_thread(Thread):
def __init__(self,func):
Thread.__init__(self)
self.func = func
def run(self):
time.sleep(0.5)
self.func()
# Start threads
t1 = wait_thread(self.trigger)
t1.start()
t2 = call_thread(self.testobj.func)
t2.start()
# Wait for them to finish
t1.join()
t2.join()
# Inspect result
self.assertTrue(t1.result)
class TestArgs(unittest.TestCase):
def setUp(self):
# Create a basic object
class TestClass(object):
def func(self, a, b, c=None, d=None):
pass
self.testobj = TestClass()
# Install trigger
self.trigger = MethodTrigger(self.testobj, "func")
def test_args(self):
""" Trigger and check args. """
# Call function
self.testobj.func(1, 2)
# Wait for trigger
self.assertTrue(self.trigger.wait(0.1))
# Check stored arguments
self.assertEqual(self.trigger.args, (1, 2))
def test_kwargs(self):
""" Trigger and check kwargs. """
# Call function
self.testobj.func(a=1, b=2)
# Wait for trigger
self.assertTrue(self.trigger.wait(0.1))
# Check stored arguments
self.assertEqual(self.trigger.kwargs, {"a": 1, "b": 2})
def test_full(self):
""" Trigger and check both args and kwargs. """
# Call function
self.testobj.func(1, 2, c=3, d=4)
# Wait for trigger
self.assertTrue(self.trigger.wait(0.1))
# Check stored arguments
self.assertEqual(self.trigger.args, (1, 2))
self.assertEqual(self.trigger.kwargs, {"c": 3, "d": 4})
def main(argv):
unittest.main(verbosity=2)
if __name__ == "__main__":
# run all tests
import sys
main(sys.argv[1:])
#!/bin/sh
./runctest.sh t_methodtrigger
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment