PROV: Add a DER to RSA-PSS deserializer implementation

Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
(Merged from https://github.com/openssl/openssl/pull/12492)
diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt
index e5ed28b..a99648a 100644
--- a/crypto/err/openssl.txt
+++ b/crypto/err/openssl.txt
@@ -2878,6 +2878,7 @@
 PROV_R_INVALID_MODE_INT:126:invalid mode int
 PROV_R_INVALID_PADDING_MODE:168:invalid padding mode
 PROV_R_INVALID_PSS_SALTLEN:169:invalid pss saltlen
+PROV_R_INVALID_RSA_KEY:217:invalid rsa key
 PROV_R_INVALID_SALT_LENGTH:112:invalid salt length
 PROV_R_INVALID_SEED_LENGTH:154:invalid seed length
 PROV_R_INVALID_SIGNATURE_SIZE:179:invalid signature size
diff --git a/providers/common/include/prov/providercommonerr.h b/providers/common/include/prov/providercommonerr.h
index f5fd37d..bdc39e4 100644
--- a/providers/common/include/prov/providercommonerr.h
+++ b/providers/common/include/prov/providercommonerr.h
@@ -101,6 +101,7 @@
 # define PROV_R_INVALID_MODE_INT                          126
 # define PROV_R_INVALID_PADDING_MODE                      168
 # define PROV_R_INVALID_PSS_SALTLEN                       169
+# define PROV_R_INVALID_RSA_KEY                           217
 # define PROV_R_INVALID_SALT_LENGTH                       112
 # define PROV_R_INVALID_SEED_LENGTH                       154
 # define PROV_R_INVALID_SIGNATURE_SIZE                    179
diff --git a/providers/common/provider_err.c b/providers/common/provider_err.c
index 7a5c41b..e65ce96 100644
--- a/providers/common/provider_err.c
+++ b/providers/common/provider_err.c
@@ -96,6 +96,7 @@
     "invalid padding mode"},
     {ERR_PACK(ERR_LIB_PROV, 0, PROV_R_INVALID_PSS_SALTLEN),
     "invalid pss saltlen"},
+    {ERR_PACK(ERR_LIB_PROV, 0, PROV_R_INVALID_RSA_KEY), "invalid rsa key"},
     {ERR_PACK(ERR_LIB_PROV, 0, PROV_R_INVALID_SALT_LENGTH),
     "invalid salt length"},
     {ERR_PACK(ERR_LIB_PROV, 0, PROV_R_INVALID_SEED_LENGTH),
diff --git a/providers/defltprov.c b/providers/defltprov.c
index 7ab006a..466b790 100644
--- a/providers/defltprov.c
+++ b/providers/defltprov.c
@@ -537,6 +537,8 @@
 static const OSSL_ALGORITHM deflt_deserializer[] = {
     { "RSA", "provider=default,fips=yes,input=der",
       der_to_rsa_deserializer_functions },
+    { "RSA-PSS", "provider=default,fips=yes,input=der",
+      der_to_rsapss_deserializer_functions },
 
     { "DER", "provider=default,fips=yes,input=pem",
       pem_to_der_deserializer_functions },
diff --git a/providers/implementations/include/prov/implementations.h b/providers/implementations/include/prov/implementations.h
index 4890f11..b02f0c6 100644
--- a/providers/implementations/include/prov/implementations.h
+++ b/providers/implementations/include/prov/implementations.h
@@ -360,4 +360,5 @@
 extern const OSSL_DISPATCH ec_param_pem_serializer_functions[];
 
 extern const OSSL_DISPATCH der_to_rsa_deserializer_functions[];
+extern const OSSL_DISPATCH der_to_rsapss_deserializer_functions[];
 extern const OSSL_DISPATCH pem_to_der_deserializer_functions[];
diff --git a/providers/implementations/keymgmt/rsa_kmgmt.c b/providers/implementations/keymgmt/rsa_kmgmt.c
index 21a35d7..7ed280e 100644
--- a/providers/implementations/keymgmt/rsa_kmgmt.c
+++ b/providers/implementations/keymgmt/rsa_kmgmt.c
@@ -628,6 +628,7 @@
       (void (*)(void))rsapss_gen_settable_params },
     { OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))rsa_gen },
     { OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))rsa_gen_cleanup },
+    { OSSL_FUNC_KEYMGMT_LOAD, (void (*)(void))rsa_load },
     { OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))rsa_freedata },
     { OSSL_FUNC_KEYMGMT_GET_PARAMS, (void (*) (void))rsa_get_params },
     { OSSL_FUNC_KEYMGMT_GETTABLE_PARAMS, (void (*) (void))rsa_gettable_params },
diff --git a/providers/implementations/serializers/deserialize_der2rsa.c b/providers/implementations/serializers/deserialize_der2rsa.c
index 6854c7e..7506654 100644
--- a/providers/implementations/serializers/deserialize_der2rsa.c
+++ b/providers/implementations/serializers/deserialize_der2rsa.c
@@ -16,10 +16,12 @@
 #include <openssl/core_dispatch.h>
 #include <openssl/core_names.h>
 #include <openssl/crypto.h>
+#include <openssl/err.h>
 #include <openssl/params.h>
 #include <openssl/x509.h>
 #include "prov/bio.h"
 #include "prov/implementations.h"
+#include "prov/providercommonerr.h"
 #include "serializer_local.h"
 
 static OSSL_FUNC_deserializer_newctx_fn der2rsa_newctx;
@@ -37,10 +39,12 @@
 struct der2rsa_ctx_st {
     PROV_CTX *provctx;
 
+    int type;
+
     struct pkcs8_encrypt_ctx_st sc;
 };
 
-static void *der2rsa_newctx(void *provctx)
+static struct der2rsa_ctx_st *der2rsa_newctx_int(void *provctx)
 {
     struct der2rsa_ctx_st *ctx = OPENSSL_zalloc(sizeof(*ctx));
 
@@ -52,6 +56,24 @@
     return ctx;
 }
 
+static void *der2rsa_newctx(void *provctx)
+{
+    struct der2rsa_ctx_st *ctx = der2rsa_newctx_int(provctx);
+
+    if (ctx != NULL)
+        ctx->type = EVP_PKEY_RSA;
+    return ctx;
+}
+
+static void *der2rsapss_newctx(void *provctx)
+{
+    struct der2rsa_ctx_st *ctx = der2rsa_newctx_int(provctx);
+
+    if (ctx != NULL)
+        ctx->type = EVP_PKEY_RSA_PSS;
+    return ctx;
+}
+
 static void der2rsa_freectx(void *vctx)
 {
     struct der2rsa_ctx_st *ctx = vctx;
@@ -166,7 +188,7 @@
     }
 
     derp = der;
-    if ((pkey = d2i_PrivateKey_ex(EVP_PKEY_RSA, NULL, &derp, der_len,
+    if ((pkey = d2i_PrivateKey_ex(ctx->type, NULL, &derp, der_len,
                                   libctx, NULL)) != NULL) {
         /* Tear out the RSA pointer from the pkey */
         rsa = EVP_PKEY_get1_RSA(pkey);
@@ -177,10 +199,27 @@
 
     if (rsa != NULL) {
         OSSL_PARAM params[3];
+        char *object_type = NULL;
+
+        switch (RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK)) {
+        case RSA_FLAG_TYPE_RSA:
+            object_type = "RSA";
+            break;
+        case RSA_FLAG_TYPE_RSASSAPSS:
+            object_type = "RSA-PSS";
+            break;
+        default:
+            ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_RSA_KEY,
+                           "Expected the RSA type to be %d or %d, but got %d",
+                           RSA_FLAG_TYPE_RSA, RSA_FLAG_TYPE_RSASSAPSS,
+                           RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK));
+            goto end;
+        }
+
 
         params[0] =
             OSSL_PARAM_construct_utf8_string(OSSL_DESERIALIZER_PARAM_DATA_TYPE,
-                                             "RSA", 0);
+                                             object_type, 0);
         /* The address of the key becomes the octet string */
         params[1] =
             OSSL_PARAM_construct_octet_string(OSSL_DESERIALIZER_PARAM_REFERENCE,
@@ -189,17 +228,18 @@
 
         ok = data_cb(params, data_cbarg);
     }
+ end:
     RSA_free(rsa);
 
     return ok;
 }
 
-static int der2rsa_export_object(void *vctx,
-                                 const void *reference, size_t reference_sz,
-                                 OSSL_CALLBACK *export_cb, void *export_cbarg)
+static int der2rsa_export_object_int(void *vctx,
+                                     const void *reference, size_t reference_sz,
+                                     OSSL_FUNC_keymgmt_export_fn *rsa_export,
+                                     OSSL_CALLBACK *export_cb,
+                                     void *export_cbarg)
 {
-    OSSL_FUNC_keymgmt_export_fn *rsa_export =
-        ossl_prov_get_keymgmt_rsa_export();
     void *keydata;
 
     if (reference_sz == sizeof(keydata) && rsa_export != NULL) {
@@ -212,6 +252,26 @@
     return 0;
 }
 
+static int der2rsa_export_object(void *vctx,
+                                 const void *reference, size_t reference_sz,
+                                 OSSL_CALLBACK *export_cb,
+                                 void *export_cbarg)
+{
+    return der2rsa_export_object_int(vctx, reference, reference_sz,
+                                     ossl_prov_get_keymgmt_rsa_export(),
+                                     export_cb, export_cbarg);
+}
+
+static int der2rsapss_export_object(void *vctx,
+                                    const void *reference, size_t reference_sz,
+                                    OSSL_CALLBACK *export_cb,
+                                    void *export_cbarg)
+{
+    return der2rsa_export_object_int(vctx, reference, reference_sz,
+                                     ossl_prov_get_keymgmt_rsapss_export(),
+                                     export_cb, export_cbarg);
+}
+
 const OSSL_DISPATCH der_to_rsa_deserializer_functions[] = {
     { OSSL_FUNC_DESERIALIZER_NEWCTX, (void (*)(void))der2rsa_newctx },
     { OSSL_FUNC_DESERIALIZER_FREECTX, (void (*)(void))der2rsa_freectx },
@@ -229,3 +289,21 @@
       (void (*)(void))der2rsa_export_object },
     { 0, NULL }
 };
+
+const OSSL_DISPATCH der_to_rsapss_deserializer_functions[] = {
+    { OSSL_FUNC_DESERIALIZER_NEWCTX, (void (*)(void))der2rsapss_newctx },
+    { OSSL_FUNC_DESERIALIZER_FREECTX, (void (*)(void))der2rsa_freectx },
+    { OSSL_FUNC_DESERIALIZER_GETTABLE_PARAMS,
+      (void (*)(void))der2rsa_gettable_params },
+    { OSSL_FUNC_DESERIALIZER_GET_PARAMS,
+      (void (*)(void))der2rsa_get_params },
+    { OSSL_FUNC_DESERIALIZER_SETTABLE_CTX_PARAMS,
+      (void (*)(void))der2rsa_settable_ctx_params },
+    { OSSL_FUNC_DESERIALIZER_SET_CTX_PARAMS,
+      (void (*)(void))der2rsa_set_ctx_params },
+    { OSSL_FUNC_DESERIALIZER_DESERIALIZE,
+      (void (*)(void))der2rsa_deserialize },
+    { OSSL_FUNC_DESERIALIZER_EXPORT_OBJECT,
+      (void (*)(void))der2rsapss_export_object },
+    { 0, NULL }
+};
diff --git a/providers/implementations/serializers/serializer_local.h b/providers/implementations/serializers/serializer_local.h
index a94418b..f1d2fe7 100644
--- a/providers/implementations/serializers/serializer_local.h
+++ b/providers/implementations/serializers/serializer_local.h
@@ -38,9 +38,11 @@
 OSSL_FUNC_keymgmt_export_fn *ossl_prov_get_keymgmt_export(const OSSL_DISPATCH *fns);
 
 OSSL_FUNC_keymgmt_new_fn *ossl_prov_get_keymgmt_rsa_new(void);
+OSSL_FUNC_keymgmt_new_fn *ossl_prov_get_keymgmt_rsapss_new(void);
 OSSL_FUNC_keymgmt_free_fn *ossl_prov_get_keymgmt_rsa_free(void);
 OSSL_FUNC_keymgmt_import_fn *ossl_prov_get_keymgmt_rsa_import(void);
 OSSL_FUNC_keymgmt_export_fn *ossl_prov_get_keymgmt_rsa_export(void);
+OSSL_FUNC_keymgmt_export_fn *ossl_prov_get_keymgmt_rsapss_export(void);
 OSSL_FUNC_keymgmt_new_fn *ossl_prov_get_keymgmt_dh_new(void);
 OSSL_FUNC_keymgmt_free_fn *ossl_prov_get_keymgmt_dh_free(void);
 OSSL_FUNC_keymgmt_import_fn *ossl_prov_get_keymgmt_dh_import(void);
diff --git a/providers/implementations/serializers/serializer_rsa.c b/providers/implementations/serializers/serializer_rsa.c
index d2a5459..9250d49 100644
--- a/providers/implementations/serializers/serializer_rsa.c
+++ b/providers/implementations/serializers/serializer_rsa.c
@@ -27,6 +27,11 @@
     return ossl_prov_get_keymgmt_new(rsa_keymgmt_functions);
 }
 
+OSSL_FUNC_keymgmt_new_fn *ossl_prov_get_keymgmt_rsapss_new(void)
+{
+    return ossl_prov_get_keymgmt_new(rsapss_keymgmt_functions);
+}
+
 OSSL_FUNC_keymgmt_free_fn *ossl_prov_get_keymgmt_rsa_free(void)
 {
     return ossl_prov_get_keymgmt_free(rsa_keymgmt_functions);
@@ -42,6 +47,11 @@
     return ossl_prov_get_keymgmt_export(rsa_keymgmt_functions);
 }
 
+OSSL_FUNC_keymgmt_export_fn *ossl_prov_get_keymgmt_rsapss_export(void)
+{
+    return ossl_prov_get_keymgmt_export(rsapss_keymgmt_functions);
+}
+
 int ossl_prov_print_rsa(BIO *out, RSA *rsa, int priv)
 {
     const char *modulus_label;