diff --git a/age.go b/age.go index 810ba8c3..67d19a85 100644 --- a/age.go +++ b/age.go @@ -48,26 +48,40 @@ import ( "filippo.io/age/internal/stream" ) -// An Identity is a private key or other value that can decrypt an opaque file -// key from a recipient stanza. +// An Identity is passed to Decrypt to unwrap an opaque file key from a +// recipient stanza. It can be for example a secret key like X25519Identity, a +// plugin, or a custom implementation. // -// Unwrap must return an error wrapping ErrIncorrectIdentity for recipient -// stanzas that don't match the identity, any other error will be considered +// Unwrap must return an error wrapping ErrIncorrectIdentity if none of the +// recipient stanzas match the identity, any other error will be considered // fatal. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Identity interface { - Unwrap(block *Stanza) (fileKey []byte, err error) + Unwrap(stanzas []*Stanza) (fileKey []byte, err error) } var ErrIncorrectIdentity = errors.New("incorrect identity for recipient block") -// A Recipient is a public key or other value that can encrypt an opaque file -// key to a recipient stanza. +// A Recipient is passed to Encrypt to wrap an opaque file key to one or more +// recipient stanza(s). It can be for example a public key like X25519Recipient, +// a plugin, or a custom implementation. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Recipient interface { - Wrap(fileKey []byte) (*Stanza, error) + Wrap(fileKey []byte) ([]*Stanza, error) } // A Stanza is a section of the age header that encapsulates the file key as // encrypted to a specific recipient. +// +// Most age API users won't need to interact with this directly, and should +// instead pass Recipient implementations to Encrypt and Identity +// implementations to Decrypt. type Stanza struct { Type string Args []string @@ -96,13 +110,16 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { hdr := &format.Header{} for i, r := range recipients { - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) } - hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(block)) - - if block.Type == "scrypt" && len(recipients) != 1 { + for _, s := range stanzas { + hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(s)) + } + } + for _, s := range hdr.Recipients { + if s.Type == "scrypt" && len(hdr.Recipients) != 1 { return nil, errors.New("an scrypt recipient must be the only one") } } @@ -155,25 +172,29 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { return nil, errors.New("too many recipients") } - errNoMatch := &NoIdentityMatchError{} - var fileKey []byte -RecipientsLoop: for _, r := range hdr.Recipients { if r.Type == "scrypt" && len(hdr.Recipients) != 1 { return nil, errors.New("an scrypt recipient must be the only one") } - for _, i := range identities { - fileKey, err = i.Unwrap((*Stanza)(r)) - if errors.Is(err, ErrIncorrectIdentity) { - errNoMatch.Errors = append(errNoMatch.Errors, err) - continue - } - if err != nil { - return nil, err - } - - break RecipientsLoop + } + + stanzas := make([]*Stanza, 0, len(hdr.Recipients)) + for _, s := range hdr.Recipients { + stanzas = append(stanzas, (*Stanza)(s)) + } + errNoMatch := &NoIdentityMatchError{} + var fileKey []byte + for _, id := range identities { + fileKey, err = id.Unwrap(stanzas) + if errors.Is(err, ErrIncorrectIdentity) { + errNoMatch.Errors = append(errNoMatch.Errors, err) + continue + } + if err != nil { + return nil, err } + + break } if fileKey == nil { return nil, errNoMatch @@ -192,3 +213,22 @@ RecipientsLoop: return stream.NewReader(streamKey(fileKey, nonce), payload) } + +// multiUnwrap is a helper that implements Identity.Unwrap in terms of a +// function that unwraps a single recipient stanza. +func multiUnwrap(unwrap func(*Stanza) ([]byte, error), stanzas []*Stanza) ([]byte, error) { + for _, s := range stanzas { + fileKey, err := unwrap(s) + if errors.Is(err, ErrIncorrectIdentity) { + // If we ever start returning something interesting wrapping + // ErrIncorrectIdentity, we should let it make its way up through + // Decrypt into NoIdentityMatchError.Errors. + continue + } + if err != nil { + return nil, err + } + return fileKey, nil + } + return nil, ErrIncorrectIdentity +} diff --git a/agessh/agessh.go b/agessh/agessh.go index dc457ab9..8096032f 100644 --- a/agessh/agessh.go +++ b/agessh/agessh.go @@ -68,7 +68,7 @@ func NewRSARecipient(pk ssh.PublicKey) (*RSARecipient, error) { return r, nil } -func (r *RSARecipient) Wrap(fileKey []byte) (*age.Stanza, error) { +func (r *RSARecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { l := &age.Stanza{ Type: "ssh-rsa", Args: []string{sshFingerprint(r.sshKey)}, @@ -81,7 +81,7 @@ func (r *RSARecipient) Wrap(fileKey []byte) (*age.Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*age.Stanza{l}, nil } type RSAIdentity struct { @@ -102,7 +102,11 @@ func NewRSAIdentity(key *rsa.PrivateKey) (*RSAIdentity, error) { return i, nil } -func (i *RSAIdentity) Unwrap(block *age.Stanza) ([]byte, error) { +func (i *RSAIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *RSAIdentity) unwrap(block *age.Stanza) ([]byte, error) { if block.Type != "ssh-rsa" { return nil, age.ErrIncorrectIdentity } @@ -187,7 +191,7 @@ func ed25519PublicKeyToCurve25519(pk ed25519.PublicKey) ([]byte, error) { const ed25519Label = "age-encryption.org/v1/ssh-ed25519" -func (r *Ed25519Recipient) Wrap(fileKey []byte) (*age.Stanza, error) { +func (r *Ed25519Recipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { ephemeral := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(ephemeral); err != nil { return nil, err @@ -230,7 +234,7 @@ func (r *Ed25519Recipient) Wrap(fileKey []byte) (*age.Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*age.Stanza{l}, nil } type Ed25519Identity struct { @@ -276,7 +280,11 @@ func ed25519PrivateKeyToCurve25519(pk ed25519.PrivateKey) []byte { return out[:curve25519.ScalarSize] } -func (i *Ed25519Identity) Unwrap(block *age.Stanza) ([]byte, error) { +func (i *Ed25519Identity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *Ed25519Identity) unwrap(block *age.Stanza) ([]byte, error) { if block.Type != "ssh-ed25519" { return nil, age.ErrIncorrectIdentity } @@ -323,6 +331,26 @@ func (i *Ed25519Identity) Unwrap(block *age.Stanza) ([]byte, error) { return fileKey, nil } +// multiUnwrap is copied from package age. It's a helper that implements +// Identity.Unwrap in terms of a function that unwraps a single recipient +// stanza. +func multiUnwrap(unwrap func(*age.Stanza) ([]byte, error), stanzas []*age.Stanza) ([]byte, error) { + for _, s := range stanzas { + fileKey, err := unwrap(s) + if errors.Is(err, age.ErrIncorrectIdentity) { + // If we ever start returning something interesting wrapping + // ErrIncorrectIdentity, we should let it make its way up through + // Decrypt into NoIdentityMatchError.Errors. + continue + } + if err != nil { + return nil, err + } + return fileKey, nil + } + return nil, age.ErrIncorrectIdentity +} + // aeadEncrypt and aeadDecrypt are copied from package age. // // They don't limit the file key size because multi-key attacks are irrelevant diff --git a/agessh/agessh_test.go b/agessh/agessh_test.go index a0c25a96..ee6a5dd5 100644 --- a/agessh/agessh_test.go +++ b/agessh/agessh_test.go @@ -14,7 +14,6 @@ import ( "testing" "filippo.io/age/agessh" - "filippo.io/age/internal/format" "golang.org/x/crypto/ssh" ) @@ -41,15 +40,12 @@ func TestSSHRSARoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } @@ -82,15 +78,12 @@ func TestSSHEd25519RoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } diff --git a/agessh/encrypted_keys.go b/agessh/encrypted_keys.go index 78757c4b..ebac9dc6 100644 --- a/agessh/encrypted_keys.go +++ b/agessh/encrypted_keys.go @@ -56,20 +56,28 @@ func NewEncryptedSSHIdentity(pubKey ssh.PublicKey, pemBytes []byte, passphrase f var _ age.Identity = &EncryptedSSHIdentity{} // Unwrap implements age.Identity. If the private key is still encrypted, and -// the block matches the public key, it will request the passphrase. The +// any of the stanzas match the public key, it will request the passphrase. The // decrypted private key will be cached after the first successful invocation. -func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) { +func (i *EncryptedSSHIdentity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) { if i.decrypted != nil { - return i.decrypted.Unwrap(block) + return i.decrypted.Unwrap(stanzas) } - if block.Type != i.pubKey.Type() { - return nil, age.ErrIncorrectIdentity - } - if len(block.Args) < 1 { - return nil, fmt.Errorf("invalid %v recipient block", i.pubKey.Type()) + var match bool + for _, s := range stanzas { + if s.Type != i.pubKey.Type() { + continue + } + if len(s.Args) < 1 { + return nil, fmt.Errorf("invalid %v recipient block", i.pubKey.Type()) + } + if s.Args[0] != sshFingerprint(i.pubKey) { + continue + } + match = true + break } - if block.Args[0] != sshFingerprint(i.pubKey) { + if !match { return nil, age.ErrIncorrectIdentity } @@ -85,6 +93,8 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er switch k := k.(type) { case *ed25519.PrivateKey: i.decrypted, err = NewEd25519Identity(*k) + // TODO: here and below, better check that the two public keys match, + // rather than just the type. if i.pubKey.Type() != ssh.KeyAlgoED25519 { return nil, fmt.Errorf("mismatched private (%s) and public (%s) SSH key types", ssh.KeyAlgoED25519, i.pubKey.Type()) } @@ -100,5 +110,5 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er return nil, fmt.Errorf("invalid SSH key: %v", err) } - return i.decrypted.Unwrap(block) + return i.decrypted.Unwrap(stanzas) } diff --git a/cmd/age/encrypted_keys.go b/cmd/age/encrypted_keys.go index 7b8b54af..ed776515 100644 --- a/cmd/age/encrypted_keys.go +++ b/cmd/age/encrypted_keys.go @@ -21,8 +21,8 @@ type LazyScryptIdentity struct { var _ age.Identity = &LazyScryptIdentity{} -func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) { - if block.Type != "scrypt" { +func (i *LazyScryptIdentity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) { + if len(stanzas) != 1 || stanzas[0].Type != "scrypt" { return nil, age.ErrIncorrectIdentity } pass, err := i.Passphrase() @@ -33,7 +33,7 @@ func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err erro if err != nil { return nil, err } - fileKey, err = ii.Unwrap(block) + fileKey, err = ii.Unwrap(stanzas) if errors.Is(err, age.ErrIncorrectIdentity) { // ScryptIdentity returns ErrIncorrectIdentity for an incorrect // passphrase, which would lead Decrypt to returning "no identity diff --git a/recipients_test.go b/recipients_test.go index 462e2bee..70bee16f 100644 --- a/recipients_test.go +++ b/recipients_test.go @@ -12,7 +12,6 @@ import ( "testing" "filippo.io/age" - "filippo.io/age/internal/format" ) func TestX25519RoundTrip(t *testing.T) { @@ -37,15 +36,12 @@ func TestX25519RoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } @@ -72,15 +68,12 @@ func TestScryptRoundTrip(t *testing.T) { if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) } - block, err := r.Wrap(fileKey) + stanzas, err := r.Wrap(fileKey) if err != nil { t.Fatal(err) } - b := &bytes.Buffer{} - (*format.Stanza)(block).Marshal(b) - t.Logf("%s", b.Bytes()) - out, err := i.Unwrap(block) + out, err := i.Unwrap(stanzas) if err != nil { t.Fatal(err) } diff --git a/scrypt.go b/scrypt.go index e7c50009..4ce5d356 100644 --- a/scrypt.go +++ b/scrypt.go @@ -61,7 +61,7 @@ func (r *ScryptRecipient) SetWorkFactor(logN int) { const scryptSaltSize = 16 -func (r *ScryptRecipient) Wrap(fileKey []byte) (*Stanza, error) { +func (r *ScryptRecipient) Wrap(fileKey []byte) ([]*Stanza, error) { salt := make([]byte, scryptSaltSize) if _, err := rand.Read(salt[:]); err != nil { return nil, err @@ -85,7 +85,7 @@ func (r *ScryptRecipient) Wrap(fileKey []byte) (*Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*Stanza{l}, nil } // ScryptIdentity is a password-based identity. @@ -121,7 +121,11 @@ func (i *ScryptIdentity) SetMaxWorkFactor(logN int) { i.maxWorkFactor = logN } -func (i *ScryptIdentity) Unwrap(block *Stanza) ([]byte, error) { +func (i *ScryptIdentity) Unwrap(stanzas []*Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *ScryptIdentity) unwrap(block *Stanza) ([]byte, error) { if block.Type != "scrypt" { return nil, ErrIncorrectIdentity } diff --git a/x25519.go b/x25519.go index 1a363a99..3cda6d2b 100644 --- a/x25519.go +++ b/x25519.go @@ -63,7 +63,7 @@ func ParseX25519Recipient(s string) (*X25519Recipient, error) { return r, nil } -func (r *X25519Recipient) Wrap(fileKey []byte) (*Stanza, error) { +func (r *X25519Recipient) Wrap(fileKey []byte) ([]*Stanza, error) { ephemeral := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(ephemeral); err != nil { return nil, err @@ -98,7 +98,7 @@ func (r *X25519Recipient) Wrap(fileKey []byte) (*Stanza, error) { } l.Body = wrappedKey - return l, nil + return []*Stanza{l}, nil } // String returns the Bech32 public key encoding of r. @@ -154,7 +154,11 @@ func ParseX25519Identity(s string) (*X25519Identity, error) { return r, nil } -func (i *X25519Identity) Unwrap(block *Stanza) ([]byte, error) { +func (i *X25519Identity) Unwrap(stanzas []*Stanza) ([]byte, error) { + return multiUnwrap(i.unwrap, stanzas) +} + +func (i *X25519Identity) unwrap(block *Stanza) ([]byte, error) { if block.Type != "X25519" { return nil, ErrIncorrectIdentity }