// Decrypted message: |-()-|
}
+func TestRoundTrip(t *testing.T) {
+ kems := []KEM{
+ DHKEM(ecdh.P256()),
+ DHKEM(ecdh.P384()),
+ DHKEM(ecdh.P521()),
+ DHKEM(ecdh.X25519()),
+ MLKEM768(),
+ MLKEM1024(),
+ MLKEM768P256(),
+ MLKEM1024P384(),
+ MLKEM768X25519(),
+ }
+ kdfs := []KDF{
+ HKDFSHA256(),
+ HKDFSHA384(),
+ HKDFSHA512(),
+ SHAKE128(),
+ SHAKE256(),
+ }
+ aeads := []AEAD{
+ AES128GCM(),
+ AES256GCM(),
+ ChaCha20Poly1305(),
+ }
+
+ for _, kem := range kems {
+ t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) {
+ k, err := kem.GenerateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ kb, err := k.Bytes()
+ if err != nil {
+ t.Fatal(err)
+ }
+ kk, err := kem.NewPrivateKey(kb)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, err := kk.Bytes(); err != nil {
+ t.Fatal(err)
+ } else if !bytes.Equal(got, kb) {
+ t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb)
+ }
+ pk, err := kem.NewPublicKey(k.PublicKey().Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) {
+ t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes())
+ }
+
+ for _, kdf := range kdfs {
+ t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) {
+ for _, aead := range aeads {
+ t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) {
+ c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ p, err := Open(kk, kdf, aead, []byte("info"), c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(p, []byte("plaintext")) {
+ t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+ }
+
+ p, err = Open(kk, kdf, aead, []byte("wrong"), c)
+ if err == nil {
+ t.Errorf("expected error when opening with wrong info, got plaintext %x", p)
+ }
+ c[len(c)-1] ^= 0xFF
+ p, err = Open(kk, kdf, aead, []byte("info"), c)
+ if err == nil {
+ t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p)
+ }
+
+ c, err = Seal(k.PublicKey(), kdf, aead, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ p, err = Open(k, kdf, aead, nil, c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(p) != 0 {
+ t.Errorf("unexpected plaintext: got %x, want empty", p)
+ }
+
+ // Test that Seal and Open don't modify the excess capacity of input
+ // slices. This is a regression test for a bug where decap would
+ // append to the enc slice, corrupting the ciphertext if they shared
+ // a backing array.
+ padSlice := func(b []byte) []byte {
+ s := make([]byte, len(b), len(b)+2000)
+ copy(s, b)
+ for i := len(b); i < cap(s); i++ {
+ s[:cap(s)][i] = 0xAA
+ }
+ return s[:len(b)]
+ }
+ checkSlice := func(name string, s []byte) {
+ for i := len(s); i < cap(s); i++ {
+ if s[:cap(s)][i] != 0xAA {
+ t.Errorf("%s: modified byte at index %d beyond slice length", name, i)
+ return
+ }
+ }
+ }
+
+ infoS := padSlice([]byte("info"))
+ plaintextS := padSlice([]byte("plaintext"))
+ c, err = Seal(pk, kdf, aead, infoS, plaintextS)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkSlice("Seal info", infoS)
+ checkSlice("Seal plaintext", plaintextS)
+
+ infoO := padSlice([]byte("info"))
+ ciphertextO := padSlice(c)
+ p, err = Open(kk, kdf, aead, infoO, ciphertextO)
+ if err != nil {
+ t.Fatalf("Open with large capacity slices failed: %v", err)
+ }
+ if !bytes.Equal(p, []byte("plaintext")) {
+ t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+ }
+ checkSlice("Open info", infoO)
+ checkSlice("Open ciphertext", ciphertextO)
+
+ // Also test the Sender.Seal and Recipient.Open methods.
+ infoSender := padSlice([]byte("info"))
+ enc, sender, err := NewSender(pk, kdf, aead, infoSender)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkSlice("NewSender info", infoSender)
+
+ aadSeal := padSlice([]byte("aad"))
+ plaintextSeal := padSlice([]byte("plaintext"))
+ ct, err := sender.Seal(aadSeal, plaintextSeal)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkSlice("Sender.Seal aad", aadSeal)
+ checkSlice("Sender.Seal plaintext", plaintextSeal)
+
+ infoRecipient := padSlice([]byte("info"))
+ encPadded := padSlice(enc)
+ recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkSlice("NewRecipient info", infoRecipient)
+ checkSlice("NewRecipient enc", encPadded)
+
+ aadOpen := padSlice([]byte("aad"))
+ ctPadded := padSlice(ct)
+ p, err = recipient.Open(aadOpen, ctPadded)
+ if err != nil {
+ t.Fatalf("Recipient.Open failed: %v", err)
+ }
+ if !bytes.Equal(p, []byte("plaintext")) {
+ t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
+ }
+ checkSlice("Recipient.Open aad", aadOpen)
+ checkSlice("Recipient.Open ciphertext", ctPadded)
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper()
b, err := hex.DecodeString(in)