From 5d96bfa9a9e67f6f49b3c25f6b0a514cc0b0328a Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 31 Jan 2021 21:59:06 +0100 Subject: [PATCH] age: make Identity and Recipient work on multiple stanzas This is a breaking change, but like the other changes to these interfaces it should not matter to consumers of the API that don't implement custom Recipients or Identities, which is all of them so far, as far as I can tell. It became clear working on plugins that we might want Recipient to return multiple recipient stanzas, for example if the plugin recipient is an alias or a group. The Identity side is less important, but it might help avoid round-trips and it makes sense to keep things symmetric. --- age.go | 92 ++++++++++++++++++++++++++++----------- agessh/agessh.go | 40 ++++++++++++++--- agessh/agessh_test.go | 15 ++----- agessh/encrypted_keys.go | 30 ++++++++----- cmd/age/encrypted_keys.go | 6 +-- recipients_test.go | 15 ++----- scrypt.go | 10 +++-- x25519.go | 10 +++-- 8 files changed, 145 insertions(+), 73 deletions(-) diff --git a/age.go b/age.go index 810ba8c3..67d19a85 100644 --- a/age.go +++ b/age.go @@ -48,26 +48,40 @@ import ( "filippo.io/age/internal/stream" ) -// An Identity is a private key or other value that can decrypt an opaque file -// key from a recipient stanza. +// An Identity is passed to Decrypt to unwrap an opaque file key from a +// recipient stanza. It can be for example a secret key like X25519Identity, a +// plugin, or a custom implementation. // -// Unwrap must return an error wrapping ErrIncorrectIdentity for recipient -// stanzas that don't match the identity, any other error will be considered +// Unwrap must return an error wrapping ErrIncorrectIdentity if none of the +// recipient stanzas match the identity, any other error will be considered // fatal. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Identity interface { - Unwrap(block *Stanza) (fileKey []byte, err error) + Unwrap(stanzas []*Stanza) (fileKey []byte, err error) } var ErrIncorrectIdentity = errors.New("incorrect identity for recipient block") -// A Recipient is a public key or other value that can encrypt an opaque file -// key to a recipient stanza. +// A Recipient is passed to Encrypt to wrap an opaque file key to one or more +// recipient stanza(s). It can be for example a public key like X25519Recipient, +// a plugin, or a custom implementation. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Recipient interface { - Wrap(fileKey []byte) (*Stanza, error) + Wrap(fileKey []byte) ([]*Stanza, error) } // A Stanza is a section of the age header that encapsulates the file key as // encrypted to a specific recipient. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Stanza struct { Type string Args []string @@ -96,13 +110,16 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { hdr := &format.Header{} for i, r := range recipients { - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) } - hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(block)) - - if block.Type == "scrypt" && len(recipients) != 1 { + for _, s := range stanzas { + hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(s)) + } + } + for _, s := range hdr.Recipients { + if s.Type == "scrypt" && len(hdr.Recipients) != 1 { return nil, errors.New("an scrypt recipient must be the only one") } } @@ -155,25 +172,29 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { return nil, errors.New("too many recipients") } - errNoMatch := &NoIdentityMatchError{} - var fileKey []byte -RecipientsLoop: for _, r := range hdr.Recipients { if r.Type == "scrypt" && len(hdr.Recipients) != 1 { return nil, errors.New("an scrypt recipient must be the only one") } - for _, i := range identities { - fileKey, err = i.Unwrap((*Stanza)(r)) - if errors.Is(err, ErrIncorrectIdentity) { - errNoMatch.Errors = append(errNoMatch.Errors, err) - continue - } - if err != nil { - return nil, err - } - - break RecipientsLoop + } + + stanzas := make([]*Stanza, 0, len(hdr.Recipients)) + for _, s := range hdr.Recipients { + stanzas = append(stanzas, (*Stanza)(s)) + } + errNoMatch := &NoIdentityMatchError{} + var fileKey []byte + for _, id := range identities { + fileKey, err = id.Unwrap(stanzas) + if errors.Is(err, ErrIncorrectIdentity) { + errNoMatch.Errors = append(errNoMatch.Errors, err) + continue + } + if err != nil { + return nil, err } + + break } if fileKey == nil { return nil, errNoMatch @@ -192,3 +213,22 @@ RecipientsLoop: return stream.NewReader(streamKey(fileKey, nonce), payload) } + +// multiUnwrap is a helper that implements Identity.Unwrap in terms of a +// function that unwraps a single recipient stanza. +func multiUnwrap(unwrap func(*Stanza) ([]byte, error), stanzas []*Stanza) ([]byte, error) { + for _, s := range stanzas { + fileKey, err := unwrap(s) + if errors.Is(err, ErrIncorrectIdentity) { + // If we ever start returning something interesting wrapping + // ErrIncorrectIdentity, we should let it make its way up through + // Decrypt into NoIdentityMatchError.Errors. + continue + } + if err != nil { + return nil, err + } + return fileKey, nil + } + return nil, ErrIncorrectIdentity +} diff --git a/agessh/agessh.go b/agessh/agessh.go index dc457ab9..8096032f 100644 --- a/agessh/agessh.go +++ b/agessh/agessh.go @@ -68,7 +68,7 @@ func NewRSARecipient(pk ssh.PublicKey) (*RSARecipient, error) { return r, nil } -func (r *RSARecipient) Wrap(fileKey []byte) (*age.Stanza, error) { +func (r *RSARecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { l := &age.Stanza{ Type: "ssh-rsa", Args: []string{sshFingerprint(r.sshKey)}, @@ -81,7 +81,7 @@ func (r *RSARecipient) Wrap(fileKey []byte) (*age.Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*age.Stanza{l}, nil } type RSAIdentity struct { @@ -102,7 +102,11 @@ func NewRSAIdentity(key *rsa.PrivateKey) (*RSAIdentity, error) { return i, nil } -func (i *RSAIdentity) Unwrap(block *age.Stanza) ([]byte, error) { +func (i *RSAIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *RSAIdentity) unwrap(block *age.Stanza) ([]byte, error) { if block.Type != "ssh-rsa" { return nil, age.ErrIncorrectIdentity } @@ -187,7 +191,7 @@ func ed25519PublicKeyToCurve25519(pk ed25519.PublicKey) ([]byte, error) { const ed25519Label = "age-encryption.org/v1/ssh-ed25519" -func (r *Ed25519Recipient) Wrap(fileKey []byte) (*age.Stanza, error) { +func (r *Ed25519Recipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { ephemeral := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(ephemeral); err != nil { return nil, err @@ -230,7 +234,7 @@ func (r *Ed25519Recipient) Wrap(fileKey []byte) (*age.Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*age.Stanza{l}, nil } type Ed25519Identity struct { @@ -276,7 +280,11 @@ func ed25519PrivateKeyToCurve25519(pk ed25519.PrivateKey) []byte { return out[:curve25519.ScalarSize] } -func (i *Ed25519Identity) Unwrap(block *age.Stanza) ([]byte, error) { +func (i *Ed25519Identity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *Ed25519Identity) unwrap(block *age.Stanza) ([]byte, error) { if block.Type != "ssh-ed25519" { return nil, age.ErrIncorrectIdentity } @@ -323,6 +331,26 @@ func (i *Ed25519Identity) Unwrap(block *age.Stanza) ([]byte, error) { return fileKey, nil } +// multiUnwrap is copied from package age. It's a helper that implements +// Identity.Unwrap in terms of a function that unwraps a single recipient +// stanza. +func multiUnwrap(unwrap func(*age.Stanza) ([]byte, error), stanzas []*age.Stanza) ([]byte, error) { + for _, s := range stanzas { + fileKey, err := unwrap(s) + if errors.Is(err, age.ErrIncorrectIdentity) { + // If we ever start returning something interesting wrapping + // ErrIncorrectIdentity, we should let it make its way up through + // Decrypt into NoIdentityMatchError.Errors. + continue + } + if err != nil { + return nil, err + } + return fileKey, nil + } + return nil, age.ErrIncorrectIdentity +} + // aeadEncrypt and aeadDecrypt are copied from package age. // // They don't limit the file key size because multi-key attacks are irrelevant diff --git a/agessh/agessh_test.go b/agessh/agessh_test.go index a0c25a96..ee6a5dd5 100644 --- a/agessh/agessh_test.go +++ b/agessh/agessh_test.go @@ -14,7 +14,6 @@ import ( "testing" "filippo.io/age/agessh" - "filippo.io/age/internal/format" "golang.org/x/crypto/ssh" ) @@ -41,15 +40,12 @@ func TestSSHRSARoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } @@ -82,15 +78,12 @@ func TestSSHEd25519RoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } diff --git a/agessh/encrypted_keys.go b/agessh/encrypted_keys.go index 78757c4b..ebac9dc6 100644 --- a/agessh/encrypted_keys.go +++ b/agessh/encrypted_keys.go @@ -56,20 +56,28 @@ func NewEncryptedSSHIdentity(pubKey ssh.PublicKey, pemBytes []byte, passphrase f var _ age.Identity = &EncryptedSSHIdentity{} // Unwrap implements age.Identity. If the private key is still encrypted, and -// the block matches the public key, it will request the passphrase. The +// any of the stanzas match the public key, it will request the passphrase. The // decrypted private key will be cached after the first successful invocation. -func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) { +func (i *EncryptedSSHIdentity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) { if i.decrypted != nil { - return i.decrypted.Unwrap(block) + return i.decrypted.Unwrap(stanzas) } - if block.Type != i.pubKey.Type() { - return nil, age.ErrIncorrectIdentity - } - if len(block.Args) < 1 { - return nil, fmt.Errorf("invalid %v recipient block", i.pubKey.Type()) + var match bool + for _, s := range stanzas { + if s.Type != i.pubKey.Type() { + continue + } + if len(s.Args) < 1 { + return nil, fmt.Errorf("invalid %v recipient block", i.pubKey.Type()) + } + if s.Args[0] != sshFingerprint(i.pubKey) { + continue + } + match = true + break } - if block.Args[0] != sshFingerprint(i.pubKey) { + if !match { return nil, age.ErrIncorrectIdentity } @@ -85,6 +93,8 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er switch k := k.(type) { case *ed25519.PrivateKey: i.decrypted, err = NewEd25519Identity(*k) + // TODO: here and below, better check that the two public keys match, + // rather than just the type. if i.pubKey.Type() != ssh.KeyAlgoED25519 { return nil, fmt.Errorf("mismatched private (%s) and public (%s) SSH key types", ssh.KeyAlgoED25519, i.pubKey.Type()) } @@ -100,5 +110,5 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er return nil, fmt.Errorf("invalid SSH key: %v", err) } - return i.decrypted.Unwrap(block) + return i.decrypted.Unwrap(stanzas) } diff --git a/cmd/age/encrypted_keys.go b/cmd/age/encrypted_keys.go index 7b8b54af..ed776515 100644 --- a/cmd/age/encrypted_keys.go +++ b/cmd/age/encrypted_keys.go @@ -21,8 +21,8 @@ type LazyScryptIdentity struct { var _ age.Identity = &LazyScryptIdentity{} -func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) { - if block.Type != "scrypt" { +func (i *LazyScryptIdentity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) { + if len(stanzas) != 1 || stanzas[0].Type != "scrypt" { return nil, age.ErrIncorrectIdentity } pass, err := i.Passphrase() @@ -33,7 +33,7 @@ func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err erro if err != nil { return nil, err } - fileKey, err = ii.Unwrap(block) + fileKey, err = ii.Unwrap(stanzas) if errors.Is(err, age.ErrIncorrectIdentity) { // ScryptIdentity returns ErrIncorrectIdentity for an incorrect // passphrase, which would lead Decrypt to returning "no identity diff --git a/recipients_test.go b/recipients_test.go index 462e2bee..70bee16f 100644 --- a/recipients_test.go +++ b/recipients_test.go @@ -12,7 +12,6 @@ import ( "testing" "filippo.io/age" - "filippo.io/age/internal/format" ) func TestX25519RoundTrip(t *testing.T) { @@ -37,15 +36,12 @@ func TestX25519RoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } @@ -72,15 +68,12 @@ func TestScryptRoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } diff --git a/scrypt.go b/scrypt.go index e7c50009..4ce5d356 100644 --- a/scrypt.go +++ b/scrypt.go @@ -61,7 +61,7 @@ func (r *ScryptRecipient) SetWorkFactor(logN int) { const scryptSaltSize = 16 -func (r *ScryptRecipient) Wrap(fileKey []byte) (*Stanza, error) { +func (r *ScryptRecipient) Wrap(fileKey []byte) ([]*Stanza, error) { salt := make([]byte, scryptSaltSize) if _, err := rand.Read(salt[:]); err != nil { return nil, err @@ -85,7 +85,7 @@ func (r *ScryptRecipient) Wrap(fileKey []byte) (*Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*Stanza{l}, nil } // ScryptIdentity is a password-based identity. @@ -121,7 +121,11 @@ func (i *ScryptIdentity) SetMaxWorkFactor(logN int) { i.maxWorkFactor = logN } -func (i *ScryptIdentity) Unwrap(block *Stanza) ([]byte, error) { +func (i *ScryptIdentity) Unwrap(stanzas []*Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *ScryptIdentity) unwrap(block *Stanza) ([]byte, error) { if block.Type != "scrypt" { return nil, ErrIncorrectIdentity } diff --git a/x25519.go b/x25519.go index 1a363a99..3cda6d2b 100644 --- a/x25519.go +++ b/x25519.go @@ -63,7 +63,7 @@ func ParseX25519Recipient(s string) (*X25519Recipient, error) { return r, nil } -func (r *X25519Recipient) Wrap(fileKey []byte) (*Stanza, error) { +func (r *X25519Recipient) Wrap(fileKey []byte) ([]*Stanza, error) { ephemeral := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(ephemeral); err != nil { return nil, err @@ -98,7 +98,7 @@ func (r *X25519Recipient) Wrap(fileKey []byte) (*Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*Stanza{l}, nil } // String returns the Bech32 public key encoding of r. @@ -154,7 +154,11 @@ func ParseX25519Identity(s string) (*X25519Identity, error) { return r, nil } -func (i *X25519Identity) Unwrap(block *Stanza) ([]byte, error) { +func (i *X25519Identity) Unwrap(stanzas []*Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *X25519Identity) unwrap(block *Stanza) ([]byte, error) { if block.Type != "X25519" { return nil, ErrIncorrectIdentity }