eth-erc712

ERC712 typed data sign material builder
Info | Log | Files | Refs

base.py (3823B)


      1 # standard imports
      2 import logging
      3 import sha3
      4 
      5 # external imports
      6 from chainlib.eth.contract import ABIContractType
      7 from chainlib.eth.contract import ABIContractEncoder
      8 from chainlib.hash import keccak256
      9 from chainlib.hash import keccak256_string_to_hex
     10 
     11 logg = logging.getLogger(__name__)
     12 
     13 
     14 class ERC712Encoder(ABIContractEncoder):
     15 
     16     def __init__(self, struct_name):
     17         super(ERC712Encoder, self).__init__()
     18         self.method(struct_name)
     19         self.encode = self.get_contents
     20 
     21 
     22     def add(self, k, t, v):
     23         typ_checked = ABIContractType(t.value) 
     24         self.typ_literal(t.value + ' ' + k)
     25         m = getattr(self, t.value)
     26         m(v)
     27 
     28 
     29     def string(self, s):
     30         v = keccak256_string_to_hex(s)
     31         self.contents.append(v)
     32         self.add_type(ABIContractType.STRING)
     33         self.__log_latest_erc712(s)
     34 
     35 
     36     def bytes(self, s):
     37         v = keccak256_string_to_hex(s)
     38         self.contents.append(v)
     39         self.add_type(ABIContractType.BYTES)
     40         self.__log_latest_erc712(s)
     41 
     42 
     43     def __log_latest_erc712(self, v):
     44         l = len(self.types) - 1 
     45         logg.debug('Encoder added {} -> {} ({})'.format(v, self.contents[l], self.types[l].value))
     46 
     47 
     48     def encode_type(self):
     49         v = self.get_method()
     50         r = keccak256(v)
     51         logg.debug('typehash material {} -> {}'.format(v, r.hex()))
     52         return r
     53 
     54 
     55     def encode_data(self):
     56         return b''.join(list(map(lambda x: bytes.fromhex(x), self.contents)))
     57 
     58 
     59     def get_contents(self):
     60         return self.encode_type() + self.encode_data()
     61 
     62 
     63 class EIP712Domain(ERC712Encoder):
     64 
     65     def __init__(self, name=None, version=None, chain_id=None, verifying_contract=None, salt=None):
     66         super(EIP712Domain, self).__init__('EIP712Domain')
     67         if name != None:
     68             self.add('name', ABIContractType.STRING, name)
     69         if version != None:
     70             self.add('version', ABIContractType.STRING, version)
     71         if chain_id != None:
     72             self.add('chainId', ABIContractType.UINT256, chain_id)
     73         if verifying_contract != None:
     74             self.add('verifyingContract', ABIContractType.ADDRESS, verifying_contract)
     75         if salt != None:
     76             self.add('salt', ABIContractType.BYTES32, salt)
     77 
     78 
     79     def get_contents(self):
     80         v = self.encode_type() + self.encode_data()
     81         r = keccak256(v)
     82         logg.debug('domainseparator material {} -> {}'.format(v.hex(), r.hex()))
     83         return r
     84 
     85 
     86 class EIP712DomainEncoder(ERC712Encoder):
     87     
     88     def __init__(self, struct_name, domain):
     89         assert domain.__class__.__name__ == 'EIP712Domain'
     90         self.domain = domain
     91         self.__cache_data = b''
     92         super(EIP712DomainEncoder, self).__init__(struct_name)
     93 
     94 
     95     def __cache(self):
     96         if not self.dirty:
     97             return
     98         domain = self.domain.get_contents()
     99         contents = super(EIP712DomainEncoder, self).get_contents()
    100         self.__cache_data = domain + contents
    101 
    102 
    103     def get_contents(self):
    104         self.__cache()
    105         return self.__cache_data
    106 
    107 
    108     def get_domain(self):
    109         self.__cache()
    110         return self.__cache_data[:32]
    111 
    112 
    113     def get_type_hash(self):
    114         self.__cache()
    115         return self.__cache_data[32:64]
    116 
    117 
    118     def get_typed_data(self):
    119         self.__cache()
    120         return self.__cache_data[64:]
    121 
    122 
    123     def get_hash(self):
    124         return keccak256(self.get_type_hash() + self.get_typed_data())
    125 
    126 
    127     def __str__(self):
    128         self.__cache()
    129         domain = self.get_domain()
    130         data_hash = self.get_type_hash()
    131         data = self.get_typed_data()
    132         s = 'erc712domain\t{}\nerc712type\t{}\nerc712data\n'.format(
    133                 domain.hex(),
    134                 data_hash.hex(),
    135                 )
    136         for i in range(0, len(data), 32):
    137             s += '\t' + data[i:i+32].hex() + '\n'
    138 
    139         return s