commit 81dd5753e84398fb1680ea1d95fdfd92a1cd420d
parent 82e26745556103a6cf5014d02d955d1396bf3d69
Author: nolash <dev@holbrook.no>
Date: Mon, 22 Feb 2021 10:57:07 +0100
Add filter to live instantiator, add filter start state
Diffstat:
9 files changed, 257 insertions(+), 59 deletions(-)
diff --git a/chainsyncer/backend.py b/chainsyncer/backend.py
@@ -7,6 +7,7 @@ from chainlib.chain import ChainSpec
# local imports
from chainsyncer.db.models.sync import BlockchainSync
+from chainsyncer.db.models.filter import BlockchainSyncFilter
from chainsyncer.db.models.base import SessionBase
logg = logging.getLogger()
@@ -23,6 +24,7 @@ class SyncerBackend:
def __init__(self, chain_spec, object_id):
self.db_session = None
self.db_object = None
+ self.db_object_filter = None
self.chain_spec = chain_spec
self.object_id = object_id
self.connect()
@@ -34,9 +36,17 @@ class SyncerBackend:
"""
if self.db_session == None:
self.db_session = SessionBase.create_session()
+
q = self.db_session.query(BlockchainSync)
q = q.filter(BlockchainSync.id==self.object_id)
self.db_object = q.first()
+
+ if self.db_object != None:
+ qtwo = self.db_session.query(BlockchainSyncFilter)
+ qtwo = qtwo.join(BlockchainSync)
+ qtwo = qtwo.filter(BlockchainSync.id==self.db_object.id)
+ self.db_object_filter = qtwo.first()
+
if self.db_object == None:
raise ValueError('sync entry with id {} not found'.format(self.object_id))
@@ -44,6 +54,8 @@ class SyncerBackend:
def disconnect(self):
"""Commits state of sync to backend.
"""
+ if self.db_object_filter != None:
+ self.db_session.add(self.db_object_filter)
self.db_session.add(self.db_object)
self.db_session.commit()
self.db_session.close()
@@ -67,8 +79,9 @@ class SyncerBackend:
"""
self.connect()
pair = self.db_object.cursor()
+ filter_state = self.db_object_filter.filter()
self.disconnect()
- return pair
+ return (pair, filter_state,)
def set(self, block_height, tx_height):
@@ -82,8 +95,9 @@ class SyncerBackend:
"""
self.connect()
pair = self.db_object.set(block_height, tx_height)
+ filter_state = self.db_object_filter.filter()
self.disconnect()
- return pair
+ return (pair, filter_state,)
def start(self):
@@ -94,8 +108,9 @@ class SyncerBackend:
"""
self.connect()
pair = self.db_object.start()
+ filter_state = self.db_object_filter.start()
self.disconnect()
- return pair
+ return (pair, filter_state,)
def target(self):
@@ -106,12 +121,13 @@ class SyncerBackend:
"""
self.connect()
target = self.db_object.target()
+ filter_state = self.db_object_filter.target()
self.disconnect()
- return target
+ return (target, filter_target,)
@staticmethod
- def first(chain):
+ def first(chain_spec):
"""Returns the model object of the most recent syncer in backend.
:param chain: Chain spec of chain that syncer is running for.
@@ -119,7 +135,12 @@ class SyncerBackend:
:returns: Last syncer object
:rtype: cic_eth.db.models.BlockchainSync
"""
- return BlockchainSync.first(chain)
+ #return BlockchainSync.first(str(chain_spec))
+ object_id = BlockchainSync.first(str(chain_spec))
+ if object_id == None:
+ return None
+ return SyncerBackend(chain_spec, object_id)
+
@staticmethod
@@ -193,15 +214,30 @@ class SyncerBackend:
"""
object_id = None
session = SessionBase.create_session()
+
o = BlockchainSync(str(chain_spec), block_height, 0, None)
session.add(o)
- session.commit()
+ session.flush()
object_id = o.id
+
+ of = BlockchainSyncFilter(o)
+ session.add(of)
+ session.commit()
+
session.close()
return SyncerBackend(chain_spec, object_id)
+ def register_filter(self, name):
+ self.connect()
+ if self.db_object_filter == None:
+ self.db_object_filter = BlockchainSyncFilter(self.db_object)
+ self.db_object_filter.add(name)
+ self.db_session.add(self.db_object_filter)
+ self.disconnect()
+
+
class MemBackend:
def __init__(self, chain_spec, object_id):
@@ -209,6 +245,7 @@ class MemBackend:
self.chain_spec = chain_spec
self.block_height = 0
self.tx_height = 0
+ self.flags = 0
self.db_session = None
diff --git a/chainsyncer/db/models/base.py b/chainsyncer/db/models/base.py
@@ -1,8 +1,18 @@
+# stanard imports
+import logging
+
# third-party imports
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import (
+ StaticPool,
+ QueuePool,
+ AssertionPool,
+ )
+
+logg = logging.getLogger()
Model = declarative_base(name='Model')
@@ -21,7 +31,11 @@ class SessionBase(Model):
transactional = True
"""Whether the database backend supports query transactions. Should be explicitly set by initialization code"""
poolable = True
- """Whether the database backend supports query transactions. Should be explicitly set by initialization code"""
+ """Whether the database backend supports connection pools. Should be explicitly set by initialization code"""
+ procedural = True
+ """Whether the database backend supports stored procedures"""
+ localsessions = {}
+ """Contains dictionary of sessions initiated by db model components"""
@staticmethod
@@ -40,7 +54,7 @@ class SessionBase(Model):
@staticmethod
- def connect(dsn, debug=False):
+ def connect(dsn, pool_size=8, debug=False):
"""Create new database connection engine and connect to database backend.
:param dsn: DSN string defining connection.
@@ -48,14 +62,28 @@ class SessionBase(Model):
"""
e = None
if SessionBase.poolable:
- e = create_engine(
- dsn,
- max_overflow=50,
- pool_pre_ping=True,
- pool_size=20,
- pool_recycle=10,
- echo=debug,
- )
+ poolclass = QueuePool
+ if pool_size > 1:
+ e = create_engine(
+ dsn,
+ max_overflow=pool_size*3,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ pool_recycle=60,
+ poolclass=poolclass,
+ echo=debug,
+ )
+ else:
+ if debug:
+ poolclass = AssertionPool
+ else:
+ poolclass = StaticPool
+
+ e = create_engine(
+ dsn,
+ poolclass=poolclass,
+ echo=debug,
+ )
else:
e = create_engine(
dsn,
@@ -71,3 +99,24 @@ class SessionBase(Model):
"""
SessionBase.engine.dispose()
SessionBase.engine = None
+
+
+ @staticmethod
+ def bind_session(session=None):
+ localsession = session
+ if localsession == None:
+ localsession = SessionBase.create_session()
+ localsession_key = str(id(localsession))
+ logg.debug('creating new session {}'.format(localsession_key))
+ SessionBase.localsessions[localsession_key] = localsession
+ return localsession
+
+
+ @staticmethod
+ def release_session(session=None):
+ session.flush()
+ session_key = str(id(session))
+ if SessionBase.localsessions.get(session_key) != None:
+ logg.debug('destroying session {}'.format(session_key))
+ session.commit()
+ session.close()
diff --git a/chainsyncer/db/models/filter.py b/chainsyncer/db/models/filter.py
@@ -1,40 +1,79 @@
# standard imports
+import logging
import hashlib
-# third-party imports
-from sqlalchemy import Column, String, Integer, BLOB
+# external imports
+from sqlalchemy import Column, String, Integer, BLOB, ForeignKey
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
# local imports
from .base import SessionBase
+from .sync import BlockchainSync
-
-zero_digest = '{:<064s'.format('0')
+zero_digest = bytearray(32)
+logg = logging.getLogger(__name__)
class BlockchainSyncFilter(SessionBase):
__tablename__ = 'chain_sync_filter'
- chain_sync_id = Column(Integer, ForeignKey='chain_sync.id')
+ chain_sync_id = Column(Integer, ForeignKey('chain_sync.id'))
+ flags_start = Column(BLOB)
flags = Column(BLOB)
- digest = Column(String)
+ digest = Column(BLOB)
count = Column(Integer)
- @staticmethod
- def set(self, names):
-
- def __init__(self, names, chain_sync, digest=None):
- if len(names) == 0:
- digest = zero_digest
- elif digest == None:
- h = hashlib.new('sha256')
- for n in names:
- h.update(n.encode('utf-8') + b'\x00')
- z = h.digest()
- digest = z.hex()
+ def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest):
self.digest = digest
- self.count = len(names)
- self.flags = bytearray((len(names) -1 ) / 8 + 1)
+ self.count = count
+
+ if flags == None:
+ flags = bytearray(0)
+ self.flags_start = flags
+ self.flags = flags
+
self.chain_sync_id = chain_sync.id
+
+
+ def add(self, name):
+ h = hashlib.new('sha256')
+ h.update(self.digest)
+ h.update(name.encode('utf-8'))
+ z = h.digest()
+
+ old_byte_count = int((self.count - 1) / 8 + 1)
+ new_byte_count = int((self.count) / 8 + 1)
+
+ logg.debug('old new {} {}'.format(old_byte_count, new_byte_count))
+ if old_byte_count != new_byte_count:
+ self.flags = bytearray(1) + self.flags
+ self.count += 1
+ self.digest = z
+
+
+ def start(self):
+ return self.flags_start
+
+
+ def cursor(self):
+ return self.flags_current
+
+
+ def clear(self):
+ self.flags = 0
+
+
+ def target(self):
+ n = 0
+ for i in range(self.count):
+ n |= 2 << i
+ return n
+
+
+ def set(self, n):
+ if self.flags & n > 0:
+ SessionBase.release_session(session)
+ raise AttributeError('Filter bit already set')
+ r.flags |= n
diff --git a/chainsyncer/db/models/sync.py b/chainsyncer/db/models/sync.py
@@ -41,19 +41,23 @@ class BlockchainSync(SessionBase):
:type chain: str
:param session: Session to use. If not specified, a separate session will be created for this method only.
:type session: SqlAlchemy Session
- :returns: True if sync record found
- :rtype: bool
+ :returns: Database primary key id of sync record
+ :rtype: number|None
"""
- local_session = False
- if session == None:
- session = SessionBase.create_session()
- local_session = True
+ session = SessionBase.bind_session(session)
+
q = session.query(BlockchainSync.id)
q = q.filter(BlockchainSync.blockchain==chain)
o = q.first()
- if local_session:
- session.close()
- return o == None
+
+ if o == None:
+ return None
+
+ sync_id = o.id
+
+ SessionBase.release_session(session)
+
+ return sync_id
@staticmethod
@@ -165,4 +169,4 @@ class BlockchainSync(SessionBase):
self.tx_cursor = tx_start
self.block_target = block_target
self.date_created = datetime.datetime.utcnow()
- self.date_modified = datetime.datetime.utcnow()
+ self.date_updated = datetime.datetime.utcnow()
diff --git a/chainsyncer/filter.py b/chainsyncer/filter.py
@@ -9,6 +9,7 @@ from .error import BackendError
logg = logging.getLogger(__name__)
+
class SyncFilter:
def __init__(self, backend, safe=True):
@@ -32,11 +33,15 @@ class SyncFilter:
except sqlalchemy.exc.TimeoutError as e:
self.backend.disconnect()
raise BackendError('database connection fail: {}'.format(e))
+ i = 0
for f in self.filters:
+ i += 1
logg.debug('applying filter {}'.format(str(f)))
f.filter(conn, block, tx, self.backend.db_session)
+ self.backend.set_filter()
self.backend.disconnect()
+
class NoopFilter:
def filter(self, conn, block, tx, db_session=None):
diff --git a/sql/sqlite/1.sql b/sql/sqlite/1.sql
@@ -1,13 +1,11 @@
CREATE TABLE IF NOT EXISTS chain_sync (
- id serial primary key not null,
+ id integer primary key autoincrement,
blockchain varchar not null,
- block_start int not null default 0,
- tx_start int not null default 0,
- block_cursor int not null default 0,
- tx_cursor int not null default 0,
- flags bytea not null,
- num_flags int not null,
- block_target int default null,
+ block_start integer not null default 0,
+ tx_start integer not null default 0,
+ block_cursor integer not null default 0,
+ tx_cursor integer not null default 0,
+ block_target integer default null,
date_created timestamp not null,
date_updated timestamp default null
);
diff --git a/sql/sqlite/2.sql b/sql/sqlite/2.sql
@@ -1,8 +1,9 @@
CREATE TABLE IF NOT EXISTS chain_sync_filter (
- id serial primary key not null,
- chain_sync_id int not null,
+ id integer primary key autoincrement not null,
+ chain_sync_id integer not null,
flags bytea default null,
- count int not null default 0,
+ flags_start bytea default null,
+ count integer not null default 0,
digest char(64) not null default '0000000000000000000000000000000000000000000000000000000000000000',
CONSTRAINT fk_chain_sync
FOREIGN KEY(chain_sync_id)
diff --git a/tests/base.py b/tests/base.py
@@ -1,13 +1,21 @@
+# standard imports
+import logging
import unittest
import tempfile
import os
#import pysqlite
+# external imports
+from chainlib.chain import ChainSpec
+
+# local imports
from chainsyncer.db import dsn_from_config
from chainsyncer.db.models.base import SessionBase
script_dir = os.path.realpath(os.path.dirname(__file__))
+logging.basicConfig(level=logging.DEBUG)
+
class TestBase(unittest.TestCase):
@@ -23,7 +31,7 @@ class TestBase(unittest.TestCase):
SessionBase.poolable = False
SessionBase.transactional = False
SessionBase.procedural = False
- SessionBase.connect(dsn, debug=True)
+ SessionBase.connect(dsn, debug=False)
f = open(os.path.join(script_dir, '..', 'sql', 'sqlite', '1.sql'), 'r')
sql = f.read()
@@ -39,6 +47,8 @@ class TestBase(unittest.TestCase):
conn = SessionBase.engine.connect()
conn.execute(sql)
+ self.chain_spec = ChainSpec('evm', 'foo', 42, 'bar')
+
def tearDown(self):
SessionBase.disconnect()
os.unlink(self.db_path)
diff --git a/tests/test_database.py b/tests/test_database.py
@@ -0,0 +1,55 @@
+# standard imports
+import unittest
+
+# external imports
+from chainlib.chain import ChainSpec
+
+# local imports
+from chainsyncer.db.models.base import SessionBase
+from chainsyncer.db.models.filter import BlockchainSyncFilter
+from chainsyncer.backend import SyncerBackend
+
+# testutil imports
+from tests.base import TestBase
+
+class TestDatabase(TestBase):
+
+
+ def test_backend_live(self):
+ s = SyncerBackend.live(self.chain_spec, 42)
+ self.assertEqual(s.object_id, 1)
+ backend = SyncerBackend.first(self.chain_spec)
+ #SyncerBackend(self.chain_spec, sync_id)
+ self.assertEqual(backend.object_id, 1)
+
+ bogus_chain_spec = ChainSpec('bogus', 'foo', 13, 'baz')
+ sync_id = SyncerBackend.first(bogus_chain_spec)
+ self.assertIsNone(sync_id)
+
+
+ def test_backend_filter(self):
+ s = SyncerBackend.live(self.chain_spec, 42)
+
+ s.connect()
+ filter_id = s.db_object_filter.id
+ s.disconnect()
+
+ session = SessionBase.create_session()
+ o = session.query(BlockchainSyncFilter).get(filter_id)
+ self.assertEqual(len(o.flags), 0)
+ session.close()
+
+ for i in range(9):
+ s.register_filter(str(i))
+
+ s.connect()
+ filter_id = s.db_object_filter.id
+ s.disconnect()
+
+ session = SessionBase.create_session()
+ o = session.query(BlockchainSyncFilter).get(filter_id)
+ self.assertEqual(len(o.flags), 2)
+ session.close()
+
+if __name__ == '__main__':
+ unittest.main()