sfa-local-server/src/paillier.py

194 lines
5.6 KiB
Python

import phe
import json
import datetime
class Paillier(object):
N = 256
def __init__(self, N):
self.N = N
def generate_key_pair(self):
pub, priv = phe.paillier.generate_paillier_keypair(n_length=self.N)
date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
jwk_public = {
"kty": "DAJ",
"alg": "PAI-GN1",
"key_ops": ["encrypt"],
"n": phe.util.int_to_base64(pub.n),
"kid": "Paillier public key generated by pheutil on {}".format(date),
}
jwk_private = {
"kty": "DAJ",
"key_ops": ["decrypt"],
"p": phe.util.int_to_base64(priv.p),
"q": phe.util.int_to_base64(priv.q),
"pub": jwk_public,
"kid": "Paillier private key generated by pheutil on {}".format(date),
}
public_fp = open("./keys/public.json", "w")
private_fp = open("./keys/private.json", "w")
json.dump(jwk_private, indent=4, fp=private_fp)
json.dump(jwk_public, indent=4, fp=public_fp)
public_fp.write("\n")
private_fp.write("\n")
print("Private key written to {}".format(private_fp.name))
print("Public key written to {}".format(public_fp.name))
@staticmethod
def _load_public_key(public_key_data):
error_msg = "Invalid public key"
assert "alg" in public_key_data, error_msg
assert public_key_data["alg"] == "PAI-GN1", error_msg
assert public_key_data["kty"] == "DAJ", error_msg
n = phe.util.base64_to_int(public_key_data["n"])
pub = phe.PaillierPublicKey(n)
return pub
def load_keys(self):
print("Loading private key")
private = open(
"./keys/private.json",
)
privatekeydata = json.load(private)
print("Loading public key")
public = open(
"./keys/public.json",
)
publickeydata = json.load(public)
assert "pub" in privatekeydata
pub = Paillier._load_public_key(privatekeydata["pub"])
private_key_error = "Invalid private key"
assert "key_ops" in privatekeydata, private_key_error
assert "decrypt" in privatekeydata["key_ops"], private_key_error
assert "p" in privatekeydata, private_key_error
assert "q" in privatekeydata, private_key_error
assert privatekeydata["kty"] == "DAJ", private_key_error
_p = phe.util.base64_to_int(privatekeydata["p"])
_q = phe.util.base64_to_int(privatekeydata["q"])
self.private_key = phe.PaillierPrivateKey(pub, _p, _q)
self.public_key = Paillier._load_public_key(publickeydata)
def encrypt(self, x):
return self.public_key.encrypt(x)
def encode(self, x):
return phe.paillier.EncodedNumber.encode(self.public_key, x)
def decrypt(self, x):
return self.private_key.decrypt(x)
def serialize(self, enc_obj):
return json.dumps(phe.command_line.serialise_encrypted(enc_obj))
def deserialize(self, enc_json):
ciphertext_data = json.loads(enc_json)
assert "v" in ciphertext_data
assert "e" in ciphertext_data
enc = phe.EncryptedNumber(
self.public_key, int(ciphertext_data["v"]), exponent=ciphertext_data["e"]
)
return enc
def encr_sqr_sum(self, A):
encA = [0] * len(A)
sumA = 0
for i in range(len(A)):
encA[i] = self.encrypt(A[i])
sumA = sumA + (A[i]) ** 2
enc_sumA = self.encrypt(sumA)
return enc_sumA
def get_euclidean_dist(self, A, B):
encA = [0] * len(A)
sumA = 0
for i in range(len(A)):
encA[i] = self.encrypt(A[i])
sumA = sumA + (A[i]) ** 2
enc_sumA = self.encrypt(sumA)
print("enc_sumA : ", self.decrypt(enc_sumA))
encB = [0] * len(A)
sumB = 0
for i in range(len(B)):
encB[i] = self.encrypt(B[i])
sumB = sumB + (B[i]) ** 2
enc_sumB = self.encrypt(sumB)
print("enc_sumB : ", self.decrypt(enc_sumB))
prodAB = self.encode(0)
for i in range(len(A)):
encoding = self.encode(-2 * B[i]) # EncodedNumber
x = (
encA[i] * encoding
) # EncryptedNumber E(A[i] * -2 * B[i]) = E(A[i]) ^ (-2*B[i])
prodAB = x + prodAB
print(
"iteration : ",
i,
"prodAB : ",
self.decrypt(prodAB),
"x : ",
self.decrypt(x),
)
print(prodAB)
dist = enc_sumA + enc_sumB + prodAB
return self.decrypt(dist) ** 0.5
def get_alt_euclidean_dist(self, encA, enc_sumA, B):
# print("enc_sumA : ", self.decrypt(enc_sumA))
encB = [0] * len(encA)
sumB = 0
for i in range(len(B)):
encB[i] = self.encrypt(B[i])
sumB = sumB + (B[i]) ** 2
enc_sumB = self.encrypt(sumB)
# print("enc_sumB : ", self.decrypt(enc_sumB))
prodAB = self.encode(0)
for i in range(len(encA)):
encoding = self.encode(-2 * B[i]) # EncodedNumber
x = encA[i] * encoding # EncryptedNumber E(A[i] * -2 * B[i])
prodAB = x + prodAB
# print(
# "iteration : ",
# i,
# "prodAB : ",
# self.decrypt(prodAB),
# "x : ",
# self.decrypt(x),
# )
# print(prodAB)
dist = enc_sumA + enc_sumB + prodAB
return self.decrypt(dist) ** 0.5