ZeroNet/src/lib/sslcrypto/openssl/ecc.py

576 lines
20 KiB
Python

import ctypes
import hmac
import threading
from .._ecc import ECC
from .aes import aes
from .library import lib, openssl_backend
# Initialize functions
lib.BN_new.restype = ctypes.POINTER(ctypes.c_char)
lib.BN_bin2bn.restype = ctypes.POINTER(ctypes.c_char)
lib.BN_CTX_new.restype = ctypes.POINTER(ctypes.c_char)
lib.EC_GROUP_new_by_curve_name.restype = ctypes.POINTER(ctypes.c_char)
lib.EC_KEY_new.restype = ctypes.POINTER(ctypes.c_char)
lib.EC_POINT_new.restype = ctypes.POINTER(ctypes.c_char)
lib.EC_KEY_get0_private_key.restype = ctypes.POINTER(ctypes.c_char)
lib.EVP_PKEY_new.restype = ctypes.POINTER(ctypes.c_char)
try:
lib.EVP_PKEY_CTX_new.restype = ctypes.POINTER(ctypes.c_char)
except AttributeError:
pass
thread_local = threading.local()
# This lock is required to keep ECC thread-safe. Old OpenSSL versions (before
# 1.1.0) use global objects so they aren't thread safe. Fortunately we can check
# the code to find out which functions are thread safe.
#
# For example, EC_GROUP_new_by_curve_name checks global error code to initialize
# the group, so if two errors happen at once or two threads read the error code,
# or the codes are read in the wrong order, the group is initialized in a wrong
# way.
#
# EC_KEY_new_by_curve_name calls EC_GROUP_new_by_curve_name so it's not thread
# safe. We can't use the lock because it would be too slow; instead, we use
# EC_KEY_new and then EC_KEY_set_group which calls EC_GROUP_copy instead which
# is thread safe.
lock = threading.Lock()
class BN:
# BN_CTX
class Context:
def __init__(self):
self.ptr = lib.BN_CTX_new()
self.lib = lib # For finalizer
def __del__(self):
self.lib.BN_CTX_free(self.ptr)
@classmethod
def get(cls):
# Get thread-safe contexf
if not hasattr(thread_local, "bn_ctx"):
thread_local.bn_ctx = cls()
return thread_local.bn_ctx.ptr
def __init__(self, value=None, link_only=False):
if link_only:
self.bn = value
self._free = False
else:
if value is None:
self.bn = lib.BN_new()
self._free = True
elif isinstance(value, bytes):
self.bn = lib.BN_bin2bn(value, len(value), None)
self._free = True
else:
self.bn = lib.BN_new()
lib.BN_clear(self.bn)
lib.BN_add_word(self.bn, value)
self._free = True
def __del__(self):
if self._free:
lib.BN_free(self.bn)
def bytes(self, length=None):
buf = ctypes.create_string_buffer((len(self) + 7) // 8)
lib.BN_bn2bin(self.bn, buf)
buf = bytes(buf)
if length is None:
return buf
else:
if length < len(buf):
raise ValueError("Too little space for BN")
return b"\x00" * (length - len(buf)) + buf
def __int__(self):
value = 0
for byte in self.bytes():
value = value * 256 + byte
return value
def __len__(self):
return lib.BN_num_bits(self.bn)
def inverse(self, modulo):
result = BN()
if not lib.BN_mod_inverse(result.bn, self.bn, modulo.bn, BN.Context.get()):
raise ValueError("Could not compute inverse")
return result
def __floordiv__(self, other):
if not isinstance(other, BN):
raise TypeError("Can only divide BN by BN, not {}".format(other))
result = BN()
if not lib.BN_div(result.bn, None, self.bn, other.bn, BN.Context.get()):
raise ZeroDivisionError("Division by zero")
return result
def __mod__(self, other):
if not isinstance(other, BN):
raise TypeError("Can only divide BN by BN, not {}".format(other))
result = BN()
if not lib.BN_div(None, result.bn, self.bn, other.bn, BN.Context.get()):
raise ZeroDivisionError("Division by zero")
return result
def __add__(self, other):
if not isinstance(other, BN):
raise TypeError("Can only sum BN's, not BN and {}".format(other))
result = BN()
if not lib.BN_add(result.bn, self.bn, other.bn):
raise ValueError("Could not sum two BN's")
return result
def __sub__(self, other):
if not isinstance(other, BN):
raise TypeError("Can only subtract BN's, not BN and {}".format(other))
result = BN()
if not lib.BN_sub(result.bn, self.bn, other.bn):
raise ValueError("Could not subtract BN from BN")
return result
def __mul__(self, other):
if not isinstance(other, BN):
raise TypeError("Can only multiply BN by BN, not {}".format(other))
result = BN()
if not lib.BN_mul(result.bn, self.bn, other.bn, BN.Context.get()):
raise ValueError("Could not multiply two BN's")
return result
def __neg__(self):
return BN(0) - self
# A dirty but nice way to update current BN and free old BN at the same time
def __imod__(self, other):
res = self % other
self.bn, res.bn = res.bn, self.bn
return self
def __iadd__(self, other):
res = self + other
self.bn, res.bn = res.bn, self.bn
return self
def __isub__(self, other):
res = self - other
self.bn, res.bn = res.bn, self.bn
return self
def __imul__(self, other):
res = self * other
self.bn, res.bn = res.bn, self.bn
return self
def cmp(self, other):
if not isinstance(other, BN):
raise TypeError("Can only compare BN with BN, not {}".format(other))
return lib.BN_cmp(self.bn, other.bn)
def __eq__(self, other):
return self.cmp(other) == 0
def __lt__(self, other):
return self.cmp(other) < 0
def __gt__(self, other):
return self.cmp(other) > 0
def __ne__(self, other):
return self.cmp(other) != 0
def __le__(self, other):
return self.cmp(other) <= 0
def __ge__(self, other):
return self.cmp(other) >= 0
def __repr__(self):
return "<BN {}>".format(int(self))
def __str__(self):
return str(int(self))
class EllipticCurveBackend:
def __init__(self, nid):
self.lib = lib # For finalizer
self.nid = nid
with lock:
# Thread-safety
self.group = lib.EC_GROUP_new_by_curve_name(self.nid)
if not self.group:
raise ValueError("The curve is not supported by OpenSSL")
self.order = BN()
self.p = BN()
bn_ctx = BN.Context.get()
lib.EC_GROUP_get_order(self.group, self.order.bn, bn_ctx)
lib.EC_GROUP_get_curve_GFp(self.group, self.p.bn, None, None, bn_ctx)
self.public_key_length = (len(self.p) + 7) // 8
self.is_supported_evp_pkey_ctx = hasattr(lib, "EVP_PKEY_CTX_new")
def __del__(self):
self.lib.EC_GROUP_free(self.group)
def _private_key_to_ec_key(self, private_key):
# Thread-safety
eckey = lib.EC_KEY_new()
lib.EC_KEY_set_group(eckey, self.group)
if not eckey:
raise ValueError("Failed to allocate EC_KEY")
private_key = BN(private_key)
if not lib.EC_KEY_set_private_key(eckey, private_key.bn):
lib.EC_KEY_free(eckey)
raise ValueError("Invalid private key")
return eckey, private_key
def _public_key_to_point(self, public_key):
x = BN(public_key[0])
y = BN(public_key[1])
# EC_KEY_set_public_key_affine_coordinates is not supported by
# OpenSSL 1.0.0 so we can't use it
point = lib.EC_POINT_new(self.group)
if not lib.EC_POINT_set_affine_coordinates_GFp(self.group, point, x.bn, y.bn, BN.Context.get()):
raise ValueError("Could not set public key affine coordinates")
return point
def _public_key_to_ec_key(self, public_key):
# Thread-safety
eckey = lib.EC_KEY_new()
lib.EC_KEY_set_group(eckey, self.group)
if not eckey:
raise ValueError("Failed to allocate EC_KEY")
try:
# EC_KEY_set_public_key_affine_coordinates is not supported by
# OpenSSL 1.0.0 so we can't use it
point = self._public_key_to_point(public_key)
if not lib.EC_KEY_set_public_key(eckey, point):
raise ValueError("Could not set point")
lib.EC_POINT_free(point)
return eckey
except Exception as e:
lib.EC_KEY_free(eckey)
raise e from None
def _point_to_affine(self, point):
# Convert to affine coordinates
x = BN()
y = BN()
if lib.EC_POINT_get_affine_coordinates_GFp(self.group, point, x.bn, y.bn, BN.Context.get()) != 1:
raise ValueError("Failed to convert public key to affine coordinates")
# Convert to binary
if (len(x) + 7) // 8 > self.public_key_length:
raise ValueError("Public key X coordinate is too large")
if (len(y) + 7) // 8 > self.public_key_length:
raise ValueError("Public key Y coordinate is too large")
return x.bytes(self.public_key_length), y.bytes(self.public_key_length)
def decompress_point(self, public_key):
point = lib.EC_POINT_new(self.group)
if not point:
raise ValueError("Could not create point")
try:
if not lib.EC_POINT_oct2point(self.group, point, public_key, len(public_key), BN.Context.get()):
raise ValueError("Invalid compressed public key")
return self._point_to_affine(point)
finally:
lib.EC_POINT_free(point)
def new_private_key(self):
# Create random key
# Thread-safety
eckey = lib.EC_KEY_new()
lib.EC_KEY_set_group(eckey, self.group)
lib.EC_KEY_generate_key(eckey)
# To big integer
private_key = BN(lib.EC_KEY_get0_private_key(eckey), link_only=True)
# To binary
private_key_buf = private_key.bytes()
# Cleanup
lib.EC_KEY_free(eckey)
return private_key_buf
def private_to_public(self, private_key):
eckey, private_key = self._private_key_to_ec_key(private_key)
try:
# Derive public key
point = lib.EC_POINT_new(self.group)
try:
if not lib.EC_POINT_mul(self.group, point, private_key.bn, None, None, BN.Context.get()):
raise ValueError("Failed to derive public key")
return self._point_to_affine(point)
finally:
lib.EC_POINT_free(point)
finally:
lib.EC_KEY_free(eckey)
def ecdh(self, private_key, public_key):
if not self.is_supported_evp_pkey_ctx:
# Use ECDH_compute_key instead
# Create EC_KEY from private key
eckey, _ = self._private_key_to_ec_key(private_key)
try:
# Create EC_POINT from public key
point = self._public_key_to_point(public_key)
try:
key = ctypes.create_string_buffer(self.public_key_length)
if lib.ECDH_compute_key(key, self.public_key_length, point, eckey, None) == -1:
raise ValueError("Could not compute shared secret")
return bytes(key)
finally:
lib.EC_POINT_free(point)
finally:
lib.EC_KEY_free(eckey)
# Private key:
# Create EC_KEY
eckey, _ = self._private_key_to_ec_key(private_key)
try:
# Convert to EVP_PKEY
pkey = lib.EVP_PKEY_new()
if not pkey:
raise ValueError("Could not create private key object")
try:
lib.EVP_PKEY_set1_EC_KEY(pkey, eckey)
# Public key:
# Create EC_KEY
peer_eckey = self._public_key_to_ec_key(public_key)
try:
# Convert to EVP_PKEY
peer_pkey = lib.EVP_PKEY_new()
if not peer_pkey:
raise ValueError("Could not create public key object")
try:
lib.EVP_PKEY_set1_EC_KEY(peer_pkey, peer_eckey)
# Create context
ctx = lib.EVP_PKEY_CTX_new(pkey, None)
if not ctx:
raise ValueError("Could not create EVP context")
try:
if lib.EVP_PKEY_derive_init(ctx) != 1:
raise ValueError("Could not initialize key derivation")
if not lib.EVP_PKEY_derive_set_peer(ctx, peer_pkey):
raise ValueError("Could not set peer")
# Actually derive
key_len = ctypes.c_int(0)
lib.EVP_PKEY_derive(ctx, None, ctypes.byref(key_len))
key = ctypes.create_string_buffer(key_len.value)
lib.EVP_PKEY_derive(ctx, key, ctypes.byref(key_len))
return bytes(key)
finally:
lib.EVP_PKEY_CTX_free(ctx)
finally:
lib.EVP_PKEY_free(peer_pkey)
finally:
lib.EC_KEY_free(peer_eckey)
finally:
lib.EVP_PKEY_free(pkey)
finally:
lib.EC_KEY_free(eckey)
def _subject_to_bn(self, subject):
return BN(subject[:(len(self.order) + 7) // 8])
def sign(self, subject, private_key, recoverable, is_compressed, entropy):
z = self._subject_to_bn(subject)
private_key = BN(private_key)
k = BN(entropy)
rp = lib.EC_POINT_new(self.group)
bn_ctx = BN.Context.get()
try:
# Fix Minerva
k1 = k + self.order
k2 = k1 + self.order
if len(k1) == len(k2):
k = k2
else:
k = k1
if not lib.EC_POINT_mul(self.group, rp, k.bn, None, None, bn_ctx):
raise ValueError("Could not generate R")
# Convert to affine coordinates
rx = BN()
ry = BN()
if lib.EC_POINT_get_affine_coordinates_GFp(self.group, rp, rx.bn, ry.bn, bn_ctx) != 1:
raise ValueError("Failed to convert R to affine coordinates")
r = rx % self.order
if r == BN(0):
raise ValueError("Invalid k")
# Calculate s = k^-1 * (z + r * private_key) mod n
s = (k.inverse(self.order) * (z + r * private_key)) % self.order
if s == BN(0):
raise ValueError("Invalid k")
inverted = False
if s * BN(2) >= self.order:
s = self.order - s
inverted = True
r_buf = r.bytes(self.public_key_length)
s_buf = s.bytes(self.public_key_length)
if recoverable:
# Generate recid
recid = int(ry % BN(2)) ^ inverted
# The line below is highly unlikely to matter in case of
# secp256k1 but might make sense for other curves
recid += 2 * int(rx // self.order)
if is_compressed:
return bytes([31 + recid]) + r_buf + s_buf
else:
if recid >= 4:
raise ValueError("Too big recovery ID, use compressed address instead")
return bytes([27 + recid]) + r_buf + s_buf
else:
return r_buf + s_buf
finally:
lib.EC_POINT_free(rp)
def recover(self, signature, subject):
recid = signature[0] - 27 if signature[0] < 31 else signature[0] - 31
r = BN(signature[1:self.public_key_length + 1])
s = BN(signature[self.public_key_length + 1:])
# Verify bounds
if r >= self.order:
raise ValueError("r is out of bounds")
if s >= self.order:
raise ValueError("s is out of bounds")
bn_ctx = BN.Context.get()
z = self._subject_to_bn(subject)
rinv = r.inverse(self.order)
u1 = (-z * rinv) % self.order
u2 = (s * rinv) % self.order
# Recover R
rx = r + BN(recid // 2) * self.order
if rx >= self.p:
raise ValueError("Rx is out of bounds")
rp = lib.EC_POINT_new(self.group)
if not rp:
raise ValueError("Could not create R")
try:
init_buf = b"\x02" + rx.bytes(self.public_key_length)
if not lib.EC_POINT_oct2point(self.group, rp, init_buf, len(init_buf), bn_ctx):
raise ValueError("Could not use Rx to initialize point")
ry = BN()
if lib.EC_POINT_get_affine_coordinates_GFp(self.group, rp, None, ry.bn, bn_ctx) != 1:
raise ValueError("Failed to convert R to affine coordinates")
if int(ry % BN(2)) != recid % 2:
# Fix Ry sign
ry = self.p - ry
if lib.EC_POINT_set_affine_coordinates_GFp(self.group, rp, rx.bn, ry.bn, bn_ctx) != 1:
raise ValueError("Failed to update R coordinates")
# Recover public key
result = lib.EC_POINT_new(self.group)
if not result:
raise ValueError("Could not create point")
try:
if not lib.EC_POINT_mul(self.group, result, u1.bn, rp, u2.bn, bn_ctx):
raise ValueError("Could not recover public key")
return self._point_to_affine(result)
finally:
lib.EC_POINT_free(result)
finally:
lib.EC_POINT_free(rp)
def verify(self, signature, subject, public_key):
r_raw = signature[:self.public_key_length]
r = BN(r_raw)
s = BN(signature[self.public_key_length:])
if r >= self.order:
raise ValueError("r is out of bounds")
if s >= self.order:
raise ValueError("s is out of bounds")
bn_ctx = BN.Context.get()
z = self._subject_to_bn(subject)
pub_p = lib.EC_POINT_new(self.group)
if not pub_p:
raise ValueError("Could not create public key point")
try:
init_buf = b"\x04" + public_key[0] + public_key[1]
if not lib.EC_POINT_oct2point(self.group, pub_p, init_buf, len(init_buf), bn_ctx):
raise ValueError("Could initialize point")
sinv = s.inverse(self.order)
u1 = (z * sinv) % self.order
u2 = (r * sinv) % self.order
# Recover public key
result = lib.EC_POINT_new(self.group)
if not result:
raise ValueError("Could not create point")
try:
if not lib.EC_POINT_mul(self.group, result, u1.bn, pub_p, u2.bn, bn_ctx):
raise ValueError("Could not recover public key")
if BN(self._point_to_affine(result)[0]) % self.order != r:
raise ValueError("Invalid signature")
return True
finally:
lib.EC_POINT_free(result)
finally:
lib.EC_POINT_free(pub_p)
def derive_child(self, seed, child):
# Round 1
h = hmac.new(key=b"Bitcoin seed", msg=seed, digestmod="sha512").digest()
private_key1 = h[:32]
x, y = self.private_to_public(private_key1)
public_key1 = bytes([0x02 + (y[-1] % 2)]) + x
private_key1 = BN(private_key1)
# Round 2
child_bytes = []
for _ in range(4):
child_bytes.append(child & 255)
child >>= 8
child_bytes = bytes(child_bytes[::-1])
msg = public_key1 + child_bytes
h = hmac.new(key=h[32:], msg=msg, digestmod="sha512").digest()
private_key2 = BN(h[:32])
return ((private_key1 + private_key2) % self.order).bytes(self.public_key_length)
@classmethod
def get_backend(cls):
return openssl_backend
ecc = ECC(EllipticCurveBackend, aes)