From be4e2164e64dfa0697561763e8079120a485a566 Mon Sep 17 00:00:00 2001 From: Richard Moore Date: Sun, 18 Oct 2020 21:52:25 -0400 Subject: [PATCH] Initial Signer support for EIP-712 signed typed data (#687). --- packages/abstract-signer/src.ts/index.ts | 12 +- packages/hash/src.ts/typed-data.ts | 219 ++++++++++++++++-- .../providers/src.ts/json-rpc-provider.ts | 23 +- packages/wallet/src.ts/index.ts | 24 +- 4 files changed, 240 insertions(+), 38 deletions(-) diff --git a/packages/abstract-signer/src.ts/index.ts b/packages/abstract-signer/src.ts/index.ts index 374292970..d8a10cf7a 100644 --- a/packages/abstract-signer/src.ts/index.ts +++ b/packages/abstract-signer/src.ts/index.ts @@ -49,6 +49,12 @@ export interface ExternallyOwnedAccount { // key or mnemonic) in a function, so that console.log does not leak // the data +// @TODO: This is a temporary measure to preserse backwards compatibility +// In v6, the method on TypedDataSigner will be added to Signer +export interface TypedDataSigner { + _signTypedData(domain: TypedDataDomain, types: Record>, value: Record): Promise; +} + export abstract class Signer { readonly provider?: Provider; @@ -70,8 +76,6 @@ export abstract class Signer { // it does, sentTransaction MUST be overridden. abstract signTransaction(transaction: Deferrable): Promise; -// abstract _signTypedData(domain: TypedDataDomain, types: Array, data: any): Promise; - // Returns a new instance of the Signer, connected to provider. // This MAY throw if changing providers is not supported. abstract connect(provider: Provider): Signer; @@ -236,7 +240,7 @@ export abstract class Signer { } } -export class VoidSigner extends Signer { +export class VoidSigner extends Signer implements TypedDataSigner { readonly address: string; constructor(address: string, provider?: Provider) { @@ -264,7 +268,7 @@ export class VoidSigner extends Signer { return this._fail("VoidSigner cannot sign transactions", "signTransaction"); } - _signTypedData(domain: TypedDataDomain, types: Array, data: any): Promise { + _signTypedData(domain: TypedDataDomain, types: Record>, value: Record): Promise { return this._fail("VoidSigner cannot sign typed data", "signTypedData"); } diff --git a/packages/hash/src.ts/typed-data.ts b/packages/hash/src.ts/typed-data.ts index 7863c5cd9..c1bc339b6 100644 --- a/packages/hash/src.ts/typed-data.ts +++ b/packages/hash/src.ts/typed-data.ts @@ -1,9 +1,9 @@ import { TypedDataDomain, TypedDataField } from "@ethersproject/abstract-signer"; import { getAddress } from "@ethersproject/address"; import { BigNumber, BigNumberish } from "@ethersproject/bignumber"; -import { arrayify, BytesLike, hexConcat, hexlify, hexZeroPad } from "@ethersproject/bytes"; +import { arrayify, BytesLike, hexConcat, hexlify, hexValue, hexZeroPad, isHexString } from "@ethersproject/bytes"; import { keccak256 } from "@ethersproject/keccak256"; -import { deepCopy, defineReadOnly } from "@ethersproject/properties"; +import { deepCopy, defineReadOnly, shallowCopy } from "@ethersproject/properties"; import { Logger } from "@ethersproject/logger"; import { version } from "./_version"; @@ -43,32 +43,63 @@ const domainFieldNames: Array = [ "name", "version", "chainId", "verifyingContract", "salt" ]; +function checkString(key: string): (value: any) => string { + return function (value: any){ + if (typeof(value) !== "string") { + logger.throwArgumentError(`invalid domain value for ${ JSON.stringify(key) }`, `domain.${ key }`, value); + } + return value; + } +} + +const domainChecks: Record any> = { + name: checkString("name"), + version: checkString("version"), + chainId: function(value: any) { + try { + return BigNumber.from(value).toString() + } catch (error) { } + return logger.throwArgumentError(`invalid domain value for "chainId"`, "domain.chainId", value); + }, + verifyingContract: function(value: any) { + try { + return getAddress(value).toLowerCase(); + } catch (error) { } + return logger.throwArgumentError(`invalid domain value "verifyingContract"`, "domain.verifyingContract", value); + }, + salt: function(value: any) { + try { + const bytes = arrayify(value); + if (bytes.length !== 32) { throw new Error("bad length"); } + return hexlify(bytes); + } catch (error) { } + return logger.throwArgumentError(`invalid domain value "salt"`, "domain.salt", value); + } +} + function getBaseEncoder(type: string): (value: any) => string { // intXX and uintXX { - const match = type.match(/^(u?)int(\d+)$/); + const match = type.match(/^(u?)int(\d*)$/); if (match) { - const width = parseInt(match[2]); - if (width % 8 !== 0 || width > 256 || match[2] !== String(width)) { - logger.throwArgumentError("invalid numeric width", "type", type); - } const signed = (match[1] === ""); - return function(value: BigNumberish) { - let v = BigNumber.from(value); + const width = parseInt(match[2] || "256"); + if (width % 8 !== 0 || width > 256 || (match[2] && match[2] !== String(width))) { + logger.throwArgumentError("invalid numeric width", "type", type); + } - if (signed) { - let bounds = MaxUint256.mask(width - 1); - if (v.gt(bounds) || v.lt(bounds.add(One).mul(NegativeOne))) { - logger.throwArgumentError(`value out-of-bounds for ${ type }`, "value", value); - } - } else if (v.lt(Zero) || v.gt(MaxUint256.mask(width))) { + const boundsUpper = MaxUint256.mask(signed ? (width - 1): width); + const boundsLower = signed ? boundsUpper.add(One).mul(NegativeOne): Zero; + + return function(value: BigNumberish) { + const v = BigNumber.from(value); + + if (v.lt(boundsLower) || v.gt(boundsUpper)) { logger.throwArgumentError(`value out-of-bounds for ${ type }`, "value", value); } - v = v.toTwos(256); - - return hexZeroPad(v.toHexString(), 32); + return hexZeroPad(v.toTwos(256).toHexString(), 32); }; } } @@ -81,6 +112,7 @@ function getBaseEncoder(type: string): (value: any) => string { if (width === 0 || width > 32 || match[1] !== String(width)) { logger.throwArgumentError("invalid bytes width", "type", type); } + return function(value: BytesLike) { const bytes = arrayify(value); if (bytes.length !== width) { @@ -110,7 +142,7 @@ function getBaseEncoder(type: string): (value: any) => string { } function encodeType(name: string, fields: Array): string { - return `${ name }(${ fields.map((f) => (f.type + " " + f.name)).join(",") })`; + return `${ name }(${ fields.map(({ name, type }) => (type + " " + name)).join(",") })`; } export class TypedDataEncoder { @@ -226,7 +258,7 @@ export class TypedDataEncoder { _getEncoder(type: string): (value: any) => string { - // Basic encoder type + // Basic encoder type (address, bool, uint256, etc) { const encoder = getBaseEncoder(type); if (encoder) { return encoder; } @@ -257,9 +289,9 @@ export class TypedDataEncoder { if (fields) { const encodedType = id(this._types[type]); return (value: Record) => { - const values = fields.map((f) => { - const result = this.getEncoder(f.type)(value[f.name]); - if (this._types[f.type]) { return keccak256(result); } + const values = fields.map(({ name, type }) => { + const result = this.getEncoder(type)(value[name]); + if (this._types[type]) { return keccak256(result); } return result; }); values.unshift(encodedType); @@ -294,6 +326,40 @@ export class TypedDataEncoder { return this.hashStruct(this.primaryType, value); } + _visit(type: string, value: any, callback: (type: string, data: any) => any): any { + // Basic encoder type (address, bool, uint256, etc) + { + const encoder = getBaseEncoder(type); + if (encoder) { return callback(type, value); } + } + + // Array + const match = type.match(/^(.*)(\x5b(\d*)\x5d)$/); + if (match) { + const subtype = match[1]; + const length = parseInt(match[3]); + if (length >= 0 && value.length !== length) { + logger.throwArgumentError("array length mismatch; expected length ${ arrayLength }", "value", value); + } + return value.map((v: any) => this._visit(subtype, v, callback)); + } + + // Struct + const fields = this.types[type]; + if (fields) { + return fields.reduce((accum, { name, type }) => { + accum[name] = this._visit(type, value[name], callback); + return accum; + }, >{}); + } + + return logger.throwArgumentError(`unknown type: ${ type }`, "type", type); + } + + visit(value: Record, callback: (type: string, data: any) => any): any { + return this._visit(this.primaryType, value, callback); + } + static from(types: Record>): TypedDataEncoder { return new TypedDataEncoder(types); } @@ -334,5 +400,112 @@ export class TypedDataEncoder { static hash(domain: TypedDataDomain, types: Record>, value: Record): string { return keccak256(TypedDataEncoder.encode(domain, types, value)); } + + // Replaces all address types with ENS names with their looked up address + static async resolveNames(domain: TypedDataDomain, types: Record>, value: Record, resolveName: (name: string) => Promise): Promise<{ domain: TypedDataDomain, value: any }> { + // Make a copy to isolate it from the object passed in + domain = shallowCopy(domain); + + // Look up all ENS names + const ensCache: Record = { }; + + // Do we need to look up the domain's verifyingContract? + if (domain.verifyingContract && !isHexString(domain.verifyingContract, 20)) { + ensCache[domain.verifyingContract] = "0x"; + } + + // We are going to use the encoder to visit all the base values + const encoder = TypedDataEncoder.from(types); + + // Get a list of all the addresses + encoder.visit(value, (type: string, value: any) => { + if (type === "address" && !isHexString(value, 20)) { + ensCache[value] = "0x"; + } + return value; + }); + + // Lookup each name + for (const name in ensCache) { + ensCache[name] = await resolveName(name); + } + + // Replace the domain verifyingContract if needed + if (domain.verifyingContract && ensCache[domain.verifyingContract]) { + domain.verifyingContract = ensCache[domain.verifyingContract]; + } + + // Replace all ENS names with their address + value = encoder.visit(value, (type: string, value: any) => { + if (type === "address" && ensCache[value]) { return ensCache[value]; } + return value; + }); + + return { domain, value }; + } + + static getPayload(domain: TypedDataDomain, types: Record>, value: Record): any { + // Validate the domain fields + TypedDataEncoder.hashDomain(domain); + + // Derive the EIP712Domain Struct reference type + const domainValues: Record = { }; + const domainTypes: Array<{ name: string, type:string }> = [ ]; + + domainFieldNames.forEach((name) => { + const value = (domain)[name]; + if (value == null) { return; } + domainValues[name] = domainChecks[name](value); + domainTypes.push({ name, type: domainFieldTypes[name] }); + }); + + const encoder = TypedDataEncoder.from(types); + + const typesWithDomain = shallowCopy(types); + if (typesWithDomain.EIP712Domain) { + typesWithDomain.EIP712Domain = domainTypes; + } + + // Validate the data structures and types + encoder.encode(value); + + return { + types: typesWithDomain, + domain: domainValues, + primaryType: encoder.primaryType, + message: encoder.visit(value, (type: string, value: any) => { + + // bytes + if (type.match(/^bytes(\d*)/)) { + return hexlify(arrayify(value)); + } + + // uint or int + if (type.match(/^u?int/)) { + let prefix = ""; + let v = BigNumber.from(value); + if (v.isNegative()) { + prefix = "-"; + v = v.mul(-1); + } + return prefix + hexValue(v.toHexString()); + } + + switch (type) { + case "address": + return value.toLowerCase(); + case "bool": + return !!value; + case "string": + if (typeof(value) !== "string") { + logger.throwArgumentError(`invalid string`, "value", value); + } + return value; + } + + return logger.throwArgumentError("unsupported type", "type", type); + }) + }; + } } diff --git a/packages/providers/src.ts/json-rpc-provider.ts b/packages/providers/src.ts/json-rpc-provider.ts index 58f7a8797..89b648623 100644 --- a/packages/providers/src.ts/json-rpc-provider.ts +++ b/packages/providers/src.ts/json-rpc-provider.ts @@ -3,9 +3,10 @@ // See: https://github.com/ethereum/wiki/wiki/JSON-RPC import { Provider, TransactionRequest, TransactionResponse } from "@ethersproject/abstract-provider"; -import { Signer } from "@ethersproject/abstract-signer"; +import { Signer, TypedDataDomain, TypedDataField, TypedDataSigner } from "@ethersproject/abstract-signer"; import { BigNumber } from "@ethersproject/bignumber"; import { Bytes, hexlify, hexValue } from "@ethersproject/bytes"; +import { _TypedDataEncoder } from "@ethersproject/hash"; import { Network, Networkish } from "@ethersproject/networks"; import { checkProperties, deepCopy, Deferrable, defineReadOnly, getStatic, resolveProperties, shallowCopy } from "@ethersproject/properties"; import { toUtf8Bytes } from "@ethersproject/strings"; @@ -88,7 +89,7 @@ function getLowerCase(value: string): string { const _constructorGuard = {}; -export class JsonRpcSigner extends Signer { +export class JsonRpcSigner extends Signer implements TypedDataSigner { readonly provider: JsonRpcProvider; _index: number; _address: string; @@ -203,13 +204,23 @@ export class JsonRpcSigner extends Signer { }); } - signMessage(message: Bytes | string): Promise { + async signMessage(message: Bytes | string): Promise { const data = ((typeof(message) === "string") ? toUtf8Bytes(message): message); - return this.getAddress().then((address) => { + const address = await this.getAddress(); - // https://github.com/ethereum/wiki/wiki/JSON-RPC#eth_sign - return this.provider.send("eth_sign", [ address.toLowerCase(), hexlify(data) ]); + // https://github.com/ethereum/wiki/wiki/JSON-RPC#eth_sign + return await this.provider.send("eth_sign", [ address.toLowerCase(), hexlify(data) ]); + } + + async _signTypedData(domain: TypedDataDomain, types: Record>, value: Record): Promise { + // Populate any ENS names (in-place) + const populated = await _TypedDataEncoder.resolveNames(domain, types, value, (name: string) => { + return this.provider.resolveName(name); }); + + return await this.provider.send("eth_signTypedData_v4", [ + _TypedDataEncoder.getPayload(populated.domain, types, populated.value) + ]); } unlock(password: string): Promise { diff --git a/packages/wallet/src.ts/index.ts b/packages/wallet/src.ts/index.ts index 66178344d..a922dc6e9 100644 --- a/packages/wallet/src.ts/index.ts +++ b/packages/wallet/src.ts/index.ts @@ -2,9 +2,9 @@ import { getAddress } from "@ethersproject/address"; import { Provider, TransactionRequest } from "@ethersproject/abstract-provider"; -import { ExternallyOwnedAccount, Signer } from "@ethersproject/abstract-signer"; +import { ExternallyOwnedAccount, Signer, TypedDataDomain, TypedDataField, TypedDataSigner } from "@ethersproject/abstract-signer"; import { arrayify, Bytes, BytesLike, concat, hexDataSlice, isHexString, joinSignature, SignatureLike } from "@ethersproject/bytes"; -import { hashMessage } from "@ethersproject/hash"; +import { hashMessage, _TypedDataEncoder } from "@ethersproject/hash"; import { defaultPath, HDNode, entropyToMnemonic, Mnemonic } from "@ethersproject/hdnode"; import { keccak256 } from "@ethersproject/keccak256"; import { defineReadOnly, resolveProperties } from "@ethersproject/properties"; @@ -27,7 +27,7 @@ function hasMnemonic(value: any): value is { mnemonic: Mnemonic } { return (mnemonic && mnemonic.phrase); } -export class Wallet extends Signer implements ExternallyOwnedAccount { +export class Wallet extends Signer implements ExternallyOwnedAccount, TypedDataSigner { readonly address: string; readonly provider: Provider; @@ -119,8 +119,22 @@ export class Wallet extends Signer implements ExternallyOwnedAccount { }); } - signMessage(message: Bytes | string): Promise { - return Promise.resolve(joinSignature(this._signingKey().signDigest(hashMessage(message)))); + async signMessage(message: Bytes | string): Promise { + return joinSignature(this._signingKey().signDigest(hashMessage(message))); + } + + async _signTypedData(domain: TypedDataDomain, types: Record>, value: Record): Promise { + // Populate any ENS names + const populated = await _TypedDataEncoder.resolveNames(domain, types, value, (name: string) => { + if (this.provider == null) { + logger.throwError("cannot resolve ENS names without a provider", Logger.errors.UNSUPPORTED_OPERATION, { + operation: "resolveName" + }); + } + return this.provider.resolveName(name); + }); + + return joinSignature(this._signingKey().signDigest(_TypedDataEncoder.hash(populated.domain, types, populated.value))); } encrypt(password: Bytes | string, options?: any, progressCallback?: ProgressCallback): Promise {