Add support for a compliant HKDF implementation.

This commit is contained in:
Moxie Marlinspike 2014-07-06 18:44:19 -07:00
parent d6c5e92c9d
commit f29d1e6269
11 changed files with 177 additions and 36 deletions

View file

@ -9,7 +9,109 @@ import java.util.Arrays;
public class HKDFTest extends AndroidTestCase {
public void testVector() {
public void testVectorV3() {
byte[] ikm = {0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b};
byte[] salt = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
0x0a, 0x0b, 0x0c};
byte[] info = {(byte) 0xf0, (byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4,
(byte) 0xf5, (byte) 0xf6, (byte) 0xf7, (byte) 0xf8, (byte) 0xf9};
byte[] expectedOutputOne = {(byte) 0x3c, (byte) 0xb2, (byte) 0x5f, (byte) 0x25, (byte) 0xfa,
(byte) 0xac, (byte) 0xd5, (byte) 0x7a, (byte) 0x90, (byte) 0x43,
(byte) 0x4f, (byte) 0x64, (byte) 0xd0, (byte) 0x36, (byte) 0x2f,
(byte) 0x2a, (byte) 0x2d, (byte) 0x2d, (byte) 0x0a, (byte) 0x90,
(byte) 0xcf, (byte) 0x1a, (byte) 0x5a, (byte) 0x4c, (byte) 0x5d,
(byte) 0xb0, (byte) 0x2d, (byte) 0x56, (byte) 0xec, (byte) 0xc4,
(byte) 0xc5, (byte) 0xbf};
byte[] expectedOutputTwo = {(byte) 0x34, (byte) 0x00, (byte) 0x72, (byte) 0x08, (byte) 0xd5,
(byte) 0xb8, (byte) 0x87, (byte) 0x18, (byte) 0x58, (byte) 0x65};
DerivedSecrets derivedSecrets = HKDF.createFor(3).deriveSecrets(ikm, salt, info);
byte[] truncatedMacKey = new byte[expectedOutputTwo.length];
System.arraycopy(derivedSecrets.getMacKey().getEncoded(), 0, truncatedMacKey, 0, truncatedMacKey.length);
assertTrue(Arrays.equals(derivedSecrets.getCipherKey().getEncoded(), expectedOutputOne));
assertTrue(Arrays.equals(expectedOutputTwo, truncatedMacKey));
}
public void testVectorLongV3() {
byte[] ikm = {(byte) 0x00, (byte) 0x01, (byte) 0x02, (byte) 0x03, (byte) 0x04,
(byte) 0x05, (byte) 0x06, (byte) 0x07, (byte) 0x08, (byte) 0x09,
(byte) 0x0a, (byte) 0x0b, (byte) 0x0c, (byte) 0x0d, (byte) 0x0e,
(byte) 0x0f, (byte) 0x10, (byte) 0x11, (byte) 0x12, (byte) 0x13,
(byte) 0x14, (byte) 0x15, (byte) 0x16, (byte) 0x17, (byte) 0x18,
(byte) 0x19, (byte) 0x1a, (byte) 0x1b, (byte) 0x1c, (byte) 0x1d,
(byte) 0x1e, (byte) 0x1f, (byte) 0x20, (byte) 0x21, (byte) 0x22,
(byte) 0x23, (byte) 0x24, (byte) 0x25, (byte) 0x26, (byte) 0x27,
(byte) 0x28, (byte) 0x29, (byte) 0x2a, (byte) 0x2b, (byte) 0x2c,
(byte) 0x2d, (byte) 0x2e, (byte) 0x2f, (byte) 0x30, (byte) 0x31,
(byte) 0x32, (byte) 0x33, (byte) 0x34, (byte) 0x35, (byte) 0x36,
(byte) 0x37, (byte) 0x38, (byte) 0x39, (byte) 0x3a, (byte) 0x3b,
(byte) 0x3c, (byte) 0x3d, (byte) 0x3e, (byte) 0x3f, (byte) 0x40,
(byte) 0x41, (byte) 0x42, (byte) 0x43, (byte) 0x44, (byte) 0x45,
(byte) 0x46, (byte) 0x47, (byte) 0x48, (byte) 0x49, (byte) 0x4a,
(byte) 0x4b, (byte) 0x4c, (byte) 0x4d, (byte) 0x4e, (byte) 0x4f};
byte[] salt = {(byte) 0x60, (byte) 0x61, (byte) 0x62, (byte) 0x63, (byte) 0x64,
(byte) 0x65, (byte) 0x66, (byte) 0x67, (byte) 0x68, (byte) 0x69,
(byte) 0x6a, (byte) 0x6b, (byte) 0x6c, (byte) 0x6d, (byte) 0x6e,
(byte) 0x6f, (byte) 0x70, (byte) 0x71, (byte) 0x72, (byte) 0x73,
(byte) 0x74, (byte) 0x75, (byte) 0x76, (byte) 0x77, (byte) 0x78,
(byte) 0x79, (byte) 0x7a, (byte) 0x7b, (byte) 0x7c, (byte) 0x7d,
(byte) 0x7e, (byte) 0x7f, (byte) 0x80, (byte) 0x81, (byte) 0x82,
(byte) 0x83, (byte) 0x84, (byte) 0x85, (byte) 0x86, (byte) 0x87,
(byte) 0x88, (byte) 0x89, (byte) 0x8a, (byte) 0x8b, (byte) 0x8c,
(byte) 0x8d, (byte) 0x8e, (byte) 0x8f, (byte) 0x90, (byte) 0x91,
(byte) 0x92, (byte) 0x93, (byte) 0x94, (byte) 0x95, (byte) 0x96,
(byte) 0x97, (byte) 0x98, (byte) 0x99, (byte) 0x9a, (byte) 0x9b,
(byte) 0x9c, (byte) 0x9d, (byte) 0x9e, (byte) 0x9f, (byte) 0xa0,
(byte) 0xa1, (byte) 0xa2, (byte) 0xa3, (byte) 0xa4, (byte) 0xa5,
(byte) 0xa6, (byte) 0xa7, (byte) 0xa8, (byte) 0xa9, (byte) 0xaa,
(byte) 0xab, (byte) 0xac, (byte) 0xad, (byte) 0xae, (byte) 0xaf};
byte[] info = {(byte) 0xb0, (byte) 0xb1, (byte) 0xb2, (byte) 0xb3, (byte) 0xb4,
(byte) 0xb5, (byte) 0xb6, (byte) 0xb7, (byte) 0xb8, (byte) 0xb9,
(byte) 0xba, (byte) 0xbb, (byte) 0xbc, (byte) 0xbd, (byte) 0xbe,
(byte) 0xbf, (byte) 0xc0, (byte) 0xc1, (byte) 0xc2, (byte) 0xc3,
(byte) 0xc4, (byte) 0xc5, (byte) 0xc6, (byte) 0xc7, (byte) 0xc8,
(byte) 0xc9, (byte) 0xca, (byte) 0xcb, (byte) 0xcc, (byte) 0xcd,
(byte) 0xce, (byte) 0xcf, (byte) 0xd0, (byte) 0xd1, (byte) 0xd2,
(byte) 0xd3, (byte) 0xd4, (byte) 0xd5, (byte) 0xd6, (byte) 0xd7,
(byte) 0xd8, (byte) 0xd9, (byte) 0xda, (byte) 0xdb, (byte) 0xdc,
(byte) 0xdd, (byte) 0xde, (byte) 0xdf, (byte) 0xe0, (byte) 0xe1,
(byte) 0xe2, (byte) 0xe3, (byte) 0xe4, (byte) 0xe5, (byte) 0xe6,
(byte) 0xe7, (byte) 0xe8, (byte) 0xe9, (byte) 0xea, (byte) 0xeb,
(byte) 0xec, (byte) 0xed, (byte) 0xee, (byte) 0xef, (byte) 0xf0,
(byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4, (byte) 0xf5,
(byte) 0xf6, (byte) 0xf7, (byte) 0xf8, (byte) 0xf9, (byte) 0xfa,
(byte) 0xfb, (byte) 0xfc, (byte) 0xfd, (byte) 0xfe, (byte) 0xff};
byte[] okm = {(byte) 0xb1, (byte) 0x1e, (byte) 0x39, (byte) 0x8d, (byte) 0xc8,
(byte) 0x03, (byte) 0x27, (byte) 0xa1, (byte) 0xc8, (byte) 0xe7,
(byte) 0xf7, (byte) 0x8c, (byte) 0x59, (byte) 0x6a, (byte) 0x49,
(byte) 0x34, (byte) 0x4f, (byte) 0x01, (byte) 0x2e, (byte) 0xda,
(byte) 0x2d, (byte) 0x4e, (byte) 0xfa, (byte) 0xd8, (byte) 0xa0,
(byte) 0x50, (byte) 0xcc, (byte) 0x4c, (byte) 0x19, (byte) 0xaf,
(byte) 0xa9, (byte) 0x7c, (byte) 0x59, (byte) 0x04, (byte) 0x5a,
(byte) 0x99, (byte) 0xca, (byte) 0xc7, (byte) 0x82, (byte) 0x72,
(byte) 0x71, (byte) 0xcb, (byte) 0x41, (byte) 0xc6, (byte) 0x5e,
(byte) 0x59, (byte) 0x0e, (byte) 0x09, (byte) 0xda, (byte) 0x32,
(byte) 0x75, (byte) 0x60, (byte) 0x0c, (byte) 0x2f, (byte) 0x09,
(byte) 0xb8, (byte) 0x36, (byte) 0x77, (byte) 0x93, (byte) 0xa9,
(byte) 0xac, (byte) 0xa3, (byte) 0xdb, (byte) 0x71, (byte) 0xcc,
(byte) 0x30, (byte) 0xc5, (byte) 0x81, (byte) 0x79, (byte) 0xec,
(byte) 0x3e, (byte) 0x87, (byte) 0xc1, (byte) 0x4c, (byte) 0x01,
(byte) 0xd5, (byte) 0xc1, (byte) 0xf3, (byte) 0x43, (byte) 0x4f,
(byte) 0x1d, (byte) 0x87};
}
public void testVectorV2() {
byte[] ikm = {0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b};
@ -36,7 +138,7 @@ public class HKDFTest extends AndroidTestCase {
(byte)0xa9, (byte)0xfd, (byte)0xa8, (byte)0x99, (byte)0xda,
(byte)0xeb, (byte)0xec};
DerivedSecrets derivedSecrets = new HKDF().deriveSecrets(ikm, salt, info);
DerivedSecrets derivedSecrets = HKDF.createFor(2).deriveSecrets(ikm, salt, info);
assertTrue(Arrays.equals(derivedSecrets.getCipherKey().getEncoded(), expectedOutputOne));
assertTrue(Arrays.equals(derivedSecrets.getMacKey().getEncoded(), expectedOutputTwo));

View file

@ -2,6 +2,7 @@ package org.whispersystems.test.ratchet;
import android.test.AndroidTestCase;
import org.whispersystems.libaxolotl.kdf.HKDF;
import org.whispersystems.libaxolotl.ratchet.ChainKey;
import java.security.NoSuchAlgorithmException;
@ -9,7 +10,7 @@ import java.util.Arrays;
public class ChainKeyTest extends AndroidTestCase {
public void testChainKeyDerivation() throws NoSuchAlgorithmException {
public void testChainKeyDerivationV2() throws NoSuchAlgorithmException {
byte[] seed = {(byte) 0x8a, (byte) 0xb7, (byte) 0x2d, (byte) 0x6f, (byte) 0x4c,
(byte) 0xc5, (byte) 0xac, (byte) 0x0d, (byte) 0x38, (byte) 0x7e,
@ -43,7 +44,7 @@ public class ChainKeyTest extends AndroidTestCase {
(byte) 0xc1, (byte) 0x03, (byte) 0x42, (byte) 0xa2, (byte) 0x46,
(byte) 0xd1, (byte) 0x5d};
ChainKey chainKey = new ChainKey(seed, 0);
ChainKey chainKey = new ChainKey(HKDF.createFor(2), seed, 0);
assertTrue(Arrays.equals(chainKey.getKey(), seed));
assertTrue(Arrays.equals(chainKey.getMessageKeys().getCipherKey().getEncoded(), messageKey));

View file

@ -7,6 +7,7 @@ import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.kdf.HKDF;
import org.whispersystems.libaxolotl.ratchet.ChainKey;
import org.whispersystems.libaxolotl.ratchet.RootKey;
import org.whispersystems.libaxolotl.util.Pair;
@ -16,7 +17,7 @@ import java.util.Arrays;
public class RootKeyTest extends AndroidTestCase {
public void testRootKeyDerivation() throws NoSuchAlgorithmException, InvalidKeyException {
public void testRootKeyDerivationV2() throws NoSuchAlgorithmException, InvalidKeyException {
byte[] rootKeySeed = {(byte) 0x7b, (byte) 0xa6, (byte) 0xde, (byte) 0xbc, (byte) 0x2b,
(byte) 0xc1, (byte) 0xbb, (byte) 0xf9, (byte) 0x1a, (byte) 0xbb,
(byte) 0xc1, (byte) 0x36, (byte) 0x74, (byte) 0x04, (byte) 0x17,
@ -69,8 +70,8 @@ public class RootKeyTest extends AndroidTestCase {
ECPrivateKey alicePrivateKey = Curve.decodePrivatePoint(alicePrivate);
ECKeyPair aliceKeyPair = new ECKeyPair(alicePublicKey, alicePrivateKey);
ECPublicKey bobPublicKey = Curve.decodePoint(bobPublic, 0);
RootKey rootKey = new RootKey(rootKeySeed);
ECPublicKey bobPublicKey = Curve.decodePoint(bobPublic, 0);
RootKey rootKey = new RootKey(HKDF.createFor(2), rootKeySeed);
Pair<RootKey, ChainKey> rootKeyChainKeyPair = rootKey.createChain(bobPublicKey, aliceKeyPair);
RootKey nextRootKey = rootKeyChainKeyPair.first();

View file

@ -24,7 +24,7 @@ import java.security.NoSuchAlgorithmException;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
public class HKDF {
public abstract class HKDF {
private static final int HASH_OUTPUT_SIZE = 32;
private static final int KEY_MATERIAL_SIZE = 64;
@ -32,6 +32,14 @@ public class HKDF {
private static final int CIPHER_KEYS_OFFSET = 0;
private static final int MAC_KEYS_OFFSET = 32;
public static HKDF createFor(int messageVersion) {
switch (messageVersion) {
case 2: return new HKDFv2();
case 3: return new HKDFv3();
default: throw new AssertionError("Unknown version: " + messageVersion);
}
}
public DerivedSecrets deriveSecrets(byte[] inputKeyMaterial, byte[] info) {
byte[] salt = new byte[HASH_OUTPUT_SIZE];
return deriveSecrets(inputKeyMaterial, salt, info);
@ -64,9 +72,7 @@ public class HKDF {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(salt, "HmacSHA256"));
return mac.doFinal(inputKeyMaterial);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
} catch (InvalidKeyException e) {
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
}
}
@ -77,7 +83,7 @@ public class HKDF {
byte[] mixin = new byte[0];
ByteArrayOutputStream results = new ByteArrayOutputStream();
for (int i=0;i<iterations;i++) {
for (int i= getIterationStartOffset();i<iterations + getIterationEndOffset();i++) {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(prk, "HmacSHA256"));
@ -94,13 +100,12 @@ public class HKDF {
}
return results.toByteArray();
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
} catch (InvalidKeyException e) {
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
}
}
protected abstract int getIterationStartOffset();
protected abstract int getIterationEndOffset();
}

View file

@ -0,0 +1,13 @@
package org.whispersystems.libaxolotl.kdf;
public class HKDFv2 extends HKDF {
@Override
protected int getIterationStartOffset() {
return 0;
}
@Override
protected int getIterationEndOffset() {
return 0;
}
}

View file

@ -0,0 +1,13 @@
package org.whispersystems.libaxolotl.kdf;
public class HKDFv3 extends HKDF {
@Override
protected int getIterationStartOffset() {
return 1;
}
@Override
protected int getIterationEndOffset() {
return 1;
}
}

View file

@ -34,8 +34,8 @@ public class KeyExchangeMessage {
ECPublicKey baseKey, ECPublicKey ephemeralKey,
IdentityKey identityKey)
{
this.supportedVersion = CiphertextMessage.CURRENT_VERSION;
this.version = CiphertextMessage.CURRENT_VERSION;
this.supportedVersion = 2;
this.version = 2;
this.sequence = sequence;
this.flags = flags;
this.baseKey = baseKey;

View file

@ -31,10 +31,12 @@ public class ChainKey {
private static final byte[] MESSAGE_KEY_SEED = {0x01};
private static final byte[] CHAIN_KEY_SEED = {0x02};
private final HKDF kdf;
private final byte[] key;
private final int index;
public ChainKey(byte[] key, int index) {
public ChainKey(HKDF kdf, byte[] key, int index) {
this.kdf = kdf;
this.key = key;
this.index = index;
}
@ -49,11 +51,10 @@ public class ChainKey {
public ChainKey getNextChainKey() {
byte[] nextKey = getBaseMaterial(CHAIN_KEY_SEED);
return new ChainKey(nextKey, index + 1);
return new ChainKey(kdf, nextKey, index + 1);
}
public MessageKeys getMessageKeys() {
HKDF kdf = new HKDF();
byte[] inputKeyMaterial = getBaseMaterial(MESSAGE_KEY_SEED);
DerivedSecrets keyMaterial = kdf.deriveSecrets(inputKeyMaterial, "WhisperMessageKeys".getBytes());

View file

@ -108,8 +108,9 @@ public class RatchetingSession {
throws InvalidKeyException
{
try {
byte[] discontinuity = new byte[32];
ByteArrayOutputStream secrets = new ByteArrayOutputStream();
HKDF kdf = HKDF.createFor(sessionVersion);
byte[] discontinuity = new byte[32];
ByteArrayOutputStream secrets = new ByteArrayOutputStream();
if (sessionVersion >= 3) {
Arrays.fill(discontinuity, (byte) 0xFF);
@ -130,11 +131,10 @@ public class RatchetingSession {
secrets.write(Curve.calculateAgreement(theirPreKey, ourPreKey.getPrivateKey()));
}
DerivedSecrets derivedSecrets = new HKDF().deriveSecrets(secrets.toByteArray(),
"WhisperText".getBytes());
DerivedSecrets derivedSecrets = kdf.deriveSecrets(secrets.toByteArray(), "WhisperText".getBytes());
return new Pair<>(new RootKey(derivedSecrets.getCipherKey().getEncoded()),
new ChainKey(derivedSecrets.getMacKey().getEncoded(), 0));
return new Pair<>(new RootKey(kdf, derivedSecrets.getCipherKey().getEncoded()),
new ChainKey(kdf, derivedSecrets.getMacKey().getEncoded(), 0));
} catch (IOException e) {
throw new AssertionError(e);
}

View file

@ -26,9 +26,11 @@ import org.whispersystems.libaxolotl.util.Pair;
public class RootKey {
private final HKDF kdf;
private final byte[] key;
public RootKey(byte[] key) {
public RootKey(HKDF kdf, byte[] key) {
this.kdf = kdf;
this.key = key;
}
@ -39,11 +41,10 @@ public class RootKey {
public Pair<RootKey, ChainKey> createChain(ECPublicKey theirEphemeral, ECKeyPair ourEphemeral)
throws InvalidKeyException
{
HKDF kdf = new HKDF();
byte[] sharedSecret = Curve.calculateAgreement(theirEphemeral, ourEphemeral.getPrivateKey());
DerivedSecrets keys = kdf.deriveSecrets(sharedSecret, key, "WhisperRatchet".getBytes());
RootKey newRootKey = new RootKey(keys.getCipherKey().getEncoded());
ChainKey newChainKey = new ChainKey(keys.getMacKey().getEncoded(), 0);
RootKey newRootKey = new RootKey(kdf, keys.getCipherKey().getEncoded());
ChainKey newChainKey = new ChainKey(kdf, keys.getMacKey().getEncoded(), 0);
return new Pair<>(newRootKey, newChainKey);
}

View file

@ -28,6 +28,7 @@ import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.kdf.HKDF;
import org.whispersystems.libaxolotl.ratchet.ChainKey;
import org.whispersystems.libaxolotl.ratchet.MessageKeys;
import org.whispersystems.libaxolotl.ratchet.RootKey;
@ -137,7 +138,8 @@ public class SessionState {
}
public RootKey getRootKey() {
return new RootKey(this.sessionStructure.getRootKey().toByteArray());
return new RootKey(HKDF.createFor(getSessionVersion()),
this.sessionStructure.getRootKey().toByteArray());
}
public void setRootKey(RootKey rootKey) {
@ -180,7 +182,7 @@ public class SessionState {
ECPublicKey chainSenderEphemeral = Curve.decodePoint(receiverChain.getSenderEphemeral().toByteArray(), 0);
if (chainSenderEphemeral.equals(senderEphemeral)) {
return new Pair<Chain,Integer>(receiverChain,index);
return new Pair<>(receiverChain,index);
}
} catch (InvalidKeyException e) {
Log.w("SessionRecordV2", e);
@ -199,7 +201,8 @@ public class SessionState {
if (receiverChain == null) {
return null;
} else {
return new ChainKey(receiverChain.getChainKey().getKey().toByteArray(),
return new ChainKey(HKDF.createFor(getSessionVersion()),
receiverChain.getChainKey().getKey().toByteArray(),
receiverChain.getChainKey().getIndex());
}
}
@ -241,7 +244,8 @@ public class SessionState {
public ChainKey getSenderChainKey() {
Chain.ChainKey chainKeyStructure = sessionStructure.getSenderChain().getChainKey();
return new ChainKey(chainKeyStructure.getKey().toByteArray(), chainKeyStructure.getIndex());
return new ChainKey(HKDF.createFor(getSessionVersion()),
chainKeyStructure.getKey().toByteArray(), chainKeyStructure.getIndex());
}
@ -284,7 +288,7 @@ public class SessionState {
return null;
}
List<Chain.MessageKey> messageKeyList = new LinkedList<Chain.MessageKey>(chain.getMessageKeysList());
List<Chain.MessageKey> messageKeyList = new LinkedList<>(chain.getMessageKeysList());
Iterator<Chain.MessageKey> messageKeyIterator = messageKeyList.iterator();
MessageKeys result = null;