commit 908f762cd09c7a0002b4450366b7cf6bf4868bc6
parent c738563d89c5b79e8f0d44133a81723b11f7afc9
Author: nolash <dev@holbrook.no>
Date: Thu, 15 Apr 2021 09:59:45 +0200
Add interrupt test base
Diffstat:
4 files changed, 127 insertions(+), 6 deletions(-)
diff --git a/chainsyncer/backend/memory.py b/chainsyncer/backend/memory.py
@@ -14,6 +14,8 @@ class MemBackend:
self.flags = 0
self.target_block = target_block
self.db_session = None
+ self.filter_names = []
+ self.filter_values = []
def connect(self):
@@ -28,6 +30,8 @@ class MemBackend:
logg.debug('stateless backend received {} {}'.format(block_height, tx_height))
self.block_height = block_height
self.tx_height = tx_height
+ for i in range(len(self.filter_values)):
+ self.filter_values[i] = False
def get(self):
@@ -39,11 +43,13 @@ class MemBackend:
def register_filter(self, name):
- pass
+ self.filter_names.append(name)
+ self.filter_values.append(False)
def complete_filter(self, n):
- pass
+ self.filter_values[n-1] = True
+ logg.debug('set filter {}'.format(self.filter_names[n-1]))
def __str__(self):
diff --git a/chainsyncer/driver.py b/chainsyncer/driver.py
@@ -72,6 +72,11 @@ class Syncer:
self.backend.register_filter(str(f))
+ def process_single(self, conn, block, tx, block_height, tx_index):
+ self.backend.set(block_height, tx_index)
+ self.filter.apply(conn, block, tx)
+
+
class BlockPollSyncer(Syncer):
def __init__(self, backend, pre_callback=None, block_callback=None, post_callback=None):
@@ -120,14 +125,16 @@ class HeadSyncer(BlockPollSyncer):
while True:
try:
tx = block.tx(i)
- rcpt = conn.do(receipt(tx.hash))
- tx.apply_receipt(rcpt)
- self.backend.set(block.number, i)
- self.filter.apply(conn, block, tx)
except IndexError as e:
logg.debug('index error syncer rcpt get {}'.format(e))
self.backend.set(block.number + 1, 0)
break
+
+ rcpt = conn.do(receipt(tx.hash))
+ tx.apply_receipt(rcpt)
+
+ self.process_single(conn, block, tx, block.number, i)
+
i += 1
diff --git a/sql_requirements.txt b/sql_requirements.txt
@@ -0,0 +1,2 @@
+psycopg2==2.8.6
+SQLAlchemy==1.3.20
diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py
@@ -0,0 +1,106 @@
+# standard imports
+import logging
+import unittest
+import os
+
+# external imports
+from chainlib.chain import ChainSpec
+from hexathon import add_0x
+
+# local imports
+from chainsyncer.backend.memory import MemBackend
+from chainsyncer.driver import HeadSyncer
+from chainsyncer.error import NoBlockForYou
+
+# test imports
+from tests.base import TestBase
+
+logging.basicConfig(level=logging.DEBUG)
+logg = logging.getLogger()
+
+
+class TestSyncer(HeadSyncer):
+
+
+ def __init__(self, backend, tx_counts=[]):
+ self.tx_counts = tx_counts
+ super(TestSyncer, self).__init__(backend)
+
+
+ def get(self, conn):
+ if self.backend.block_height == self.backend.target_block:
+ raise NoBlockForYou()
+ if self.backend.block_height > len(self.tx_counts):
+ return []
+
+ block_txs = []
+ for i in range(self.tx_counts[self.backend.block_height]):
+ block_txs.append(add_0x(os.urandom(32).hex()))
+
+ return block_txs
+
+
+ def process(self, conn, block):
+ i = 0
+ for tx in block:
+ self.process_single(conn, block, tx, self.backend.block_height, i)
+ i += 1
+
+
+
+class NaughtyCountExceptionFilter:
+
+ def __init__(self, name, croak_on):
+ self.c = 0
+ self.croak = croak_on
+ self.name = name
+
+
+ def filter(self, conn, block, tx, db_session=None):
+ self.c += 1
+ if self.c == self.croak:
+ raise RuntimeError('foo')
+
+
+ def __str__(self):
+ return '{} {}'.format(self.__class__.__name__, self.name)
+
+
+class CountFilter:
+
+ def __init__(self, name):
+ self.c = 0
+ self.name = name
+
+
+ def filter(self, conn, block, tx, db_session=None):
+ self.c += 1
+
+
+ def __str__(self):
+ return '{} {}'.format(self.__class__.__name__, self.name)
+
+
+class TestInterrupt(unittest.TestCase):
+
+ def setUp(self):
+ self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
+ self.backend = MemBackend(self.chain_spec, None, target_block=2)
+ self.syncer = TestSyncer(self.backend, [4, 2, 3])
+
+ def test_filter_interrupt(self):
+
+ fltrs = [
+ CountFilter('foo'),
+ CountFilter('bar'),
+ NaughtyCountExceptionFilter('xyzzy', 2),
+ CountFilter('baz'),
+ ]
+
+ for fltr in fltrs:
+ self.syncer.add_filter(fltr)
+
+ self.syncer.loop(0.1, None)
+
+if __name__ == '__main__':
+ unittest.main()