]> git.feebdaed.xyz Git - 0xmirror/go.git/commitdiff
crypto/hpke: don't corrupt enc's excess capacity in DHKEM decap
authorFilippo Valsorda <filippo@golang.org>
Thu, 9 Jan 2025 15:56:37 +0000 (16:56 +0100)
committerGopher Robot <gobot@golang.org>
Wed, 10 Dec 2025 21:45:53 +0000 (13:45 -0800)
Caught because the one-shop APIs put the ciphertext after enc in a
single slice, so Recipient.Open would corrupt the ciphertext.

Change-Id: I15fe1dfcc05a5a7f5cd0b4ada21661e66a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/728500
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>

src/crypto/hpke/hpke_test.go
src/crypto/hpke/kem.go

index b54a234fe55221d15c5e154fed51e0841b4afff8..ceb33263a6671718fba70172bec104f693b752ee 100644 (file)
@@ -69,6 +69,183 @@ func Example() {
        // Decrypted message: |-()-|
 }
 
+func TestRoundTrip(t *testing.T) {
+       kems := []KEM{
+               DHKEM(ecdh.P256()),
+               DHKEM(ecdh.P384()),
+               DHKEM(ecdh.P521()),
+               DHKEM(ecdh.X25519()),
+               MLKEM768(),
+               MLKEM1024(),
+               MLKEM768P256(),
+               MLKEM1024P384(),
+               MLKEM768X25519(),
+       }
+       kdfs := []KDF{
+               HKDFSHA256(),
+               HKDFSHA384(),
+               HKDFSHA512(),
+               SHAKE128(),
+               SHAKE256(),
+       }
+       aeads := []AEAD{
+               AES128GCM(),
+               AES256GCM(),
+               ChaCha20Poly1305(),
+       }
+
+       for _, kem := range kems {
+               t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) {
+                       k, err := kem.GenerateKey()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       kb, err := k.Bytes()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       kk, err := kem.NewPrivateKey(kb)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if got, err := kk.Bytes(); err != nil {
+                               t.Fatal(err)
+                       } else if !bytes.Equal(got, kb) {
+                               t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb)
+                       }
+                       pk, err := kem.NewPublicKey(k.PublicKey().Bytes())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) {
+                               t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes())
+                       }
+
+                       for _, kdf := range kdfs {
+                               t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) {
+                                       for _, aead := range aeads {
+                                               t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) {
+                                                       c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext"))
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       p, err := Open(kk, kdf, aead, []byte("info"), c)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       if !bytes.Equal(p, []byte("plaintext")) {
+                                                               t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+                                                       }
+
+                                                       p, err = Open(kk, kdf, aead, []byte("wrong"), c)
+                                                       if err == nil {
+                                                               t.Errorf("expected error when opening with wrong info, got plaintext %x", p)
+                                                       }
+                                                       c[len(c)-1] ^= 0xFF
+                                                       p, err = Open(kk, kdf, aead, []byte("info"), c)
+                                                       if err == nil {
+                                                               t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p)
+                                                       }
+
+                                                       c, err = Seal(k.PublicKey(), kdf, aead, nil, nil)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       p, err = Open(k, kdf, aead, nil, c)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       if len(p) != 0 {
+                                                               t.Errorf("unexpected plaintext: got %x, want empty", p)
+                                                       }
+
+                                                       // Test that Seal and Open don't modify the excess capacity of input
+                                                       // slices. This is a regression test for a bug where decap would
+                                                       // append to the enc slice, corrupting the ciphertext if they shared
+                                                       // a backing array.
+                                                       padSlice := func(b []byte) []byte {
+                                                               s := make([]byte, len(b), len(b)+2000)
+                                                               copy(s, b)
+                                                               for i := len(b); i < cap(s); i++ {
+                                                                       s[:cap(s)][i] = 0xAA
+                                                               }
+                                                               return s[:len(b)]
+                                                       }
+                                                       checkSlice := func(name string, s []byte) {
+                                                               for i := len(s); i < cap(s); i++ {
+                                                                       if s[:cap(s)][i] != 0xAA {
+                                                                               t.Errorf("%s: modified byte at index %d beyond slice length", name, i)
+                                                                               return
+                                                                       }
+                                                               }
+                                                       }
+
+                                                       infoS := padSlice([]byte("info"))
+                                                       plaintextS := padSlice([]byte("plaintext"))
+                                                       c, err = Seal(pk, kdf, aead, infoS, plaintextS)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       checkSlice("Seal info", infoS)
+                                                       checkSlice("Seal plaintext", plaintextS)
+
+                                                       infoO := padSlice([]byte("info"))
+                                                       ciphertextO := padSlice(c)
+                                                       p, err = Open(kk, kdf, aead, infoO, ciphertextO)
+                                                       if err != nil {
+                                                               t.Fatalf("Open with large capacity slices failed: %v", err)
+                                                       }
+                                                       if !bytes.Equal(p, []byte("plaintext")) {
+                                                               t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+                                                       }
+                                                       checkSlice("Open info", infoO)
+                                                       checkSlice("Open ciphertext", ciphertextO)
+
+                                                       // Also test the Sender.Seal and Recipient.Open methods.
+                                                       infoSender := padSlice([]byte("info"))
+                                                       enc, sender, err := NewSender(pk, kdf, aead, infoSender)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       checkSlice("NewSender info", infoSender)
+
+                                                       aadSeal := padSlice([]byte("aad"))
+                                                       plaintextSeal := padSlice([]byte("plaintext"))
+                                                       ct, err := sender.Seal(aadSeal, plaintextSeal)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       checkSlice("Sender.Seal aad", aadSeal)
+                                                       checkSlice("Sender.Seal plaintext", plaintextSeal)
+
+                                                       infoRecipient := padSlice([]byte("info"))
+                                                       encPadded := padSlice(enc)
+                                                       recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient)
+                                                       if err != nil {
+                                                               t.Fatal(err)
+                                                       }
+                                                       checkSlice("NewRecipient info", infoRecipient)
+                                                       checkSlice("NewRecipient enc", encPadded)
+
+                                                       aadOpen := padSlice([]byte("aad"))
+                                                       ctPadded := padSlice(ct)
+                                                       p, err = recipient.Open(aadOpen, ctPadded)
+                                                       if err != nil {
+                                                               t.Fatalf("Recipient.Open failed: %v", err)
+                                                       }
+                                                       if !bytes.Equal(p, []byte("plaintext")) {
+                                                               t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+                                                       }
+                                                       checkSlice("Recipient.Open aad", aadOpen)
+                                                       checkSlice("Recipient.Open ciphertext", ctPadded)
+                                               })
+                                       }
+                               })
+                       }
+               })
+       }
+}
+
 func mustDecodeHex(t *testing.T, in string) []byte {
        t.Helper()
        b, err := hex.DecodeString(in)
index 7633aa2b714ca9ae46f7e1136a07776530538526..132e0a754c63f37e23c2874a2b3ba947df988254 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto/internal/rand"
        "errors"
        "internal/byteorder"
+       "slices"
 )
 
 // A KEM is a Key Encapsulation Mechanism, one of the three components of an
@@ -377,6 +378,6 @@ func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
        if err != nil {
                return nil, err
        }
-       kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
+       kemContext := append(slices.Clip(encPubEph), k.priv.PublicKey().Bytes()...)
        return k.kem.extractAndExpand(dhVal, kemContext)
 }