Skip to content

Commit

Permalink
Use internal::RsaPrivateKeyToRsa to validate RSA key pairs for JWT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721402906
Change-Id: I8f45b84422a880515bd32d4ec57394abcd3bc14a
  • Loading branch information
morambro authored and copybara-github committed Jan 30, 2025
1 parent 3ac3585 commit f944957
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 191 deletions.
6 changes: 4 additions & 2 deletions tink/jwt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,13 @@ cc_library(
"//tink:key",
"//tink:partial_key_access_token",
"//tink:restricted_big_integer",
"//tink/internal:bn_util",
"//tink/internal:rsa_util",
"//tink/internal:ssl_unique_ptr",
"//tink/util:status",
"//tink/util:statusor",
"@boringssl//:crypto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:optional",
],
)
Expand Down Expand Up @@ -588,12 +589,13 @@ cc_library(
"//tink:key",
"//tink:partial_key_access_token",
"//tink:restricted_big_integer",
"//tink/internal:bn_util",
"//tink/internal:rsa_util",
"//tink/internal:ssl_unique_ptr",
"//tink/util:status",
"//tink/util:statusor",
"@boringssl//:crypto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:optional",
],
)
Expand Down
6 changes: 4 additions & 2 deletions tink/jwt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,15 @@ tink_cc_library(
tink::jwt::jwt_rsa_ssa_pkcs1_public_key
tink::jwt::jwt_signature_private_key
absl::status
absl::statusor
absl::optional
crypto
tink::core::big_integer
tink::core::insecure_secret_key_access
tink::core::key
tink::core::partial_key_access_token
tink::core::restricted_big_integer
tink::internal::bn_util
tink::internal::rsa_util
tink::internal::ssl_unique_ptr
tink::util::status
tink::util::statusor
Expand Down Expand Up @@ -540,14 +541,15 @@ tink_cc_library(
tink::jwt::jwt_rsa_ssa_pss_public_key
tink::jwt::jwt_signature_private_key
absl::status
absl::statusor
absl::optional
crypto
tink::core::big_integer
tink::core::insecure_secret_key_access
tink::core::key
tink::core::partial_key_access_token
tink::core::restricted_big_integer
tink::internal::bn_util
tink::internal::rsa_util
tink::internal::ssl_unique_ptr
tink::util::status
tink::util::statusor
Expand Down
99 changes: 16 additions & 83 deletions tink/jwt/jwt_rsa_ssa_pkcs1_private_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

#include "tink/jwt/jwt_rsa_ssa_pkcs1_private_key.h"

#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tink/internal/rsa_util.h"
#ifdef OPENSSL_IS_BORINGSSL
#include "openssl/base.h"
#endif
#include "openssl/rsa.h"
#include "tink/big_integer.h"
#include "tink/insecure_secret_key_access.h"
#include "tink/internal/bn_util.h"
#include "tink/internal/ssl_unique_ptr.h"
#include "tink/jwt/jwt_rsa_ssa_pkcs1_public_key.h"
#include "tink/key.h"
Expand All @@ -41,88 +44,18 @@ util::Status ValidateKeyPair(
const RestrictedBigInteger& p, const RestrictedBigInteger& q,
const RestrictedBigInteger& d, const RestrictedBigInteger& dp,
const RestrictedBigInteger& dq, const RestrictedBigInteger& q_inv) {
internal::SslUniquePtr<RSA> rsa(RSA_new());
if (rsa.get() == nullptr) {
return util::Status(absl::StatusCode::kInternal,
"Internal RSA allocation error");
}

util::StatusOr<internal::SslUniquePtr<BIGNUM>> n =
internal::StringToBignum(modulus.GetValue());
if (!n.ok()) {
return n.status();
}

util::StatusOr<internal::SslUniquePtr<BIGNUM>> e =
internal::StringToBignum(public_exponent.GetValue());
if (!e.ok()) {
return e.status();
}

util::StatusOr<internal::SslUniquePtr<BIGNUM>> d_bn =
internal::StringToBignum(d.GetSecret(InsecureSecretKeyAccess::Get()));
if (!d_bn.ok()) {
return d_bn.status();
}

util::StatusOr<internal::SslUniquePtr<BIGNUM>> p_bn =
internal::StringToBignum(p.GetSecret(InsecureSecretKeyAccess::Get()));
if (!p_bn.ok()) {
return p_bn.status();
}
util::StatusOr<internal::SslUniquePtr<BIGNUM>> q_bn =
internal::StringToBignum(q.GetSecret(InsecureSecretKeyAccess::Get()));
if (!q_bn.ok()) {
return q_bn.status();
}

util::StatusOr<internal::SslUniquePtr<BIGNUM>> dp_bn =
internal::StringToBignum(dp.GetSecret(InsecureSecretKeyAccess::Get()));
if (!dp_bn.ok()) {
return dp_bn.status();
}
util::StatusOr<internal::SslUniquePtr<BIGNUM>> dq_bn =
internal::StringToBignum(dq.GetSecret(InsecureSecretKeyAccess::Get()));
if (!dq_bn.ok()) {
return dq_bn.status();
}
util::StatusOr<internal::SslUniquePtr<BIGNUM>> q_inv_bn =
internal::StringToBignum(q_inv.GetSecret(InsecureSecretKeyAccess::Get()));
if (!q_inv_bn.ok()) {
return q_inv_bn.status();
}

// Build RSA key from the given values. The RSA object takes ownership of the
// given values after the call.
if (RSA_set0_key(rsa.get(), n->release(), e->release(), d_bn->release()) !=
1 ||
RSA_set0_factors(rsa.get(), p_bn->release(), q_bn->release()) != 1 ||
RSA_set0_crt_params(rsa.get(), dp_bn->release(), dq_bn->release(),
q_inv_bn->release()) != 1) {
return util::Status(absl::StatusCode::kInternal,
"Internal RSA key loading error");
}

// Validate key.
int check_key_status = RSA_check_key(rsa.get());
if (check_key_status == 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"RSA key pair is not valid");
}

if (check_key_status == -1) {
return util::Status(absl::StatusCode::kInternal,
"An error ocurred while checking the key");
}

#ifdef OPENSSL_IS_BORINGSSL
if (RSA_check_fips(rsa.get()) == 0) {
return util::Status(absl::StatusCode::kInvalidArgument,
"RSA key pair is not valid in FIPS mode");
}
#endif

return util::OkStatus();
absl::StatusOr<internal::SslUniquePtr<RSA>> rsa =
internal::RsaPrivateKeyToRsa(internal::RsaPrivateKey{
/*n=*/std::string(modulus.GetValue()),
/*e=*/std::string(public_exponent.GetValue()),
/*d=*/d.GetSecretData(InsecureSecretKeyAccess::Get()),
/*p=*/p.GetSecretData(InsecureSecretKeyAccess::Get()),
/*q=*/q.GetSecretData(InsecureSecretKeyAccess::Get()),
/*dp=*/dp.GetSecretData(InsecureSecretKeyAccess::Get()),
/*dq=*/dq.GetSecretData(InsecureSecretKeyAccess::Get()),
/*crt=*/q_inv.GetSecretData(InsecureSecretKeyAccess::Get()),
});
return rsa.status();
}

} // namespace
Expand Down
16 changes: 8 additions & 8 deletions tink/jwt/jwt_rsa_ssa_pkcs1_private_key_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesModulus) {

EXPECT_THAT(private_key_modified_modulus.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Modulus size is")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeP) {
Expand All @@ -397,7 +397,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeP) {

EXPECT_THAT(private_key_modified_prime_p.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeQ) {
Expand All @@ -422,7 +422,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeQ) {

EXPECT_THAT(private_key_modified_prime_q.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeExponentP) {
Expand All @@ -448,7 +448,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeExponentP) {

EXPECT_THAT(private_key_modified_prime_exponent_p.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeExponentQ) {
Expand All @@ -474,7 +474,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrimeExponentQ) {

EXPECT_THAT(private_key_modified_prime_exponent_q.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrivateExponent) {
Expand All @@ -500,7 +500,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesPrivateExponent) {

EXPECT_THAT(private_key_modified_private_exponent.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesCrtCoefficient) {
Expand All @@ -526,7 +526,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPrivateKeyValidatesCrtCoefficient) {

EXPECT_THAT(private_key_modified_crt_coefficient.status(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST(JwtRsaSsaPkcs1PrivateKeyTest, BuildPublicKeyNotSetFails) {
Expand Down Expand Up @@ -727,7 +727,7 @@ TEST(JwtRsaSsaPkcs1PrivateKeyTest, CreateMismatchedKeyPairFails) {

EXPECT_THAT(private_key.status(),
StatusIs(absl::StatusCode ::kInvalidArgument,
HasSubstr("RSA key pair is not valid")));
HasSubstr("Could not load RSA key")));
}

TEST_P(JwtRsaSsaPkcs1PrivateKeyTest, PrivateKeyEquals) {
Expand Down
Loading

0 comments on commit f944957

Please sign in to comment.