commit 987a18fd6b55b3d1a0e80c3b4be12ee781edee1f
parent d1077bf87aca2501feb8cfadf5f071750b1d286e
Author: nolash <dev@holbrook.no>
Date: Thu, 15 Apr 2021 15:06:07 +0200
Implement filter integrity test in sql backend
Diffstat:
5 files changed, 58 insertions(+), 34 deletions(-)
diff --git a/chainsyncer/backend/memory.py b/chainsyncer/backend/memory.py
@@ -44,9 +44,9 @@ class MemBackend:
def complete_filter(self, n):
- v = 1 << (n-1)
+ v = 1 << n
self.flags |= v
- logg.debug('set filter {} {}'.format(self.filter_names[n-1], v))
+ logg.debug('set filter {} {}'.format(self.filter_names[n], v))
def reset_filter(self):
diff --git a/chainsyncer/backend/sql.py b/chainsyncer/backend/sql.py
@@ -2,7 +2,7 @@
import logging
import uuid
-# third-party imports
+# imports
from chainlib.chain import ChainSpec
# local imports
@@ -56,6 +56,9 @@ class SyncerBackend:
def disconnect(self):
"""Commits state of sync to backend.
"""
+ if self.db_session == None:
+ return
+
if self.db_object_filter != None:
self.db_session.add(self.db_object_filter)
self.db_session.add(self.db_object)
@@ -97,7 +100,6 @@ class SyncerBackend:
"""
self.connect()
pair = self.db_object.set(block_height, tx_height)
- self.db_object_filter.clear()
(filter_state, count, digest)= self.db_object_filter.cursor()
self.disconnect()
return (pair, filter_state,)
@@ -294,5 +296,11 @@ class SyncerBackend:
self.disconnect()
+ def reset_filter(self):
+ self.connect()
+ self.db_object_filter.clear()
+ self.disconnect()
+
+
def __str__(self):
return "syncerbackend chain {} start {} target {}".format(self.chain(), self.start(), self.target())
diff --git a/chainsyncer/filter.py b/chainsyncer/filter.py
@@ -36,16 +36,15 @@ class SyncFilter:
i = 0
(pair, flags) = self.backend.get()
for f in self.filters:
+ if flags & (1 << i) == 0:
+ logg.debug('applying filter {} {}'.format(str(f), flags))
+ f.filter(conn, block, tx, session)
+ self.backend.complete_filter(i)
+ else:
+ logg.debug('skipping previously applied filter {} {}'.format(str(f), flags))
i += 1
- if flags & (1 << (i - 1)) > 0:
- logg.debug('skipping previously applied filter {}'.format(str(f)))
- continue
- logg.debug('applying filter {}'.format(str(f)))
- f.filter(conn, block, tx, session)
- self.backend.complete_filter(i)
- if session != None:
- self.backend.disconnect()
+ self.backend.disconnect()
class NoopFilter:
diff --git a/chainsyncer/unittest/base.py b/chainsyncer/unittest/base.py
@@ -39,19 +39,21 @@ class TestSyncer(HistorySyncer):
def get(self, conn):
- if self.backend.block_height == self.backend.target_block:
+ (pair, fltr) = self.backend.get()
+ (target_block, fltr) = self.backend.target()
+ block_height = pair[0]
+
+ if block_height == target_block:
self.running = False
raise NoBlockForYou()
return []
block_txs = []
- if self.backend.block_height < len(self.tx_counts):
- for i in range(self.tx_counts[self.backend.block_height]):
+ if block_height < len(self.tx_counts):
+ for i in range(self.tx_counts[block_height]):
block_txs.append(add_0x(os.urandom(32).hex()))
- logg.debug('get tx height {}'.format(self.backend.tx_height))
-
- return MockBlock(self.backend.block_height, block_txs)
+ return MockBlock(block_height, block_txs)
# TODO: implement mock conn instead, and use HeadSyncer.process
@@ -61,4 +63,4 @@ class TestSyncer(HistorySyncer):
self.process_single(conn, block, block.tx(i))
self.backend.reset_filter()
i += 1
- self.backend.set(self.backend.block_height + 1, 0)
+ self.backend.set(block.number + 1, 0)
diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py
@@ -8,6 +8,7 @@ from chainlib.chain import ChainSpec
# local imports
from chainsyncer.backend.memory import MemBackend
+from chainsyncer.backend.sql import SyncerBackend
# test imports
from tests.base import TestBase
@@ -54,35 +55,49 @@ class CountFilter:
return '{} {}'.format(self.__class__.__name__, self.name)
-class TestInterrupt(unittest.TestCase):
- def setUp(self):
- self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
- self.backend = MemBackend(self.chain_spec, None, target_block=4)
- self.syncer = TestSyncer(self.backend, [4, 2, 3])
+class TestInterrupt(TestBase):
- def test_filter_interrupt(self):
-
- fltrs = [
+ def setUp(self):
+ super(TestInterrupt, self).setUp()
+ self.filters = [
CountFilter('foo'),
CountFilter('bar'),
NaughtyCountExceptionFilter('xyzzy', 3),
CountFilter('baz'),
- ]
+ ]
+ self.backend = None
+
+
+ def assert_filter_interrupt(self):
+
+ syncer = TestSyncer(self.backend, [4, 2, 3])
- for fltr in fltrs:
- self.syncer.add_filter(fltr)
+ for fltr in self.filters:
+ syncer.add_filter(fltr)
try:
- self.syncer.loop(0.1, None)
+ syncer.loop(0.1, None)
except RuntimeError:
logg.info('caught croak')
pass
- self.syncer.loop(0.1, None)
+ (pair, fltr) = self.backend.get()
+ self.assertGreater(fltr, 0)
+ syncer.loop(0.1, None)
- for fltr in fltrs:
+ for fltr in self.filters:
logg.debug('{} {}'.format(str(fltr), fltr.c))
- #self.assertEqual(fltr.c, 11)
+ self.assertEqual(fltr.c, 9)
+
+
+ def test_filter_interrupt_memory(self):
+ self.backend = MemBackend(self.chain_spec, None, target_block=4)
+ self.assert_filter_interrupt()
+
+
+ def test_filter_interrupt_sql(self):
+ self.backend = SyncerBackend.initial(self.chain_spec, 4)
+ self.assert_filter_interrupt()
if __name__ == '__main__':