Add test for query invalidation after new provider added

Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Richard Levitte <levitte@openssl.org>
Reviewed-by: Paul Dale <pauli@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/18269)
diff --git a/test/fake_rsaprov.c b/test/fake_rsaprov.c
index e4833a6..d250332 100644
--- a/test/fake_rsaprov.c
+++ b/test/fake_rsaprov.c
@@ -93,6 +93,41 @@
     return fake_rsa_import_key_types;
 }
 
+static void *fake_rsa_gen_init(void *provctx, int selection,
+                               const OSSL_PARAM params[])
+{
+    unsigned char *gctx = NULL;
+
+    if (!TEST_ptr(gctx = OPENSSL_malloc(1)))
+        return NULL;
+
+    *gctx = 1;
+
+    return gctx;
+}
+
+static void *fake_rsa_gen(void *genctx, OSSL_CALLBACK *osslcb, void *cbarg)
+{
+    unsigned char *gctx = genctx;
+    static const unsigned char inited[] = { 1 };
+    unsigned char *keydata;
+
+    if (!TEST_ptr(gctx)
+        || !TEST_mem_eq(gctx, sizeof(*gctx), inited, sizeof(inited)))
+        return NULL;
+
+    if (!TEST_ptr(keydata = fake_rsa_keymgmt_new(NULL)))
+        return NULL;
+
+    *keydata = 2;
+    return keydata;
+}
+
+static void fake_rsa_gen_cleanup(void *genctx)
+{
+   OPENSSL_free(genctx);
+}
+
 static const OSSL_DISPATCH fake_rsa_keymgmt_funcs[] = {
     { OSSL_FUNC_KEYMGMT_NEW, (void (*)(void))fake_rsa_keymgmt_new },
     { OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))fake_rsa_keymgmt_free} ,
@@ -102,6 +137,9 @@
     { OSSL_FUNC_KEYMGMT_IMPORT, (void (*)(void))fake_rsa_keymgmt_import },
     { OSSL_FUNC_KEYMGMT_IMPORT_TYPES,
         (void (*)(void))fake_rsa_keymgmt_imptypes },
+    { OSSL_FUNC_KEYMGMT_GEN_INIT, (void (*)(void))fake_rsa_gen_init },
+    { OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))fake_rsa_gen },
+    { OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))fake_rsa_gen_cleanup },
     { 0, NULL }
 };
 
diff --git a/test/provider_pkey_test.c b/test/provider_pkey_test.c
index d360c0c..dc59326 100644
--- a/test/provider_pkey_test.c
+++ b/test/provider_pkey_test.c
@@ -115,6 +115,66 @@
     return ret;
 }
 
+static int test_alternative_keygen_init(void)
+{
+    EVP_PKEY_CTX *ctx = NULL;
+    OSSL_PROVIDER *deflt = NULL;
+    OSSL_PROVIDER *fake_rsa = NULL;
+    const OSSL_PROVIDER *provider;
+    const char *provname;
+    int ret = 0;
+
+    if (!TEST_ptr(deflt = OSSL_PROVIDER_load(libctx, "default")))
+        goto end;
+
+    /* first try without the fake RSA provider loaded */
+    if (!TEST_ptr(ctx = EVP_PKEY_CTX_new_from_name(libctx, "RSA", NULL)))
+        goto end;
+
+    if (!TEST_int_gt(EVP_PKEY_keygen_init(ctx), 0))
+        goto end;
+
+    if (!TEST_ptr(provider = EVP_PKEY_CTX_get0_provider(ctx)))
+        goto end;
+
+    if (!TEST_ptr(provname = OSSL_PROVIDER_get0_name(provider)))
+        goto end;
+
+    if (!TEST_str_eq(provname, "default"))
+        goto end;
+
+    EVP_PKEY_CTX_free(ctx);
+    ctx = NULL;
+
+    /* now load fake RSA and try again */
+    if (!TEST_ptr(fake_rsa = fake_rsa_start(libctx)))
+        return 0;
+
+    if (!TEST_ptr(ctx = EVP_PKEY_CTX_new_from_name(libctx, "RSA",
+                                                   "?provider=fake-rsa")))
+        goto end;
+
+    if (!TEST_int_gt(EVP_PKEY_keygen_init(ctx), 0))
+        goto end;
+
+    if (!TEST_ptr(provider = EVP_PKEY_CTX_get0_provider(ctx)))
+        goto end;
+
+    if (!TEST_ptr(provname = OSSL_PROVIDER_get0_name(provider)))
+        goto end;
+
+    if (!TEST_str_eq(provname, "fake-rsa"))
+        goto end;
+
+    ret = 1;
+
+end:
+    fake_rsa_finish(fake_rsa);
+    OSSL_PROVIDER_unload(deflt);
+    EVP_PKEY_CTX_free(ctx);
+    return ret;
+}
+
 int setup_tests(void)
 {
     libctx = OSSL_LIB_CTX_new();
@@ -122,6 +182,7 @@
         return 0;
 
     ADD_TEST(test_pkey_sig);
+    ADD_TEST(test_alternative_keygen_init);
 
     return 1;
 }