commit b49ab2ceac92fe74c6063ee330c582a45599e226
parent 63b898444be1769d52bb86c6b676a98c3465e88c
Author: lash <dev@holbrook.no>
Date: Thu, 1 Dec 2022 16:11:52 +0000
Set up complex visitor test pattern
Diffstat:
2 files changed, 55 insertions(+), 19 deletions(-)
diff --git a/piknik/render/base.py b/piknik/render/base.py
@@ -30,15 +30,15 @@ class Renderer:
pass
- def apply_message_pre(self, state, issue, tags, envelope, message, accumulator=None):
+ def apply_message_pre(self, state, issue, tags, envelope, message, message_id, accumulator=None):
pass
- def apply_message_post(self, state, issue, tags, envelope, message, accumulator=None):
+ def apply_message_post(self, state, issue, tags, envelope, message, message_id, accumulator=None):
pass
- def apply_message(self, state, issue, tags, envelope, message, accumulator=None):
+ def apply_message(self, state, issue, tags, envelope, message, message_id, accumulator=None):
pass
@@ -61,11 +61,11 @@ class Renderer:
self.__add(r)
def message_callback(envelope, message, message_id):
- r = self.apply_message_pre(state, issue, tags, envelope, message, accumulator=accumulator)
+ r = self.apply_message_pre(state, issue, tags, envelope, message, message_id, accumulator=accumulator)
self.__add(r)
- r = self.apply_message(state, issue, tags, envelope, message, accumulator=accumulator)
+ r = self.apply_message(state, issue, tags, envelope, message, message_id, accumulator=accumulator)
self.__add(r)
- r = self.apply_message_post(state, issue, tags, envelope, message, accumulator=accumulator)
+ r = self.apply_message_post(state, issue, tags, envelope, message, message_id, accumulator=accumulator)
self.__add(r)
#for msg in self.b.get_msg(issue.id, envelope_callback=envelope_callback, message_callback=message_callback):
diff --git a/tests/test_render.py b/tests/test_render.py
@@ -47,35 +47,71 @@ class TestRenderer(Renderer):
def __init__(self, basket, accumulator=None):
super(TestRenderer, self).__init__(basket, accumulator=accumulator)
- self.c = 0
+ self.p = 0
+ self.e = 0
- def apply_message(self, state, issue, tags, envelope, message, accumulator=None):
- r = self.c
- self.c += 1
+ def apply_envelope(self, state, issue, tags, envelope, accumulator=None):
+ r = self.e
+ self.e += 1
return r
+ def apply_message(self, state, issue, tags, envelope, message, message_id, accumulator=None):
+ r = self.p
+ self.p += 1
+ return r
+
+
+class TestRendererComposite(TestRenderer):
+
+ def __init__(self, basket, accumulator=None):
+ super(TestRendererComposite, self).__init__(basket, accumulator=accumulator)
+ self.last_message_id = None
+ self.m = []
+
+
+ def apply_message_post(self, state, issue, tags, envelope, message, message_id, accumulator=None):
+ if self.last_message_id != message_id:
+ self.m.append(message_id)
+ self.last_message_id = message_id
+
+
class TestMsg(unittest.TestCase):
def setUp(self):
self.acc = []
- def accumulate(v):
- self.acc.append(v)
-
- #(self.crypto, self.gpg, self.gpg_dir) = pgp_setup()
self.store = TestStates()
- self.b = Basket(self.store, message_wrapper=test_wrapper) #, message_wrapper=self.crypto.sign)
+ self.b = Basket(self.store, message_wrapper=test_wrapper)
self.render_dir = tempfile.mkdtemp()
- self.renderer = TestRenderer(self.b, accumulator=accumulate) #outdir=self.render_dir)
+
+
+ def accumulate(self, v):
+ self.acc.append(v)
def tearDown(self):
- #logg.debug('look in {}'.format(self.render_dir))
shutil.rmtree(self.render_dir)
def test_idlepass(self):
+ renderer = TestRenderer(self.b, accumulator=self.accumulate)
+ issue_one = Issue('foo')
+ self.b.add(issue_one)
+
+ issue_two = Issue('bar')
+ v = self.b.add(issue_two)
+
+ m = self.b.msg(v, 's:foo')
+
+ renderer.apply()
+ self.assertEqual(len(self.acc), 2)
+ self.assertEqual(renderer.e, 1)
+ self.assertEqual(renderer.p, 1)
+
+
+ def test_composite(self):
+ renderer = TestRendererComposite(self.b, accumulator=self.accumulate)
issue_one = Issue('foo')
self.b.add(issue_one)
@@ -84,8 +120,8 @@ class TestMsg(unittest.TestCase):
m = self.b.msg(v, 's:foo')
- self.renderer.apply()
- self.assertEqual(len(self.acc), 1)
+ renderer.apply()
+ self.assertEqual(len(renderer.m), 1)
if __name__ == '__main__':