diff --git a/hpke/algs.go b/hpke/algs.go index a9fbc6612..865fb8c3c 100644 --- a/hpke/algs.go +++ b/hpke/algs.go @@ -16,6 +16,7 @@ import ( "github.com/cloudflare/circl/ecc/p384" "github.com/cloudflare/circl/kem" "github.com/cloudflare/circl/kem/kyber/kyber768" + "github.com/cloudflare/circl/kem/xwing" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" ) @@ -39,6 +40,8 @@ const ( // KEM_X25519_KYBER768_DRAFT00 is a hybrid KEM built on DHKEM(X25519, HKDF-SHA256) // and Kyber768Draft00 KEM_X25519_KYBER768_DRAFT00 KEM = 0x30 + // KEM_XWING is a hybrid KEM using X25519 and ML-KEM-768. + KEM_XWING KEM = 0x647a ) // IsValid returns true if the KEM identifier is supported by the HPKE package. @@ -49,7 +52,8 @@ func (k KEM) IsValid() bool { KEM_P521_HKDF_SHA512, KEM_X25519_HKDF_SHA256, KEM_X448_HKDF_SHA512, - KEM_X25519_KYBER768_DRAFT00: + KEM_X25519_KYBER768_DRAFT00, + KEM_XWING: return true default: return false @@ -58,7 +62,7 @@ func (k KEM) IsValid() bool { // Scheme returns an instance of a KEM that supports authentication. Panics if // the KEM identifier is invalid. -func (k KEM) Scheme() kem.AuthScheme { +func (k KEM) Scheme() kem.Scheme { switch k { case KEM_P256_HKDF_SHA256: return dhkemp256hkdfsha256 @@ -72,6 +76,8 @@ func (k KEM) Scheme() kem.AuthScheme { return dhkemx448hkdfsha512 case KEM_X25519_KYBER768_DRAFT00: return hybridkemX25519Kyber768 + case KEM_XWING: + return kemXwing default: panic(ErrInvalidKEM) } @@ -237,6 +243,7 @@ var ( dhkemp256hkdfsha256, dhkemp384hkdfsha384, dhkemp521hkdfsha512 shortKEM dhkemx25519hkdfsha256, dhkemx448hkdfsha512 xKEM hybridkemX25519Kyber768 hybridKEM + kemXwing genericNoAuthKEM ) func init() { @@ -275,4 +282,7 @@ func init() { hybridkemX25519Kyber768.kemBase.Hash = crypto.SHA256 hybridkemX25519Kyber768.kemA = dhkemx25519hkdfsha256 hybridkemX25519Kyber768.kemB = kyber768.Scheme() + + kemXwing.Scheme = xwing.Scheme() + kemXwing.name = "HPKE_KEM_XWING" } diff --git a/hpke/genericnoauthkem.go b/hpke/genericnoauthkem.go new file mode 100644 index 000000000..88e12352f --- /dev/null +++ b/hpke/genericnoauthkem.go @@ -0,0 +1,27 @@ +package hpke + +// Shim to use generic KEM (kem.Scheme) as HPKE KEM. + +import ( + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" +) + +// genericNoAuthKEM wraps a generic KEM (kem.Scheme) to be used as a HPKE KEM. +type genericNoAuthKEM struct { + kem.Scheme + name string +} + +func (h genericNoAuthKEM) Name() string { return h.name } + +// HPKE requires DeriveKeyPair() to take any seed larger than the private key +// size, whereas typical KEMs expect a specific seed size. We'll just use +// SHAKE256 to hash it to the right size as in X-Wing. +func (h genericNoAuthKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + seed2 := make([]byte, h.Scheme.SeedSize()) + hh := sha3.NewShake256() + _, _ = hh.Write(seed) + _, _ = hh.Read(seed2) + return h.Scheme.DeriveKeyPair(seed2) +} diff --git a/hpke/hpke.go b/hpke/hpke.go index 4075b285e..ccd50d2cd 100644 --- a/hpke/hpke.go +++ b/hpke/hpke.go @@ -224,7 +224,12 @@ func (s *Sender) allSetup(rnd io.Reader) ([]byte, Sealer, error) { case modeBase, modePSK: enc, ss, err = scheme.EncapsulateDeterministically(s.pkR, seed) case modeAuth, modeAuthPSK: - enc, ss, err = scheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed) + authScheme, ok := scheme.(kem.AuthScheme) + if !ok { + return nil, nil, ErrInvalidAuthKEM + } + + enc, ss, err = authScheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed) } if err != nil { return nil, nil, err @@ -246,7 +251,12 @@ func (r *Receiver) allSetup() (Opener, error) { case modeBase, modePSK: ss, err = scheme.Decapsulate(r.skR, r.enc) case modeAuth, modeAuthPSK: - ss, err = scheme.AuthDecapsulate(r.skR, r.enc, r.pkS) + authScheme, ok := scheme.(kem.AuthScheme) + if !ok { + return nil, ErrInvalidAuthKEM + } + + ss, err = authScheme.AuthDecapsulate(r.skR, r.enc, r.pkS) } if err != nil { return nil, err @@ -263,6 +273,7 @@ var ( ErrInvalidHPKESuite = errors.New("hpke: invalid HPKE suite") ErrInvalidKDF = errors.New("hpke: invalid KDF identifier") ErrInvalidKEM = errors.New("hpke: invalid KEM identifier") + ErrInvalidAuthKEM = errors.New("hpke: KEM does not support Auth mode") ErrInvalidAEAD = errors.New("hpke: invalid AEAD identifier") ErrInvalidKEMPublicKey = errors.New("hpke: invalid KEM public key") ErrInvalidKEMPrivateKey = errors.New("hpke: invalid KEM private key") diff --git a/hpke/hpke_test.go b/hpke/hpke_test.go index b679c825c..23d782af0 100644 --- a/hpke/hpke_test.go +++ b/hpke/hpke_test.go @@ -160,6 +160,7 @@ func BenchmarkHpkeRoundTrip(b *testing.B) { }{ {hpke.KEM_X25519_HKDF_SHA256, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM}, {hpke.KEM_X25519_KYBER768_DRAFT00, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM}, + {hpke.KEM_XWING, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM}, } for _, test := range tests { runHpkeBenchmark(b, test.kem, test.kdf, test.aead) diff --git a/kem/schemes/schemes.go b/kem/schemes/schemes.go index 747d99e9d..836948259 100644 --- a/kem/schemes/schemes.go +++ b/kem/schemes/schemes.go @@ -29,6 +29,7 @@ import ( "github.com/cloudflare/circl/kem/mlkem/mlkem1024" "github.com/cloudflare/circl/kem/mlkem/mlkem512" "github.com/cloudflare/circl/kem/mlkem/mlkem768" + "github.com/cloudflare/circl/kem/xwing" ) var allSchemes = [...]kem.Scheme{ @@ -50,6 +51,7 @@ var allSchemes = [...]kem.Scheme{ hybrid.Kyber1024X448(), hybrid.P256Kyber768Draft00(), hybrid.X25519MLKEM768(), + xwing.Scheme(), } var allSchemeNames map[string]kem.Scheme diff --git a/kem/schemes/schemes_test.go b/kem/schemes/schemes_test.go index d9caf70a9..e81896c53 100644 --- a/kem/schemes/schemes_test.go +++ b/kem/schemes/schemes_test.go @@ -160,4 +160,5 @@ func Example_schemes() { // Kyber1024-X448 // P256Kyber768Draft00 // X25519MLKEM768 + // X-Wing } diff --git a/kem/xwing/scheme.go b/kem/xwing/scheme.go new file mode 100644 index 000000000..6c01477b3 --- /dev/null +++ b/kem/xwing/scheme.go @@ -0,0 +1,140 @@ +package xwing + +import ( + "bytes" + cryptoRand "crypto/rand" + "crypto/subtle" + + "github.com/cloudflare/circl/kem" + "github.com/cloudflare/circl/kem/mlkem/mlkem768" +) + +// This file contains the boilerplate code to connect X-Wing to the +// generic KEM API. + +// Returns the generic KEM interface for X-Wing PQ/T hybrid KEM. +func Scheme() kem.Scheme { return scheme{} } + +type scheme struct{} + +func (scheme) Name() string { return "X-Wing" } +func (scheme) PublicKeySize() int { return PublicKeySize } +func (scheme) PrivateKeySize() int { return PrivateKeySize } +func (scheme) SeedSize() int { return SeedSize } +func (scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize } +func (scheme) SharedKeySize() int { return SharedKeySize } +func (scheme) CiphertextSize() int { return CiphertextSize } +func (*PrivateKey) Scheme() kem.Scheme { return scheme{} } +func (*PublicKey) Scheme() kem.Scheme { return scheme{} } + +func (sch scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + var seed [EncapsulationSeedSize]byte + _, err = cryptoRand.Read(seed[:]) + if err != nil { + return + } + return sch.EncapsulateDeterministically(pk, seed[:]) +} + +func (scheme) EncapsulateDeterministically( + pk kem.PublicKey, seed []byte, +) ([]byte, []byte, error) { + if len(seed) != EncapsulationSeedSize { + return nil, nil, kem.ErrSeedSize + } + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + var ( + ct [CiphertextSize]byte + ss [SharedKeySize]byte + ) + pub.EncapsulateTo(ct[:], ss[:], seed) + return ct[:], ss[:], nil +} + +func (scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + var pk PublicKey + if len(buf) != PublicKeySize { + return nil, kem.ErrPubKeySize + } + + if err := pk.Unpack(buf); err != nil { + return nil, err + } + return &pk, nil +} + +func (scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + var sk PrivateKey + if len(buf) != PrivateKeySize { + return nil, kem.ErrPrivKeySize + } + + sk.Unpack(buf) + return &sk, nil +} + +func (sk *PrivateKey) MarshalBinary() ([]byte, error) { + var ret [PrivateKeySize]byte + sk.Pack(ret[:]) + return ret[:], nil +} + +func (sk *PrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*PrivateKey) + if !ok { + return false + } + return sk.m.Equal(&oth.m) && + subtle.ConstantTimeCompare(oth.x[:], sk.x[:]) == 1 +} + +func (sk *PrivateKey) Public() kem.PublicKey { + var pk PublicKey + pk.m = *(sk.m.Public().(*mlkem768.PublicKey)) + pk.x = sk.xpk + return &pk +} + +func (pk *PublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*PublicKey) + if !ok { + return false + } + return pk.m.Equal(&oth.m) && bytes.Equal(pk.x[:], oth.x[:]) +} + +func (pk *PublicKey) MarshalBinary() ([]byte, error) { + var ret [PublicKeySize]byte + pk.Pack(ret[:]) + return ret[:], nil +} + +func (scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + sk, pk := DeriveKeyPair(seed) + return pk, sk +} + +func (scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + sk, pk, err := GenerateKeyPair(nil) + return pk, sk, err +} + +func (scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != CiphertextSize { + return nil, kem.ErrCiphertextSize + } + + var ss [SharedKeySize]byte + + priv, ok := sk.(*PrivateKey) + if !ok { + return nil, kem.ErrTypeMismatch + } + + priv.DecapsulateTo(ss[:], ct[:]) + + return ss[:], nil +} diff --git a/kem/xwing/xwing.go b/kem/xwing/xwing.go new file mode 100644 index 000000000..7e2807907 --- /dev/null +++ b/kem/xwing/xwing.go @@ -0,0 +1,310 @@ +// Package xwing implements the X-Wing PQ/T hybrid KEM +// +// https://datatracker.ietf.org/doc/draft-connolly-cfrg-xwing-kem +// +// Implements the final version (-05). +package xwing + +import ( + cryptoRand "crypto/rand" + "errors" + "io" + + "github.com/cloudflare/circl/dh/x25519" + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" + "github.com/cloudflare/circl/kem/mlkem/mlkem768" +) + +// An X-Wing private key. +type PrivateKey struct { + seed [32]byte + m mlkem768.PrivateKey + x x25519.Key + xpk x25519.Key +} + +// An X-Wing public key. +type PublicKey struct { + m mlkem768.PublicKey + x x25519.Key +} + +const ( + // Size of a seed of a keypair + SeedSize = 32 + + // Size of an X-Wing public key + PublicKeySize = 1216 + + // Size of an X-Wing private key + PrivateKeySize = 32 + + // Size of the seed passed to EncapsulateTo + EncapsulationSeedSize = 64 + + // Size of the established shared key + SharedKeySize = 32 + + // Size of an X-Wing ciphertext. + CiphertextSize = 1120 +) + +func combiner( + out []byte, + ssm *[mlkem768.SharedKeySize]byte, + ssx *x25519.Key, + ctx *x25519.Key, + pkx *x25519.Key, +) { + h := sha3.New256() + _, _ = h.Write(ssm[:]) + _, _ = h.Write(ssx[:]) + _, _ = h.Write(ctx[:]) + _, _ = h.Write(pkx[:]) + + // \./ + // /^\ + _, _ = h.Write([]byte(`\.//^\`)) + + _, _ = h.Read(out[:]) +} + +// Packs sk to buf. +// +// Panics if buf is not of size PrivateKeySize +func (sk *PrivateKey) Pack(buf []byte) { + if len(buf) != PrivateKeySize { + panic(kem.ErrPrivKeySize) + } + copy(buf, sk.seed[:]) +} + +// Packs pk to buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Pack(buf []byte) { + if len(buf) != PublicKeySize { + panic(kem.ErrPubKeySize) + } + pk.m.Pack(buf[:mlkem768.PublicKeySize]) + copy(buf[mlkem768.PublicKeySize:], pk.x[:]) +} + +// DeriveKeyPair derives a public/private keypair deterministically +// from the given seed. +// +// Panics if seed is not of length SeedSize. +func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { + var ( + sk PrivateKey + pk PublicKey + ) + + deriveKeyPair(seed, &sk, &pk) + + return &sk, &pk +} + +func deriveKeyPair(seed []byte, sk *PrivateKey, pk *PublicKey) { + if len(seed) != SeedSize { + panic(kem.ErrSeedSize) + } + + var seedm [mlkem768.KeySeedSize]byte + + copy(sk.seed[:], seed) + + h := sha3.NewShake256() + _, _ = h.Write(seed) + _, _ = h.Read(seedm[:]) + _, _ = h.Read(sk.x[:]) + + pkm, skm := mlkem768.NewKeyFromSeed(seedm[:]) + sk.m = *skm + pk.m = *pkm + + x25519.KeyGen(&pk.x, &sk.x) + sk.xpk = pk.x +} + +// DeriveKeyPairPacked derives a keypair like DeriveKeyPair, and +// returns them packed. +func DeriveKeyPairPacked(seed []byte) ([]byte, []byte) { + sk, pk := DeriveKeyPair(seed) + var ( + ppk [PublicKeySize]byte + psk [PrivateKeySize]byte + ) + pk.Pack(ppk[:]) + sk.Pack(psk[:]) + return psk[:], ppk[:] +} + +// GenerateKeyPair generates public and private keys using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKeyPair(rand io.Reader) (*PrivateKey, *PublicKey, error) { + var seed [SeedSize]byte + if rand == nil { + rand = cryptoRand.Reader + } + _, err := io.ReadFull(rand, seed[:]) + if err != nil { + return nil, nil, err + } + sk, pk := DeriveKeyPair(seed[:]) + return sk, pk, nil +} + +// GenerateKeyPairPacked generates a keypair like GenerateKeyPair, and +// returns them packed. +func GenerateKeyPairPacked(rand io.Reader) ([]byte, []byte, error) { + sk, pk, err := GenerateKeyPair(rand) + if err != nil { + return nil, nil, err + } + var ( + ppk [PublicKeySize]byte + psk [PrivateKeySize]byte + ) + pk.Pack(ppk[:]) + sk.Pack(psk[:]) + return psk[:], ppk[:], nil +} + +// Encapsulate generates a shared key and ciphertext that contains it +// for the public key pk using randomness from seed. +// +// seed may be nil, in which case crypto/rand.Reader is used. +// +// Warning: note that the order of the returned ss and ct matches the +// X-Wing standard, which is the reverse of the Circl KEM API. +// +// Returns ErrPubKey if ML-KEM encapsulation key check fails. +// +// Panics if pk is not of size PublicKeySize, or randomness could not +// be read from crypto/rand.Reader. +func Encapsulate(pk, seed []byte) (ss, ct []byte, err error) { + var pub PublicKey + if err := pub.Unpack(pk); err != nil { + return nil, nil, err + } + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + pub.EncapsulateTo(ct, ss, seed) + return ss, ct, nil +} + +// Decapsulate computes the shared key which is encapsulated in ct +// for the private key sk. +// +// Panics if sk or ct are not of length PrivateKeySize and CiphertextSize +// respectively. +func Decapsulate(ct, sk []byte) (ss []byte) { + var priv PrivateKey + priv.Unpack(sk) + ss = make([]byte, SharedKeySize) + priv.DecapsulateTo(ss, ct) + return ss +} + +// Raised when passing a byte slice of the wrong size for the shared +// secret to the EncapsulateTo or DecapsulateTo functions. +var ErrSharedKeySize = errors.New("wrong size for shared key") + +// EncapsulateTo generates a shared key and ciphertext that contains it +// for the public key using randomness from seed and writes the shared key +// to ss and ciphertext to ct. +// +// Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize +// and EncapsulationSeedSize respectively. +// +// seed may be nil, in which case crypto/rand.Reader is used to generate one. +func (pk *PublicKey) EncapsulateTo(ct, ss, seed []byte) { + if seed == nil { + seed = make([]byte, EncapsulationSeedSize) + if _, err := cryptoRand.Read(seed[:]); err != nil { + panic(err) + } + } else { + if len(seed) != EncapsulationSeedSize { + panic(kem.ErrSeedSize) + } + } + + if len(ct) != CiphertextSize { + panic(kem.ErrCiphertextSize) + } + + if len(ss) != SharedKeySize { + panic(ErrSharedKeySize) + } + + var ( + seedm [32]byte + ekx x25519.Key + ctx x25519.Key + ssx x25519.Key + ssm [mlkem768.SharedKeySize]byte + ) + + copy(seedm[:], seed[:32]) + copy(ekx[:], seed[32:]) + + x25519.KeyGen(&ctx, &ekx) + x25519.Shared(&ssx, &ekx, &pk.x) + pk.m.EncapsulateTo(ct[:mlkem768.CiphertextSize], ssm[:], seedm[:]) + + combiner(ss, &ssm, &ssx, &ctx, &pk.x) + copy(ct[mlkem768.CiphertextSize:], ctx[:]) +} + +// DecapsulateTo computes the shared key which is encapsulated in ct +// for the private key. +// +// Panics if ct or ss are not of length CiphertextSize and SharedKeySize +// respectively. +func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { + if len(ct) != CiphertextSize { + panic(kem.ErrCiphertextSize) + } + if len(ss) != SharedKeySize { + panic(ErrSharedKeySize) + } + + ctm := ct[:mlkem768.CiphertextSize] + + var ( + ssm [mlkem768.SharedKeySize]byte + ssx x25519.Key + ctx x25519.Key + ) + + copy(ctx[:], ct[mlkem768.CiphertextSize:]) + + sk.m.DecapsulateTo(ssm[:], ctm) + x25519.Shared(&ssx, &sk.x, &ctx) + combiner(ss, &ssm, &ssx, &ctx, &sk.xpk) +} + +// Unpacks pk from buf. +// +// Panics if buf is not of size PublicKeySize. +// +// Returns ErrPubKey if pk fails the ML-KEM encapsulation key check. +func (pk *PublicKey) Unpack(buf []byte) error { + if len(buf) != PublicKeySize { + panic(kem.ErrPubKeySize) + } + + copy(pk.x[:], buf[mlkem768.PublicKeySize:]) + return pk.m.Unpack(buf[:mlkem768.PublicKeySize]) +} + +// Unpacks sk from buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Unpack(buf []byte) { + var pk PublicKey + deriveKeyPair(buf, sk, &pk) +} diff --git a/kem/xwing/xwing_test.go b/kem/xwing/xwing_test.go new file mode 100644 index 000000000..760c75a0d --- /dev/null +++ b/kem/xwing/xwing_test.go @@ -0,0 +1,79 @@ +package xwing + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/cloudflare/circl/internal/sha3" +) + +func writeHex(w io.Writer, prefix string, val interface{}) { + indent := " " + width := 74 + hex := fmt.Sprintf("%x", val) + if len(prefix)+len(hex)+5 < width { + fmt.Fprintf(w, "%s %s\n", prefix, hex) + return + } + fmt.Fprintf(w, "%s\n", prefix) + for len(hex) != 0 { + var toPrint string + if len(hex) < width-len(indent) { + toPrint = hex + hex = "" + } else { + toPrint = hex[:width-len(indent)] + hex = hex[width-len(indent):] + } + fmt.Fprintf(w, "%s%s\n", indent, toPrint) + } +} + +func TestVectors(t *testing.T) { + h := sha3.NewShake128() + w := new(bytes.Buffer) + + for i := 0; i < 3; i++ { + var seed [SeedSize]byte + _, _ = h.Read(seed[:]) + writeHex(w, "seed", seed) + + sk, pk := DeriveKeyPairPacked(seed[:]) + writeHex(w, "sk", sk) + writeHex(w, "pk", pk) + + var eseed [EncapsulationSeedSize]byte + _, _ = h.Read(eseed[:]) + writeHex(w, "eseed", eseed) + + ss, ct, err := Encapsulate(pk, eseed[:]) + if err != nil { + t.Fatal(err) + } + writeHex(w, "ct", ct) + writeHex(w, "ss", ss) + + ss2 := Decapsulate(ct, sk) + if !bytes.Equal(ss, ss2) { + t.Fatal() + } + + fmt.Fprintf(w, "\n") + } + + t.Logf("%s", w.String()) + h.Reset() + _, _ = h.Write(w.Bytes()) + var cs [32]byte + _, _ = h.Read(cs[:]) + got := fmt.Sprintf("%x", cs) + + // shake128 of spec/test-vectors.txt from X-Wing spec at + // https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem + want := "1bcd0057d861d6b866239936cadcaeee1ec0164dedc181c386e9e54fe46156fe" + if got != want { + t.Fatalf("%s ≠ %s", got, want) + } +}