194 lines
5.6 KiB
Python
194 lines
5.6 KiB
Python
import phe
|
|
import json
|
|
import datetime
|
|
|
|
|
|
class Paillier(object):
|
|
|
|
N = 256
|
|
|
|
def __init__(self, N):
|
|
self.N = N
|
|
|
|
def generate_key_pair(self):
|
|
pub, priv = phe.paillier.generate_paillier_keypair(n_length=self.N)
|
|
|
|
date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
jwk_public = {
|
|
"kty": "DAJ",
|
|
"alg": "PAI-GN1",
|
|
"key_ops": ["encrypt"],
|
|
"n": phe.util.int_to_base64(pub.n),
|
|
"kid": "Paillier public key generated by pheutil on {}".format(date),
|
|
}
|
|
|
|
jwk_private = {
|
|
"kty": "DAJ",
|
|
"key_ops": ["decrypt"],
|
|
"p": phe.util.int_to_base64(priv.p),
|
|
"q": phe.util.int_to_base64(priv.q),
|
|
"pub": jwk_public,
|
|
"kid": "Paillier private key generated by pheutil on {}".format(date),
|
|
}
|
|
|
|
public_fp = open("./keys/public.json", "w")
|
|
private_fp = open("./keys/private.json", "w")
|
|
|
|
json.dump(jwk_private, indent=4, fp=private_fp)
|
|
json.dump(jwk_public, indent=4, fp=public_fp)
|
|
|
|
public_fp.write("\n")
|
|
private_fp.write("\n")
|
|
|
|
print("Private key written to {}".format(private_fp.name))
|
|
print("Public key written to {}".format(public_fp.name))
|
|
|
|
@staticmethod
|
|
def _load_public_key(public_key_data):
|
|
error_msg = "Invalid public key"
|
|
assert "alg" in public_key_data, error_msg
|
|
assert public_key_data["alg"] == "PAI-GN1", error_msg
|
|
assert public_key_data["kty"] == "DAJ", error_msg
|
|
|
|
n = phe.util.base64_to_int(public_key_data["n"])
|
|
pub = phe.PaillierPublicKey(n)
|
|
return pub
|
|
|
|
def load_keys(self):
|
|
print("Loading private key")
|
|
private = open(
|
|
"./keys/private.json",
|
|
)
|
|
privatekeydata = json.load(private)
|
|
|
|
print("Loading public key")
|
|
public = open(
|
|
"./keys/public.json",
|
|
)
|
|
publickeydata = json.load(public)
|
|
|
|
assert "pub" in privatekeydata
|
|
pub = Paillier._load_public_key(privatekeydata["pub"])
|
|
|
|
private_key_error = "Invalid private key"
|
|
assert "key_ops" in privatekeydata, private_key_error
|
|
assert "decrypt" in privatekeydata["key_ops"], private_key_error
|
|
assert "p" in privatekeydata, private_key_error
|
|
assert "q" in privatekeydata, private_key_error
|
|
assert privatekeydata["kty"] == "DAJ", private_key_error
|
|
|
|
_p = phe.util.base64_to_int(privatekeydata["p"])
|
|
_q = phe.util.base64_to_int(privatekeydata["q"])
|
|
|
|
self.private_key = phe.PaillierPrivateKey(pub, _p, _q)
|
|
|
|
self.public_key = Paillier._load_public_key(publickeydata)
|
|
|
|
def encrypt(self, x):
|
|
return self.public_key.encrypt(x)
|
|
|
|
def encode(self, x):
|
|
return phe.paillier.EncodedNumber.encode(self.public_key, x)
|
|
|
|
def decrypt(self, x):
|
|
return self.private_key.decrypt(x)
|
|
|
|
def serialize(self, enc_obj):
|
|
return json.dumps(phe.command_line.serialise_encrypted(enc_obj))
|
|
|
|
def deserialize(self, enc_json):
|
|
ciphertext_data = json.loads(enc_json)
|
|
assert "v" in ciphertext_data
|
|
assert "e" in ciphertext_data
|
|
|
|
enc = phe.EncryptedNumber(
|
|
self.public_key, int(ciphertext_data["v"]), exponent=ciphertext_data["e"]
|
|
)
|
|
return enc
|
|
|
|
def encr_sqr_sum(self, A):
|
|
encA = [0] * len(A)
|
|
sumA = 0
|
|
for i in range(len(A)):
|
|
encA[i] = self.encrypt(A[i])
|
|
sumA = sumA + (A[i]) ** 2
|
|
|
|
enc_sumA = self.encrypt(sumA)
|
|
return enc_sumA
|
|
|
|
def get_euclidean_dist(self, A, B):
|
|
encA = [0] * len(A)
|
|
sumA = 0
|
|
for i in range(len(A)):
|
|
encA[i] = self.encrypt(A[i])
|
|
sumA = sumA + (A[i]) ** 2
|
|
|
|
enc_sumA = self.encrypt(sumA)
|
|
print("enc_sumA : ", self.decrypt(enc_sumA))
|
|
|
|
encB = [0] * len(A)
|
|
sumB = 0
|
|
|
|
for i in range(len(B)):
|
|
encB[i] = self.encrypt(B[i])
|
|
sumB = sumB + (B[i]) ** 2
|
|
|
|
enc_sumB = self.encrypt(sumB)
|
|
print("enc_sumB : ", self.decrypt(enc_sumB))
|
|
|
|
prodAB = self.encode(0)
|
|
|
|
for i in range(len(A)):
|
|
encoding = self.encode(-2 * B[i]) # EncodedNumber
|
|
x = (
|
|
encA[i] * encoding
|
|
) # EncryptedNumber E(A[i] * -2 * B[i]) = E(A[i]) ^ (-2*B[i])
|
|
prodAB = x + prodAB
|
|
print(
|
|
"iteration : ",
|
|
i,
|
|
"prodAB : ",
|
|
self.decrypt(prodAB),
|
|
"x : ",
|
|
self.decrypt(x),
|
|
)
|
|
|
|
print(prodAB)
|
|
|
|
dist = enc_sumA + enc_sumB + prodAB
|
|
return self.decrypt(dist) ** 0.5
|
|
|
|
def get_alt_euclidean_dist(self, encA, enc_sumA, B):
|
|
|
|
# print("enc_sumA : ", self.decrypt(enc_sumA))
|
|
|
|
encB = [0] * len(encA)
|
|
sumB = 0
|
|
|
|
for i in range(len(B)):
|
|
encB[i] = self.encrypt(B[i])
|
|
sumB = sumB + (B[i]) ** 2
|
|
|
|
enc_sumB = self.encrypt(sumB)
|
|
# print("enc_sumB : ", self.decrypt(enc_sumB))
|
|
|
|
prodAB = self.encode(0)
|
|
|
|
for i in range(len(encA)):
|
|
encoding = self.encode(-2 * B[i]) # EncodedNumber
|
|
x = encA[i] * encoding # EncryptedNumber E(A[i] * -2 * B[i])
|
|
prodAB = x + prodAB
|
|
# print(
|
|
# "iteration : ",
|
|
# i,
|
|
# "prodAB : ",
|
|
# self.decrypt(prodAB),
|
|
# "x : ",
|
|
# self.decrypt(x),
|
|
# )
|
|
|
|
# print(prodAB)
|
|
|
|
dist = enc_sumA + enc_sumB + prodAB
|
|
return self.decrypt(dist) ** 0.5
|