diff --git a/kem/kyber/gen.go b/kem/kyber/gen.go index 1abca8e5e..c0c815b61 100644 --- a/kem/kyber/gen.go +++ b/kem/kyber/gen.go @@ -7,8 +7,10 @@ package main import ( "bytes" + "fmt" "go/format" "io/ioutil" + "path" "strings" "text/template" ) @@ -17,8 +19,33 @@ type Instance struct { Name string } +func (m Instance) KemName() string { + if m.NIST() { + return m.Name + } + return m.Name + ".CCAKEM" +} + +func (m Instance) NIST() bool { + return strings.HasPrefix(m.Name, "ML-KEM") +} + +func (m Instance) PkePkg() string { + if !m.NIST() { + return m.Pkg() + } + return strings.ReplaceAll(m.Pkg(), "mlkem", "kyber") +} + func (m Instance) Pkg() string { - return strings.ToLower(m.Name) + return strings.ToLower(strings.ReplaceAll(m.Name, "-", "")) +} + +func (m Instance) PkgPath() string { + if m.NIST() { + return path.Join("..", "mlkem", m.Pkg()) + } + return m.Pkg() } var ( @@ -26,6 +53,9 @@ var ( {Name: "Kyber512"}, {Name: "Kyber768"}, {Name: "Kyber1024"}, + {Name: "ML-KEM-512"}, + {Name: "ML-KEM-768"}, + {Name: "ML-KEM-1024"}, } TemplateWarning = "// Code generated from" ) @@ -51,7 +81,7 @@ func generatePackageFiles() { // Formating output code code, err := format.Source(buf.Bytes()) if err != nil { - panic("error formating code") + panic(fmt.Sprintf("error formating code: %v", err)) } res := string(code) @@ -59,7 +89,7 @@ func generatePackageFiles() { if offset == -1 { panic("Missing template warning in pkg.templ.go") } - err = ioutil.WriteFile(mode.Pkg()+"/kyber.go", []byte(res[offset:]), 0o644) + err = ioutil.WriteFile(mode.PkgPath()+"/kyber.go", []byte(res[offset:]), 0o644) if err != nil { panic(err) } diff --git a/kem/kyber/kat_test.go b/kem/kyber/kat_test.go index 73eee008d..5b6d39e86 100644 --- a/kem/kyber/kat_test.go +++ b/kem/kyber/kat_test.go @@ -7,6 +7,7 @@ import ( "bytes" "crypto/sha256" "fmt" + "strings" "testing" "github.com/cloudflare/circl/internal/nist" @@ -22,6 +23,12 @@ func TestPQCgenKATKem(t *testing.T) { {"Kyber1024", "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5"}, {"Kyber768", "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2"}, {"Kyber512", "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547"}, + + // TODO crossreference with standard branch of reference implementation + // once they've added the final change: domain separation in K-PKE.KeyGen(). + {"ML-KEM-512", "a30184edee53b3b009356e1e31d7f9e93ce82550e3c622d7192e387b0cc84f2e"}, + {"ML-KEM-768", "729367b590637f4a93c68d5e4a4d2e2b4454842a52c9eec503e3a0d24cb66471"}, + {"ML-KEM-1024", "3fba7327d0320cb6134badf2a1bcb963a5b3c0026c7dece8f00d6a6155e47b33"}, } for _, kat := range kats { kat := kat @@ -45,18 +52,26 @@ func testPQCgenKATKem(t *testing.T, name, expected string) { } f := sha256.New() g := nist.NewDRBG(&seed) - fmt.Fprintf(f, "# %s\n\n", name) + + // The "standard" branch reference implementation still uses Kyber + // as name instead of ML-KEM. + fmt.Fprintf(f, "# %s\n\n", strings.ReplaceAll(name, "ML-KEM-", "Kyber")) for i := 0; i < 100; i++ { g.Fill(seed[:]) fmt.Fprintf(f, "count = %d\n", i) fmt.Fprintf(f, "seed = %X\n", seed) g2 := nist.NewDRBG(&seed) - // This is not equivalent to g2.Fill(kseed[:]). As the reference - // implementation calls randombytes twice generating the keypair, - // we have to do that as well. - g2.Fill(kseed[:32]) - g2.Fill(kseed[32:]) + if strings.HasPrefix(name, "ML-KEM") { + // https://github.com/pq-crystals/kyber/commit/830e0ba1a7fdba6cde03f8139b0d41ad2102b860 + g2.Fill(kseed[:]) + } else { + // This is not equivalent to g2.Fill(kseed[:]). As the reference + // implementation calls randombytes twice generating the keypair, + // we have to do that as well. + g2.Fill(kseed[:32]) + g2.Fill(kseed[32:]) + } g2.Fill(eseed) pk, sk := scheme.DeriveKeyPair(kseed) @@ -73,6 +88,6 @@ func testPQCgenKATKem(t *testing.T, name, expected string) { fmt.Fprintf(f, "ss = %X\n\n", ss) } if fmt.Sprintf("%x", f.Sum(nil)) != expected { - t.Fatal() + t.Fatalf("%s %x %s", name, f.Sum(nil), expected) } } diff --git a/kem/kyber/kyber1024/kyber.go b/kem/kyber/kyber1024/kyber.go index 428584528..4131fef57 100644 --- a/kem/kyber/kyber1024/kyber.go +++ b/kem/kyber/kyber1024/kyber.go @@ -123,10 +123,10 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { panic("ss must be of length SharedKeySize") } - // m = H(seed) var m [32]byte + // m = H(seed), the hash of shame h := sha3.New256() - h.Write(seed[:]) + h.Write(seed) h.Read(m[:]) // (K', r) = G(m ‖ H(pk)) @@ -194,7 +194,7 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // K = KDF(K''/z, H(c)) kdf := sha3.NewShake256() kdf.Write(kr2[:]) - kdf.Read(ss[:SharedKeySize]) + kdf.Read(ss) } // Packs sk to buf. diff --git a/kem/kyber/kyber512/kyber.go b/kem/kyber/kyber512/kyber.go index c250d78c6..447d135f9 100644 --- a/kem/kyber/kyber512/kyber.go +++ b/kem/kyber/kyber512/kyber.go @@ -123,10 +123,10 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { panic("ss must be of length SharedKeySize") } - // m = H(seed) var m [32]byte + // m = H(seed), the hash of shame h := sha3.New256() - h.Write(seed[:]) + h.Write(seed) h.Read(m[:]) // (K', r) = G(m ‖ H(pk)) @@ -194,7 +194,7 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // K = KDF(K''/z, H(c)) kdf := sha3.NewShake256() kdf.Write(kr2[:]) - kdf.Read(ss[:SharedKeySize]) + kdf.Read(ss) } // Packs sk to buf. diff --git a/kem/kyber/kyber768/kyber.go b/kem/kyber/kyber768/kyber.go index 832d9b371..cfdc705ef 100644 --- a/kem/kyber/kyber768/kyber.go +++ b/kem/kyber/kyber768/kyber.go @@ -123,10 +123,10 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { panic("ss must be of length SharedKeySize") } - // m = H(seed) var m [32]byte + // m = H(seed), the hash of shame h := sha3.New256() - h.Write(seed[:]) + h.Write(seed) h.Read(m[:]) // (K', r) = G(m ‖ H(pk)) @@ -194,7 +194,7 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // K = KDF(K''/z, H(c)) kdf := sha3.NewShake256() kdf.Write(kr2[:]) - kdf.Read(ss[:SharedKeySize]) + kdf.Read(ss) } // Packs sk to buf. diff --git a/kem/kyber/templates/pkg.templ.go b/kem/kyber/templates/pkg.templ.go index 22eb1fd74..18b9261e1 100644 --- a/kem/kyber/templates/pkg.templ.go +++ b/kem/kyber/templates/pkg.templ.go @@ -5,10 +5,14 @@ // Code generated from pkg.templ.go. DO NOT EDIT. // Package {{.Pkg}} implements the IND-CCA2 secure key encapsulation mechanism -// {{.Name}}.CCAKEM as submitted to round 3 of the NIST PQC competition and +{{ if .NIST -}} +// {{.KemName}} as defined in FIPS203. +{{- else -}} +// {{.KemName}} as submitted to round 3 of the NIST PQC competition and // described in // // https://pq-crystals.org/kyber/data/kyber-specification-round3.pdf +{{- end }} package {{.Pkg}} import ( @@ -18,7 +22,7 @@ import ( "github.com/cloudflare/circl/internal/sha3" "github.com/cloudflare/circl/kem" - cpapke "github.com/cloudflare/circl/pke/kyber/{{.Pkg}}" + cpapke "github.com/cloudflare/circl/pke/kyber/{{.PkePkg}}" cryptoRand "crypto/rand" ) @@ -42,14 +46,14 @@ const ( PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64 ) -// Type of a {{.Name}}.CCAKEM public key +// Type of a {{.KemName}} public key type PublicKey struct { pk *cpapke.PublicKey hpk [32]byte // H(pk) } -// Type of a {{.Name}}.CCAKEM private key +// Type of a {{.KemName}} private key type PrivateKey struct { sk *cpapke.PrivateKey pk *cpapke.PublicKey @@ -69,7 +73,11 @@ func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { panic("seed must be of length KeySeedSize") } + {{ if .NIST -}} + pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize]) + {{- else -}} pk.pk, sk.sk = cpapke.NewKeyFromSeed(seed[:cpapke.KeySeedSize]) + {{- end }} sk.pk = pk.pk copy(sk.z[:], seed[cpapke.KeySeedSize:]) @@ -127,11 +135,15 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { panic("ss must be of length SharedKeySize") } - // m = H(seed) var m [32]byte + {{ if .NIST -}} + copy(m[:], seed) + {{- else -}} + // m = H(seed), the hash of shame h := sha3.New256() - h.Write(seed[:]) + h.Write(seed) h.Read(m[:]) + {{- end }} // (K', r) = G(m ‖ H(pk)) var kr [64]byte @@ -143,6 +155,9 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { // c = Kyber.CPAPKE.Enc(pk, m, r) pk.pk.EncryptTo(ct, m[:], kr[32:]) + {{ if .NIST -}} + copy(ss, kr[:SharedKeySize]) + {{- else -}} // Compute H(c) and put in second slot of kr, which will be (K', H(c)). h.Reset() h.Write(ct[:CiphertextSize]) @@ -152,6 +167,7 @@ func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { kdf := sha3.NewShake256() kdf.Write(kr[:]) kdf.Read(ss[:SharedKeySize]) + {{- end }} } // DecapsulateTo computes the shared key which is encapsulated in ct @@ -183,6 +199,24 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { var ct2 [CiphertextSize]byte sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:]) + {{ if .NIST -}} + var ss2 [SharedKeySize]byte + + // Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c) + prf := sha3.NewShake256() + prf.Write(sk.z[:]) + prf.Write(ct[:CiphertextSize]) + prf.Read(ss2[:]) + + // Set ss2 to the real shared secret if c = c'. + subtle.ConstantTimeCopy( + subtle.ConstantTimeCompare(ct, ct2[:]), + ss2[:], + kr2[:SharedKeySize], + ) + + copy(ss, ss2[:]) + {{- else -}} // Compute H(c) and put in second slot of kr2, which will be (K'', H(c)). h := sha3.New256() h.Write(ct[:CiphertextSize]) @@ -198,7 +232,8 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // K = KDF(K''/z, H(c)) kdf := sha3.NewShake256() kdf.Write(kr2[:]) - kdf.Read(ss[:SharedKeySize]) + kdf.Read(ss) + {{- end }} } // Packs sk to buf. diff --git a/kem/mlkem/acvp_test.go b/kem/mlkem/acvp_test.go new file mode 100644 index 000000000..4e5306a6b --- /dev/null +++ b/kem/mlkem/acvp_test.go @@ -0,0 +1,269 @@ +package mlkem + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "io" + "os" + "testing" + + "github.com/cloudflare/circl/kem/schemes" +) + +// []byte but is encoded in hex for JSON +type HexBytes []byte + +func (b HexBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(b)) +} + +func (b *HexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return err + } + *b, err = hex.DecodeString(s) + return err +} + +func gunzip(in []byte) ([]byte, error) { + buf := bytes.NewBuffer(in) + r, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +func readGzip(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return gunzip(buf) +} + +func TestACVP(t *testing.T) { + for _, sub := range []string{ + "keyGen", + "encapDecap", + } { + t.Run(sub, func(t *testing.T) { + testACVP(t, sub) + }) + } +} + +func testACVP(t *testing.T, sub string) { + buf, err := readGzip("testdata/ML-KEM-" + sub + "-FIPS203/prompt.json.gz") + if err != nil { + t.Fatal(err) + } + + var prompt struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err := json.Unmarshal(buf, &prompt); err != nil { + t.Fatal(err) + } + + buf, err = readGzip("testdata/ML-KEM-" + sub + "-FIPS203/expectedResults.json.gz") + if err != nil { + t.Fatal(err) + } + + var results struct { + TestGroups []json.RawMessage `json:"testGroups"` + } + + if err := json.Unmarshal(buf, &results); err != nil { + t.Fatal(err) + } + + rawResults := make(map[int]json.RawMessage) + + for _, rawGroup := range results.TestGroups { + var abstractGroup struct { + Tests []json.RawMessage `json:tests` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + for _, rawTest := range abstractGroup.Tests { + var abstractTest struct { + TcId int `json:"tcId"` + } + if err := json.Unmarshal(rawTest, &abstractTest); err != nil { + t.Fatal(err) + } + if _, exists := rawResults[abstractTest.TcId]; exists { + t.Fatalf("Duplicate test id: %d", abstractTest.TcId) + } + rawResults[abstractTest.TcId] = rawTest + } + } + + for _, rawGroup := range prompt.TestGroups { + var abstractGroup struct { + TestType string `json:"testType"` + } + if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil { + t.Fatal(err) + } + switch { + case abstractGroup.TestType == "AFT" && sub == "keyGen": + var group struct { + TgId int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Tests []struct { + TcId int `json:"tcId"` + Z HexBytes `json:"z"` + D HexBytes `json:"d"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + scheme := schemes.ByName(group.ParameterSet) + if scheme == nil { + t.Fatalf("No such scheme: %s", group.ParameterSet) + } + + for _, test := range group.Tests { + var result struct { + Ek HexBytes `json:"ek"` + Dk HexBytes `json:"dk"` + } + rawResult, ok := rawResults[test.TcId] + if !ok { + t.Fatalf("Missing result: %d", test.TcId) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + var seed [64]byte + copy(seed[:], test.D) + copy(seed[32:], test.Z) + + ek, dk := scheme.DeriveKeyPair(seed[:]) + + ek2, err := scheme.UnmarshalBinaryPublicKey(result.Ek) + if err != nil { + t.Fatal(err) + } + dk2, err := scheme.UnmarshalBinaryPrivateKey(result.Dk) + if err != nil { + t.Fatal(err) + } + + if dk.Equal(dk2) { + t.Fatal("dk does not match") + } + if ek.Equal(ek2) { + t.Fatal("ek does not match") + } + } + case abstractGroup.TestType == "AFT" && sub == "encapDecap": + var group struct { + TgId int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Tests []struct { + TcId int `json:"tcId"` + Ek HexBytes `json:"ek"` + M HexBytes `json:"m"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + scheme := schemes.ByName(group.ParameterSet) + if scheme == nil { + t.Fatalf("No such scheme: %s", group.ParameterSet) + } + + for _, test := range group.Tests { + var result struct { + C HexBytes `json:"c"` + K HexBytes `json:"k"` + } + rawResult, ok := rawResults[test.TcId] + if !ok { + t.Fatalf("Missing result: %d", test.TcId) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + ek, err := scheme.UnmarshalBinaryPublicKey(test.Ek) + if err != nil { + t.Fatal(err) + } + + ct, ss, err := scheme.EncapsulateDeterministically(ek, test.M) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(ct, result.C) { + t.Fatalf("ciphertext doesn't match: %x ≠ %x", ct, result.C) + } + if !bytes.Equal(ss, result.K) { + t.Fatalf("shared secret doesn't match: %x ≠ %x", ss, result.K) + } + } + case abstractGroup.TestType == "VAL" && sub == "encapDecap": + var group struct { + TgId int `json:"tgId"` + ParameterSet string `json:"parameterSet"` + Dk HexBytes `json:"dk"` + Tests []struct { + TcId int `json:"tcId"` + C HexBytes `json:"c"` + } + } + if err := json.Unmarshal(rawGroup, &group); err != nil { + t.Fatal(err) + } + + scheme := schemes.ByName(group.ParameterSet) + if scheme == nil { + t.Fatalf("No such scheme: %s", group.ParameterSet) + } + + dk, err := scheme.UnmarshalBinaryPrivateKey(group.Dk) + if err != nil { + t.Fatal(err) + } + + for _, test := range group.Tests { + var result struct { + K HexBytes `json:"k"` + } + rawResult, ok := rawResults[test.TcId] + if !ok { + t.Fatalf("Missing rawResult: %d", test.TcId) + } + if err := json.Unmarshal(rawResult, &result); err != nil { + t.Fatal(err) + } + + ss, err := scheme.Decapsulate(dk, test.C) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(ss, result.K) { + t.Fatalf("shared secret doesn't match: %x ≠ %x", ss, result.K) + } + } + default: + t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub) + } + } +} diff --git a/kem/mlkem/doc.go b/kem/mlkem/doc.go new file mode 100644 index 000000000..563e52f96 --- /dev/null +++ b/kem/mlkem/doc.go @@ -0,0 +1,7 @@ +// Package mlkem implements IND-CCA2 secure ML-KEM key encapsulation +// mechanism (KEM) as defined in FIPS 203. +// +// https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.203.pdf +package mlkem + +// See ../kyber/gen.go and ../kyber/kat_test.go. diff --git a/kem/mlkem/mlkem1024/kyber.go b/kem/mlkem/mlkem1024/kyber.go new file mode 100644 index 000000000..fabd1e46e --- /dev/null +++ b/kem/mlkem/mlkem1024/kyber.go @@ -0,0 +1,390 @@ +// Code generated from pkg.templ.go. DO NOT EDIT. + +// Package mlkem1024 implements the IND-CCA2 secure key encapsulation mechanism +// ML-KEM-1024 as defined in FIPS203. +package mlkem1024 + +import ( + "bytes" + "crypto/subtle" + "io" + + cryptoRand "crypto/rand" + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" + cpapke "github.com/cloudflare/circl/pke/kyber/kyber1024" +) + +const ( + // Size of seed for NewKeyFromSeed + KeySeedSize = cpapke.KeySeedSize + 32 + + // Size of seed for EncapsulateTo. + EncapsulationSeedSize = 32 + + // Size of the established shared key. + SharedKeySize = 32 + + // Size of the encapsulated shared key. + CiphertextSize = cpapke.CiphertextSize + + // Size of a packed public key. + PublicKeySize = cpapke.PublicKeySize + + // Size of a packed private key. + PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64 +) + +// Type of a ML-KEM-1024 public key +type PublicKey struct { + pk *cpapke.PublicKey + + hpk [32]byte // H(pk) +} + +// Type of a ML-KEM-1024 private key +type PrivateKey struct { + sk *cpapke.PrivateKey + pk *cpapke.PublicKey + hpk [32]byte // H(pk) + z [32]byte +} + +// NewKeyFromSeed derives a public/private keypair deterministically +// from the given seed. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { + var sk PrivateKey + var pk PublicKey + + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + + pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize]) + sk.pk = pk.pk + copy(sk.z[:], seed[cpapke.KeySeedSize:]) + + // Compute H(pk) + var ppk [cpapke.PublicKeySize]byte + sk.pk.Pack(ppk[:]) + h := sha3.New256() + h.Write(ppk[:]) + h.Read(sk.hpk[:]) + copy(pk.hpk[:], sk.hpk[:]) + + return &pk, &sk +} + +// GenerateKeyPair generates public and private keys using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKeyPair(rand io.Reader) (*PublicKey, *PrivateKey, error) { + var seed [KeySeedSize]byte + if rand == nil { + rand = cryptoRand.Reader + } + _, err := io.ReadFull(rand, seed[:]) + if err != nil { + return nil, nil, err + } + pk, sk := NewKeyFromSeed(seed[:]) + return pk, sk, nil +} + +// EncapsulateTo generates a shared key and ciphertext that contains it +// for the public key using randomness from seed and writes the shared key +// to ss and ciphertext to ct. +// +// Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize +// and EncapsulationSeedSize respectively. +// +// seed may be nil, in which case crypto/rand.Reader is used to generate one. +func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { + if seed == nil { + seed = make([]byte, EncapsulationSeedSize) + if _, err := cryptoRand.Read(seed[:]); err != nil { + panic(err) + } + } else { + if len(seed) != EncapsulationSeedSize { + panic("seed must be of length EncapsulationSeedSize") + } + } + + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + var m [32]byte + copy(m[:], seed) + + // (K', r) = G(m ‖ H(pk)) + var kr [64]byte + g := sha3.New512() + g.Write(m[:]) + g.Write(pk.hpk[:]) + g.Read(kr[:]) + + // c = Kyber.CPAPKE.Enc(pk, m, r) + pk.pk.EncryptTo(ct, m[:], kr[32:]) + + copy(ss, kr[:SharedKeySize]) +} + +// DecapsulateTo computes the shared key which is encapsulated in ct +// for the private key. +// +// Panics if ct or ss are not of length CiphertextSize and SharedKeySize +// respectively. +func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + // m' = Kyber.CPAPKE.Dec(sk, ct) + var m2 [32]byte + sk.sk.DecryptTo(m2[:], ct) + + // (K'', r') = G(m' ‖ H(pk)) + var kr2 [64]byte + g := sha3.New512() + g.Write(m2[:]) + g.Write(sk.hpk[:]) + g.Read(kr2[:]) + + // c' = Kyber.CPAPKE.Enc(pk, m', r') + var ct2 [CiphertextSize]byte + sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:]) + + var ss2 [SharedKeySize]byte + + // Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c) + prf := sha3.NewShake256() + prf.Write(sk.z[:]) + prf.Write(ct[:CiphertextSize]) + prf.Read(ss2[:]) + + // Set ss2 to the real shared secret if c = c'. + subtle.ConstantTimeCopy( + subtle.ConstantTimeCompare(ct, ct2[:]), + ss2[:], + kr2[:SharedKeySize], + ) + + copy(ss, ss2[:]) +} + +// Packs sk to buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Pack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk.Pack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk.Pack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(buf, sk.hpk[:]) + buf = buf[32:] + copy(buf, sk.z[:]) +} + +// Unpacks sk from buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Unpack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk = new(cpapke.PrivateKey) + sk.sk.Unpack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk = new(cpapke.PublicKey) + sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(sk.hpk[:], buf[:32]) + copy(sk.z[:], buf[32:]) +} + +// Packs pk to buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Pack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk.Pack(buf) +} + +// Unpacks pk from buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Unpack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk = new(cpapke.PublicKey) + pk.pk.Unpack(buf) + + // Compute cached H(pk) + h := sha3.New256() + h.Write(buf) + h.Read(pk.hpk[:]) +} + +// Boilerplate down below for the KEM scheme API. + +type scheme struct{} + +var sch kem.Scheme = &scheme{} + +// Scheme returns a KEM interface. +func Scheme() kem.Scheme { return sch } + +func (*scheme) Name() string { return "ML-KEM-1024" } +func (*scheme) PublicKeySize() int { return PublicKeySize } +func (*scheme) PrivateKeySize() int { return PrivateKeySize } +func (*scheme) SeedSize() int { return KeySeedSize } +func (*scheme) SharedKeySize() int { return SharedKeySize } +func (*scheme) CiphertextSize() int { return CiphertextSize } +func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize } + +func (sk *PrivateKey) Scheme() kem.Scheme { return sch } +func (pk *PublicKey) Scheme() kem.Scheme { return sch } + +func (sk *PrivateKey) MarshalBinary() ([]byte, error) { + var ret [PrivateKeySize]byte + sk.Pack(ret[:]) + return ret[:], nil +} + +func (sk *PrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*PrivateKey) + if !ok { + return false + } + if sk.pk == nil && oth.pk == nil { + return true + } + if sk.pk == nil || oth.pk == nil { + return false + } + if !bytes.Equal(sk.hpk[:], oth.hpk[:]) || + subtle.ConstantTimeCompare(sk.z[:], oth.z[:]) != 1 { + return false + } + return sk.sk.Equal(oth.sk) +} + +func (pk *PublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*PublicKey) + if !ok { + return false + } + if pk.pk == nil && oth.pk == nil { + return true + } + if pk.pk == nil || oth.pk == nil { + return false + } + return bytes.Equal(pk.hpk[:], oth.hpk[:]) +} + +func (sk *PrivateKey) Public() kem.PublicKey { + pk := new(PublicKey) + pk.pk = sk.pk + copy(pk.hpk[:], sk.hpk[:]) + return pk +} + +func (pk *PublicKey) MarshalBinary() ([]byte, error) { + var ret [PublicKeySize]byte + pk.Pack(ret[:]) + return ret[:], nil +} + +func (*scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + return GenerateKeyPair(cryptoRand.Reader) +} + +func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + if len(seed) != KeySeedSize { + panic(kem.ErrSeedSize) + } + return NewKeyFromSeed(seed[:]) +} + +func (*scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, nil) + return +} + +func (*scheme) EncapsulateDeterministically(pk kem.PublicKey, seed []byte) ( + ct, ss []byte, err error) { + if len(seed) != EncapsulationSeedSize { + return nil, nil, kem.ErrSeedSize + } + + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, seed) + return +} + +func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != CiphertextSize { + return nil, kem.ErrCiphertextSize + } + + priv, ok := sk.(*PrivateKey) + if !ok { + return nil, kem.ErrTypeMismatch + } + ss := make([]byte, SharedKeySize) + priv.DecapsulateTo(ss, ct) + return ss, nil +} + +func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + if len(buf) != PublicKeySize { + return nil, kem.ErrPubKeySize + } + var ret PublicKey + ret.Unpack(buf) + return &ret, nil +} + +func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + if len(buf) != PrivateKeySize { + return nil, kem.ErrPrivKeySize + } + var ret PrivateKey + ret.Unpack(buf) + return &ret, nil +} diff --git a/kem/mlkem/mlkem512/kyber.go b/kem/mlkem/mlkem512/kyber.go new file mode 100644 index 000000000..24283e431 --- /dev/null +++ b/kem/mlkem/mlkem512/kyber.go @@ -0,0 +1,390 @@ +// Code generated from pkg.templ.go. DO NOT EDIT. + +// Package mlkem512 implements the IND-CCA2 secure key encapsulation mechanism +// ML-KEM-512 as defined in FIPS203. +package mlkem512 + +import ( + "bytes" + "crypto/subtle" + "io" + + cryptoRand "crypto/rand" + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" + cpapke "github.com/cloudflare/circl/pke/kyber/kyber512" +) + +const ( + // Size of seed for NewKeyFromSeed + KeySeedSize = cpapke.KeySeedSize + 32 + + // Size of seed for EncapsulateTo. + EncapsulationSeedSize = 32 + + // Size of the established shared key. + SharedKeySize = 32 + + // Size of the encapsulated shared key. + CiphertextSize = cpapke.CiphertextSize + + // Size of a packed public key. + PublicKeySize = cpapke.PublicKeySize + + // Size of a packed private key. + PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64 +) + +// Type of a ML-KEM-512 public key +type PublicKey struct { + pk *cpapke.PublicKey + + hpk [32]byte // H(pk) +} + +// Type of a ML-KEM-512 private key +type PrivateKey struct { + sk *cpapke.PrivateKey + pk *cpapke.PublicKey + hpk [32]byte // H(pk) + z [32]byte +} + +// NewKeyFromSeed derives a public/private keypair deterministically +// from the given seed. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { + var sk PrivateKey + var pk PublicKey + + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + + pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize]) + sk.pk = pk.pk + copy(sk.z[:], seed[cpapke.KeySeedSize:]) + + // Compute H(pk) + var ppk [cpapke.PublicKeySize]byte + sk.pk.Pack(ppk[:]) + h := sha3.New256() + h.Write(ppk[:]) + h.Read(sk.hpk[:]) + copy(pk.hpk[:], sk.hpk[:]) + + return &pk, &sk +} + +// GenerateKeyPair generates public and private keys using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKeyPair(rand io.Reader) (*PublicKey, *PrivateKey, error) { + var seed [KeySeedSize]byte + if rand == nil { + rand = cryptoRand.Reader + } + _, err := io.ReadFull(rand, seed[:]) + if err != nil { + return nil, nil, err + } + pk, sk := NewKeyFromSeed(seed[:]) + return pk, sk, nil +} + +// EncapsulateTo generates a shared key and ciphertext that contains it +// for the public key using randomness from seed and writes the shared key +// to ss and ciphertext to ct. +// +// Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize +// and EncapsulationSeedSize respectively. +// +// seed may be nil, in which case crypto/rand.Reader is used to generate one. +func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { + if seed == nil { + seed = make([]byte, EncapsulationSeedSize) + if _, err := cryptoRand.Read(seed[:]); err != nil { + panic(err) + } + } else { + if len(seed) != EncapsulationSeedSize { + panic("seed must be of length EncapsulationSeedSize") + } + } + + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + var m [32]byte + copy(m[:], seed) + + // (K', r) = G(m ‖ H(pk)) + var kr [64]byte + g := sha3.New512() + g.Write(m[:]) + g.Write(pk.hpk[:]) + g.Read(kr[:]) + + // c = Kyber.CPAPKE.Enc(pk, m, r) + pk.pk.EncryptTo(ct, m[:], kr[32:]) + + copy(ss, kr[:SharedKeySize]) +} + +// DecapsulateTo computes the shared key which is encapsulated in ct +// for the private key. +// +// Panics if ct or ss are not of length CiphertextSize and SharedKeySize +// respectively. +func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + // m' = Kyber.CPAPKE.Dec(sk, ct) + var m2 [32]byte + sk.sk.DecryptTo(m2[:], ct) + + // (K'', r') = G(m' ‖ H(pk)) + var kr2 [64]byte + g := sha3.New512() + g.Write(m2[:]) + g.Write(sk.hpk[:]) + g.Read(kr2[:]) + + // c' = Kyber.CPAPKE.Enc(pk, m', r') + var ct2 [CiphertextSize]byte + sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:]) + + var ss2 [SharedKeySize]byte + + // Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c) + prf := sha3.NewShake256() + prf.Write(sk.z[:]) + prf.Write(ct[:CiphertextSize]) + prf.Read(ss2[:]) + + // Set ss2 to the real shared secret if c = c'. + subtle.ConstantTimeCopy( + subtle.ConstantTimeCompare(ct, ct2[:]), + ss2[:], + kr2[:SharedKeySize], + ) + + copy(ss, ss2[:]) +} + +// Packs sk to buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Pack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk.Pack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk.Pack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(buf, sk.hpk[:]) + buf = buf[32:] + copy(buf, sk.z[:]) +} + +// Unpacks sk from buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Unpack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk = new(cpapke.PrivateKey) + sk.sk.Unpack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk = new(cpapke.PublicKey) + sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(sk.hpk[:], buf[:32]) + copy(sk.z[:], buf[32:]) +} + +// Packs pk to buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Pack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk.Pack(buf) +} + +// Unpacks pk from buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Unpack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk = new(cpapke.PublicKey) + pk.pk.Unpack(buf) + + // Compute cached H(pk) + h := sha3.New256() + h.Write(buf) + h.Read(pk.hpk[:]) +} + +// Boilerplate down below for the KEM scheme API. + +type scheme struct{} + +var sch kem.Scheme = &scheme{} + +// Scheme returns a KEM interface. +func Scheme() kem.Scheme { return sch } + +func (*scheme) Name() string { return "ML-KEM-512" } +func (*scheme) PublicKeySize() int { return PublicKeySize } +func (*scheme) PrivateKeySize() int { return PrivateKeySize } +func (*scheme) SeedSize() int { return KeySeedSize } +func (*scheme) SharedKeySize() int { return SharedKeySize } +func (*scheme) CiphertextSize() int { return CiphertextSize } +func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize } + +func (sk *PrivateKey) Scheme() kem.Scheme { return sch } +func (pk *PublicKey) Scheme() kem.Scheme { return sch } + +func (sk *PrivateKey) MarshalBinary() ([]byte, error) { + var ret [PrivateKeySize]byte + sk.Pack(ret[:]) + return ret[:], nil +} + +func (sk *PrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*PrivateKey) + if !ok { + return false + } + if sk.pk == nil && oth.pk == nil { + return true + } + if sk.pk == nil || oth.pk == nil { + return false + } + if !bytes.Equal(sk.hpk[:], oth.hpk[:]) || + subtle.ConstantTimeCompare(sk.z[:], oth.z[:]) != 1 { + return false + } + return sk.sk.Equal(oth.sk) +} + +func (pk *PublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*PublicKey) + if !ok { + return false + } + if pk.pk == nil && oth.pk == nil { + return true + } + if pk.pk == nil || oth.pk == nil { + return false + } + return bytes.Equal(pk.hpk[:], oth.hpk[:]) +} + +func (sk *PrivateKey) Public() kem.PublicKey { + pk := new(PublicKey) + pk.pk = sk.pk + copy(pk.hpk[:], sk.hpk[:]) + return pk +} + +func (pk *PublicKey) MarshalBinary() ([]byte, error) { + var ret [PublicKeySize]byte + pk.Pack(ret[:]) + return ret[:], nil +} + +func (*scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + return GenerateKeyPair(cryptoRand.Reader) +} + +func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + if len(seed) != KeySeedSize { + panic(kem.ErrSeedSize) + } + return NewKeyFromSeed(seed[:]) +} + +func (*scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, nil) + return +} + +func (*scheme) EncapsulateDeterministically(pk kem.PublicKey, seed []byte) ( + ct, ss []byte, err error) { + if len(seed) != EncapsulationSeedSize { + return nil, nil, kem.ErrSeedSize + } + + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, seed) + return +} + +func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != CiphertextSize { + return nil, kem.ErrCiphertextSize + } + + priv, ok := sk.(*PrivateKey) + if !ok { + return nil, kem.ErrTypeMismatch + } + ss := make([]byte, SharedKeySize) + priv.DecapsulateTo(ss, ct) + return ss, nil +} + +func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + if len(buf) != PublicKeySize { + return nil, kem.ErrPubKeySize + } + var ret PublicKey + ret.Unpack(buf) + return &ret, nil +} + +func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + if len(buf) != PrivateKeySize { + return nil, kem.ErrPrivKeySize + } + var ret PrivateKey + ret.Unpack(buf) + return &ret, nil +} diff --git a/kem/mlkem/mlkem768/kyber.go b/kem/mlkem/mlkem768/kyber.go new file mode 100644 index 000000000..b80fb9f42 --- /dev/null +++ b/kem/mlkem/mlkem768/kyber.go @@ -0,0 +1,390 @@ +// Code generated from pkg.templ.go. DO NOT EDIT. + +// Package mlkem768 implements the IND-CCA2 secure key encapsulation mechanism +// ML-KEM-768 as defined in FIPS203. +package mlkem768 + +import ( + "bytes" + "crypto/subtle" + "io" + + cryptoRand "crypto/rand" + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" + cpapke "github.com/cloudflare/circl/pke/kyber/kyber768" +) + +const ( + // Size of seed for NewKeyFromSeed + KeySeedSize = cpapke.KeySeedSize + 32 + + // Size of seed for EncapsulateTo. + EncapsulationSeedSize = 32 + + // Size of the established shared key. + SharedKeySize = 32 + + // Size of the encapsulated shared key. + CiphertextSize = cpapke.CiphertextSize + + // Size of a packed public key. + PublicKeySize = cpapke.PublicKeySize + + // Size of a packed private key. + PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64 +) + +// Type of a ML-KEM-768 public key +type PublicKey struct { + pk *cpapke.PublicKey + + hpk [32]byte // H(pk) +} + +// Type of a ML-KEM-768 private key +type PrivateKey struct { + sk *cpapke.PrivateKey + pk *cpapke.PublicKey + hpk [32]byte // H(pk) + z [32]byte +} + +// NewKeyFromSeed derives a public/private keypair deterministically +// from the given seed. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { + var sk PrivateKey + var pk PublicKey + + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + + pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize]) + sk.pk = pk.pk + copy(sk.z[:], seed[cpapke.KeySeedSize:]) + + // Compute H(pk) + var ppk [cpapke.PublicKeySize]byte + sk.pk.Pack(ppk[:]) + h := sha3.New256() + h.Write(ppk[:]) + h.Read(sk.hpk[:]) + copy(pk.hpk[:], sk.hpk[:]) + + return &pk, &sk +} + +// GenerateKeyPair generates public and private keys using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKeyPair(rand io.Reader) (*PublicKey, *PrivateKey, error) { + var seed [KeySeedSize]byte + if rand == nil { + rand = cryptoRand.Reader + } + _, err := io.ReadFull(rand, seed[:]) + if err != nil { + return nil, nil, err + } + pk, sk := NewKeyFromSeed(seed[:]) + return pk, sk, nil +} + +// EncapsulateTo generates a shared key and ciphertext that contains it +// for the public key using randomness from seed and writes the shared key +// to ss and ciphertext to ct. +// +// Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize +// and EncapsulationSeedSize respectively. +// +// seed may be nil, in which case crypto/rand.Reader is used to generate one. +func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) { + if seed == nil { + seed = make([]byte, EncapsulationSeedSize) + if _, err := cryptoRand.Read(seed[:]); err != nil { + panic(err) + } + } else { + if len(seed) != EncapsulationSeedSize { + panic("seed must be of length EncapsulationSeedSize") + } + } + + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + var m [32]byte + copy(m[:], seed) + + // (K', r) = G(m ‖ H(pk)) + var kr [64]byte + g := sha3.New512() + g.Write(m[:]) + g.Write(pk.hpk[:]) + g.Read(kr[:]) + + // c = Kyber.CPAPKE.Enc(pk, m, r) + pk.pk.EncryptTo(ct, m[:], kr[32:]) + + copy(ss, kr[:SharedKeySize]) +} + +// DecapsulateTo computes the shared key which is encapsulated in ct +// for the private key. +// +// Panics if ct or ss are not of length CiphertextSize and SharedKeySize +// respectively. +func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { + if len(ct) != CiphertextSize { + panic("ct must be of length CiphertextSize") + } + + if len(ss) != SharedKeySize { + panic("ss must be of length SharedKeySize") + } + + // m' = Kyber.CPAPKE.Dec(sk, ct) + var m2 [32]byte + sk.sk.DecryptTo(m2[:], ct) + + // (K'', r') = G(m' ‖ H(pk)) + var kr2 [64]byte + g := sha3.New512() + g.Write(m2[:]) + g.Write(sk.hpk[:]) + g.Read(kr2[:]) + + // c' = Kyber.CPAPKE.Enc(pk, m', r') + var ct2 [CiphertextSize]byte + sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:]) + + var ss2 [SharedKeySize]byte + + // Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c) + prf := sha3.NewShake256() + prf.Write(sk.z[:]) + prf.Write(ct[:CiphertextSize]) + prf.Read(ss2[:]) + + // Set ss2 to the real shared secret if c = c'. + subtle.ConstantTimeCopy( + subtle.ConstantTimeCompare(ct, ct2[:]), + ss2[:], + kr2[:SharedKeySize], + ) + + copy(ss, ss2[:]) +} + +// Packs sk to buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Pack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk.Pack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk.Pack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(buf, sk.hpk[:]) + buf = buf[32:] + copy(buf, sk.z[:]) +} + +// Unpacks sk from buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Unpack(buf []byte) { + if len(buf) != PrivateKeySize { + panic("buf must be of length PrivateKeySize") + } + + sk.sk = new(cpapke.PrivateKey) + sk.sk.Unpack(buf[:cpapke.PrivateKeySize]) + buf = buf[cpapke.PrivateKeySize:] + sk.pk = new(cpapke.PublicKey) + sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + buf = buf[cpapke.PublicKeySize:] + copy(sk.hpk[:], buf[:32]) + copy(sk.z[:], buf[32:]) +} + +// Packs pk to buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Pack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk.Pack(buf) +} + +// Unpacks pk from buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Unpack(buf []byte) { + if len(buf) != PublicKeySize { + panic("buf must be of length PublicKeySize") + } + + pk.pk = new(cpapke.PublicKey) + pk.pk.Unpack(buf) + + // Compute cached H(pk) + h := sha3.New256() + h.Write(buf) + h.Read(pk.hpk[:]) +} + +// Boilerplate down below for the KEM scheme API. + +type scheme struct{} + +var sch kem.Scheme = &scheme{} + +// Scheme returns a KEM interface. +func Scheme() kem.Scheme { return sch } + +func (*scheme) Name() string { return "ML-KEM-768" } +func (*scheme) PublicKeySize() int { return PublicKeySize } +func (*scheme) PrivateKeySize() int { return PrivateKeySize } +func (*scheme) SeedSize() int { return KeySeedSize } +func (*scheme) SharedKeySize() int { return SharedKeySize } +func (*scheme) CiphertextSize() int { return CiphertextSize } +func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize } + +func (sk *PrivateKey) Scheme() kem.Scheme { return sch } +func (pk *PublicKey) Scheme() kem.Scheme { return sch } + +func (sk *PrivateKey) MarshalBinary() ([]byte, error) { + var ret [PrivateKeySize]byte + sk.Pack(ret[:]) + return ret[:], nil +} + +func (sk *PrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*PrivateKey) + if !ok { + return false + } + if sk.pk == nil && oth.pk == nil { + return true + } + if sk.pk == nil || oth.pk == nil { + return false + } + if !bytes.Equal(sk.hpk[:], oth.hpk[:]) || + subtle.ConstantTimeCompare(sk.z[:], oth.z[:]) != 1 { + return false + } + return sk.sk.Equal(oth.sk) +} + +func (pk *PublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*PublicKey) + if !ok { + return false + } + if pk.pk == nil && oth.pk == nil { + return true + } + if pk.pk == nil || oth.pk == nil { + return false + } + return bytes.Equal(pk.hpk[:], oth.hpk[:]) +} + +func (sk *PrivateKey) Public() kem.PublicKey { + pk := new(PublicKey) + pk.pk = sk.pk + copy(pk.hpk[:], sk.hpk[:]) + return pk +} + +func (pk *PublicKey) MarshalBinary() ([]byte, error) { + var ret [PublicKeySize]byte + pk.Pack(ret[:]) + return ret[:], nil +} + +func (*scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + return GenerateKeyPair(cryptoRand.Reader) +} + +func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + if len(seed) != KeySeedSize { + panic(kem.ErrSeedSize) + } + return NewKeyFromSeed(seed[:]) +} + +func (*scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, nil) + return +} + +func (*scheme) EncapsulateDeterministically(pk kem.PublicKey, seed []byte) ( + ct, ss []byte, err error) { + if len(seed) != EncapsulationSeedSize { + return nil, nil, kem.ErrSeedSize + } + + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + pub.EncapsulateTo(ct, ss, seed) + return +} + +func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != CiphertextSize { + return nil, kem.ErrCiphertextSize + } + + priv, ok := sk.(*PrivateKey) + if !ok { + return nil, kem.ErrTypeMismatch + } + ss := make([]byte, SharedKeySize) + priv.DecapsulateTo(ss, ct) + return ss, nil +} + +func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + if len(buf) != PublicKeySize { + return nil, kem.ErrPubKeySize + } + var ret PublicKey + ret.Unpack(buf) + return &ret, nil +} + +func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + if len(buf) != PrivateKeySize { + return nil, kem.ErrPrivKeySize + } + var ret PrivateKey + ret.Unpack(buf) + return &ret, nil +} diff --git a/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/expectedResults.json.gz b/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/expectedResults.json.gz new file mode 100644 index 000000000..438079841 Binary files /dev/null and b/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/expectedResults.json.gz differ diff --git a/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/prompt.json.gz b/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/prompt.json.gz new file mode 100644 index 000000000..f8c794216 Binary files /dev/null and b/kem/mlkem/testdata/ML-KEM-encapDecap-FIPS203/prompt.json.gz differ diff --git a/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/expectedResults.json.gz b/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/expectedResults.json.gz new file mode 100644 index 000000000..a89e8ac19 Binary files /dev/null and b/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/expectedResults.json.gz differ diff --git a/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/prompt.json.gz b/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/prompt.json.gz new file mode 100644 index 000000000..5de69b62c Binary files /dev/null and b/kem/mlkem/testdata/ML-KEM-keyGen-FIPS203/prompt.json.gz differ diff --git a/kem/mlkem/testdata/README.md b/kem/mlkem/testdata/README.md new file mode 100644 index 000000000..756d8ba81 --- /dev/null +++ b/kem/mlkem/testdata/README.md @@ -0,0 +1,4 @@ +Sources + + 1. https://github.com/usnistgov/ACVP-Server/tree/f38183487eebff2952da0e5a3441371218acfe3f/gen-val/json-files/ML-KEM-encapDecap-FIPS203 + 2. https://github.com/usnistgov/ACVP-Server/tree/f38183487eebff2952da0e5a3441371218acfe3f/gen-val/json-files/ML-KEM-keyGen-FIPS203 diff --git a/kem/schemes/schemes.go b/kem/schemes/schemes.go index a33e7b96e..da5f9839e 100644 --- a/kem/schemes/schemes.go +++ b/kem/schemes/schemes.go @@ -26,6 +26,9 @@ import ( "github.com/cloudflare/circl/kem/kyber/kyber1024" "github.com/cloudflare/circl/kem/kyber/kyber512" "github.com/cloudflare/circl/kem/kyber/kyber768" + "github.com/cloudflare/circl/kem/mlkem/mlkem1024" + "github.com/cloudflare/circl/kem/mlkem/mlkem512" + "github.com/cloudflare/circl/kem/mlkem/mlkem768" ) var allSchemes = [...]kem.Scheme{ @@ -38,6 +41,9 @@ var allSchemes = [...]kem.Scheme{ kyber512.Scheme(), kyber768.Scheme(), kyber1024.Scheme(), + mlkem512.Scheme(), + mlkem768.Scheme(), + mlkem1024.Scheme(), hybrid.Kyber512X25519(), hybrid.Kyber768X25519(), hybrid.Kyber768X448(), diff --git a/kem/schemes/schemes_test.go b/kem/schemes/schemes_test.go index e41840b34..be4c18a54 100644 --- a/kem/schemes/schemes_test.go +++ b/kem/schemes/schemes_test.go @@ -155,6 +155,9 @@ func Example_schemes() { // Kyber512 // Kyber768 // Kyber1024 + // ML-KEM-512 + // ML-KEM-768 + // ML-KEM-1024 // Kyber512-X25519 // Kyber768-X25519 // Kyber768-X448 diff --git a/pke/kyber/kyber1024/kyber.go b/pke/kyber/kyber1024/kyber.go index fb5911fac..4693ea56e 100644 --- a/pke/kyber/kyber1024/kyber.go +++ b/pke/kyber/kyber1024/kyber.go @@ -57,6 +57,9 @@ func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) { // NewKeyFromSeed derives a public/private key pair using the given seed. // +// Note: does not include the domain separation of ML-KEM (line 1, algorithm 13 +// of FIPS 203). For that use NewKeyFromSeedMLKEM(). +// // Panics if seed is not of length KeySeedSize. func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { if len(seed) != KeySeedSize { @@ -66,6 +69,21 @@ func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { return (*PublicKey)(pk), (*PrivateKey)(sk) } +// NewKeyFromSeedMLKEM derives a public/private key pair using the given seed +// using the domain separation of ML-KEM. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeedMLKEM(seed []byte) (*PublicKey, *PrivateKey) { + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + var seed2 [33]byte + copy(seed2[:32], seed) + seed2[32] = byte(internal.K) + pk, sk := internal.NewKeyFromSeed(seed2[:]) + return (*PublicKey)(pk), (*PrivateKey)(sk) +} + // EncryptTo encrypts message pt for the public key and writes the ciphertext // to ct using randomness from seed. // diff --git a/pke/kyber/kyber512/kyber.go b/pke/kyber/kyber512/kyber.go index ea9248487..313140828 100644 --- a/pke/kyber/kyber512/kyber.go +++ b/pke/kyber/kyber512/kyber.go @@ -57,6 +57,9 @@ func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) { // NewKeyFromSeed derives a public/private key pair using the given seed. // +// Note: does not include the domain separation of ML-KEM (line 1, algorithm 13 +// of FIPS 203). For that use NewKeyFromSeedMLKEM(). +// // Panics if seed is not of length KeySeedSize. func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { if len(seed) != KeySeedSize { @@ -66,6 +69,21 @@ func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { return (*PublicKey)(pk), (*PrivateKey)(sk) } +// NewKeyFromSeedMLKEM derives a public/private key pair using the given seed +// using the domain separation of ML-KEM. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeedMLKEM(seed []byte) (*PublicKey, *PrivateKey) { + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + var seed2 [33]byte + copy(seed2[:32], seed) + seed2[32] = byte(internal.K) + pk, sk := internal.NewKeyFromSeed(seed2[:]) + return (*PublicKey)(pk), (*PrivateKey)(sk) +} + // EncryptTo encrypts message pt for the public key and writes the ciphertext // to ct using randomness from seed. // diff --git a/pke/kyber/kyber768/kyber.go b/pke/kyber/kyber768/kyber.go index 4cecbb1b8..66f7ebb37 100644 --- a/pke/kyber/kyber768/kyber.go +++ b/pke/kyber/kyber768/kyber.go @@ -57,6 +57,9 @@ func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) { // NewKeyFromSeed derives a public/private key pair using the given seed. // +// Note: does not include the domain separation of ML-KEM (line 1, algorithm 13 +// of FIPS 203). For that use NewKeyFromSeedMLKEM(). +// // Panics if seed is not of length KeySeedSize. func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { if len(seed) != KeySeedSize { @@ -66,6 +69,21 @@ func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { return (*PublicKey)(pk), (*PrivateKey)(sk) } +// NewKeyFromSeedMLKEM derives a public/private key pair using the given seed +// using the domain separation of ML-KEM. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeedMLKEM(seed []byte) (*PublicKey, *PrivateKey) { + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + var seed2 [33]byte + copy(seed2[:32], seed) + seed2[32] = byte(internal.K) + pk, sk := internal.NewKeyFromSeed(seed2[:]) + return (*PublicKey)(pk), (*PrivateKey)(sk) +} + // EncryptTo encrypts message pt for the public key and writes the ciphertext // to ct using randomness from seed. // diff --git a/pke/kyber/templates/pkg.templ.go b/pke/kyber/templates/pkg.templ.go index 4afec38b9..3834ddc0d 100644 --- a/pke/kyber/templates/pkg.templ.go +++ b/pke/kyber/templates/pkg.templ.go @@ -61,6 +61,9 @@ func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) { // NewKeyFromSeed derives a public/private key pair using the given seed. // +// Note: does not include the domain separation of ML-KEM (line 1, algorithm 13 +// of FIPS 203). For that use NewKeyFromSeedMLKEM(). +// // Panics if seed is not of length KeySeedSize. func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { if len(seed) != KeySeedSize { @@ -70,6 +73,21 @@ func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) { return (*PublicKey)(pk), (*PrivateKey)(sk) } +// NewKeyFromSeedMLKEM derives a public/private key pair using the given seed +// using the domain separation of ML-KEM. +// +// Panics if seed is not of length KeySeedSize. +func NewKeyFromSeedMLKEM(seed []byte) (*PublicKey, *PrivateKey) { + if len(seed) != KeySeedSize { + panic("seed must be of length KeySeedSize") + } + var seed2 [33]byte + copy(seed2[:32], seed) + seed2[32] = byte(internal.K) + pk, sk := internal.NewKeyFromSeed(seed2[:]) + return (*PublicKey)(pk), (*PrivateKey)(sk) +} + // EncryptTo encrypts message pt for the public key and writes the ciphertext // to ct using randomness from seed. //