base.py (3823B)
1 # standard imports 2 import logging 3 import sha3 4 5 # external imports 6 from chainlib.eth.contract import ABIContractType 7 from chainlib.eth.contract import ABIContractEncoder 8 from chainlib.hash import keccak256 9 from chainlib.hash import keccak256_string_to_hex 10 11 logg = logging.getLogger(__name__) 12 13 14 class ERC712Encoder(ABIContractEncoder): 15 16 def __init__(self, struct_name): 17 super(ERC712Encoder, self).__init__() 18 self.method(struct_name) 19 self.encode = self.get_contents 20 21 22 def add(self, k, t, v): 23 typ_checked = ABIContractType(t.value) 24 self.typ_literal(t.value + ' ' + k) 25 m = getattr(self, t.value) 26 m(v) 27 28 29 def string(self, s): 30 v = keccak256_string_to_hex(s) 31 self.contents.append(v) 32 self.add_type(ABIContractType.STRING) 33 self.__log_latest_erc712(s) 34 35 36 def bytes(self, s): 37 v = keccak256_string_to_hex(s) 38 self.contents.append(v) 39 self.add_type(ABIContractType.BYTES) 40 self.__log_latest_erc712(s) 41 42 43 def __log_latest_erc712(self, v): 44 l = len(self.types) - 1 45 logg.debug('Encoder added {} -> {} ({})'.format(v, self.contents[l], self.types[l].value)) 46 47 48 def encode_type(self): 49 v = self.get_method() 50 r = keccak256(v) 51 logg.debug('typehash material {} -> {}'.format(v, r.hex())) 52 return r 53 54 55 def encode_data(self): 56 return b''.join(list(map(lambda x: bytes.fromhex(x), self.contents))) 57 58 59 def get_contents(self): 60 return self.encode_type() + self.encode_data() 61 62 63 class EIP712Domain(ERC712Encoder): 64 65 def __init__(self, name=None, version=None, chain_id=None, verifying_contract=None, salt=None): 66 super(EIP712Domain, self).__init__('EIP712Domain') 67 if name != None: 68 self.add('name', ABIContractType.STRING, name) 69 if version != None: 70 self.add('version', ABIContractType.STRING, version) 71 if chain_id != None: 72 self.add('chainId', ABIContractType.UINT256, chain_id) 73 if verifying_contract != None: 74 self.add('verifyingContract', ABIContractType.ADDRESS, verifying_contract) 75 if salt != None: 76 self.add('salt', ABIContractType.BYTES32, salt) 77 78 79 def get_contents(self): 80 v = self.encode_type() + self.encode_data() 81 r = keccak256(v) 82 logg.debug('domainseparator material {} -> {}'.format(v.hex(), r.hex())) 83 return r 84 85 86 class EIP712DomainEncoder(ERC712Encoder): 87 88 def __init__(self, struct_name, domain): 89 assert domain.__class__.__name__ == 'EIP712Domain' 90 self.domain = domain 91 self.__cache_data = b'' 92 super(EIP712DomainEncoder, self).__init__(struct_name) 93 94 95 def __cache(self): 96 if not self.dirty: 97 return 98 domain = self.domain.get_contents() 99 contents = super(EIP712DomainEncoder, self).get_contents() 100 self.__cache_data = domain + contents 101 102 103 def get_contents(self): 104 self.__cache() 105 return self.__cache_data 106 107 108 def get_domain(self): 109 self.__cache() 110 return self.__cache_data[:32] 111 112 113 def get_type_hash(self): 114 self.__cache() 115 return self.__cache_data[32:64] 116 117 118 def get_typed_data(self): 119 self.__cache() 120 return self.__cache_data[64:] 121 122 123 def get_hash(self): 124 return keccak256(self.get_type_hash() + self.get_typed_data()) 125 126 127 def __str__(self): 128 self.__cache() 129 domain = self.get_domain() 130 data_hash = self.get_type_hash() 131 data = self.get_typed_data() 132 s = 'erc712domain\t{}\nerc712type\t{}\nerc712data\n'.format( 133 domain.hex(), 134 data_hash.hex(), 135 ) 136 for i in range(0, len(data), 32): 137 s += '\t' + data[i:i+32].hex() + '\n' 138 139 return s