Extend PSS signature support to TLSv1.2

TLSv1.3 introduces PSS based sigalgs. Offering these in a TLSv1.3 client
implies that the client is prepared to accept these sigalgs even in
TLSv1.2.

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2157)
diff --git a/ssl/statem/statem_clnt.c b/ssl/statem/statem_clnt.c
index 432dc91..5eec0d1 100644
--- a/ssl/statem/statem_clnt.c
+++ b/ssl/statem/statem_clnt.c
@@ -1824,9 +1824,11 @@
 
 MSG_PROCESS_RETURN tls_process_key_exchange(SSL *s, PACKET *pkt)
 {
-    int al = -1;
+    int al = -1, ispss = 0;
     long alg_k;
     EVP_PKEY *pkey = NULL;
+    EVP_MD_CTX *md_ctx = NULL;
+    EVP_PKEY_CTX *pctx = NULL;
     PACKET save_param_start, signature;
 
     alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
@@ -1865,7 +1867,6 @@
         PACKET params;
         int maxsig;
         const EVP_MD *md = NULL;
-        EVP_MD_CTX *md_ctx;
 
         /*
          * |pkt| now points to the beginning of the signature, so the difference
@@ -1896,6 +1897,7 @@
                 al = SSL_AD_DECODE_ERROR;
                 goto err;
             }
+            ispss = SIGID_IS_PSS(sigalg);
 #ifdef SSL_DEBUG
             fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
 #endif
@@ -1936,29 +1938,39 @@
             goto err;
         }
 
-        if (EVP_VerifyInit_ex(md_ctx, md, NULL) <= 0
-            || EVP_VerifyUpdate(md_ctx, &(s->s3->client_random[0]),
-                                SSL3_RANDOM_SIZE) <= 0
-            || EVP_VerifyUpdate(md_ctx, &(s->s3->server_random[0]),
-                                SSL3_RANDOM_SIZE) <= 0
-            || EVP_VerifyUpdate(md_ctx, PACKET_data(&params),
-                                PACKET_remaining(&params)) <= 0) {
-            EVP_MD_CTX_free(md_ctx);
+        if (EVP_DigestVerifyInit(md_ctx, &pctx, md, NULL, pkey) <= 0) {
             al = SSL_AD_INTERNAL_ERROR;
             SSLerr(SSL_F_TLS_PROCESS_KEY_EXCHANGE, ERR_R_EVP_LIB);
             goto err;
         }
-        /* TODO(size_t): Convert this call */
-        if (EVP_VerifyFinal(md_ctx, PACKET_data(&signature),
-                            (unsigned int)PACKET_remaining(&signature),
-                            pkey) <= 0) {
+        if (ispss) {
+            if (EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) <= 0
+                       /* -1 here means set saltlen to the digest len */
+                    || EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1) <= 0) {
+                al = SSL_AD_INTERNAL_ERROR;
+                SSLerr(SSL_F_TLS_PROCESS_KEY_EXCHANGE, ERR_R_EVP_LIB);
+                goto err;
+            }
+        }
+        if (EVP_DigestVerifyUpdate(md_ctx, &(s->s3->client_random[0]),
+                                   SSL3_RANDOM_SIZE) <= 0
+                || EVP_DigestVerifyUpdate(md_ctx, &(s->s3->server_random[0]),
+                                          SSL3_RANDOM_SIZE) <= 0
+                || EVP_DigestVerifyUpdate(md_ctx, PACKET_data(&params),
+                                          PACKET_remaining(&params)) <= 0) {
+            al = SSL_AD_INTERNAL_ERROR;
+            SSLerr(SSL_F_TLS_PROCESS_KEY_EXCHANGE, ERR_R_EVP_LIB);
+            goto err;
+        }
+        if (EVP_DigestVerifyFinal(md_ctx, PACKET_data(&signature),
+                                  PACKET_remaining(&signature)) <= 0) {
             /* bad signature */
-            EVP_MD_CTX_free(md_ctx);
             al = SSL_AD_DECRYPT_ERROR;
             SSLerr(SSL_F_TLS_PROCESS_KEY_EXCHANGE, SSL_R_BAD_SIGNATURE);
             goto err;
         }
         EVP_MD_CTX_free(md_ctx);
+        md_ctx = NULL;
     } else {
         /* aNULL, aSRP or PSK do not need public keys */
         if (!(s->s3->tmp.new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP))
@@ -1986,6 +1998,7 @@
     if (al != -1)
         ssl3_send_alert(s, SSL3_AL_FATAL, al);
     ossl_statem_set_error(s);
+    EVP_MD_CTX_free(md_ctx);
     return MSG_PROCESS_ERROR;
 }
 
diff --git a/ssl/statem/statem_lib.c b/ssl/statem/statem_lib.c
index 5b8ad3a..03efdec 100644
--- a/ssl/statem/statem_lib.c
+++ b/ssl/statem/statem_lib.c
@@ -136,7 +136,7 @@
     void *hdata;
     unsigned char *sig = NULL;
     unsigned char tls13tbs[TLS13_TBS_PREAMBLE_SIZE + EVP_MAX_MD_SIZE];
-    int pktype;
+    int pktype, ispss = 0;
 
     if (s->server) {
         /* Only happens in TLSv1.3 */
@@ -166,7 +166,7 @@
         goto err;
     }
 
-    if (SSL_USE_SIGALGS(s) && !tls12_get_sigandhash(s, pkt, pkey, md)) {
+    if (SSL_USE_SIGALGS(s) && !tls12_get_sigandhash(s, pkt, pkey, md, &ispss)) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CERT_VERIFY, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -186,7 +186,7 @@
         goto err;
     }
 
-    if (SSL_IS_TLS13(s) && pktype == EVP_PKEY_RSA) {
+    if (ispss) {
         if (EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) <= 0
                    /* -1 here means set saltlen to the digest len */
                 || EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1) <= 0) {
@@ -243,7 +243,7 @@
     unsigned char *gost_data = NULL;
 #endif
     int al = SSL_AD_INTERNAL_ERROR, ret = MSG_PROCESS_ERROR;
-    int type = 0, j, pktype;
+    int type = 0, j, pktype, ispss = 0;
     unsigned int len;
     X509 *peer;
     const EVP_MD *md = NULL;
@@ -297,6 +297,7 @@
                 al = SSL_AD_DECODE_ERROR;
                 goto f_err;
             }
+            ispss = SIGID_IS_PSS(sigalg);
 #ifdef SSL_DEBUG
             fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
 #endif
@@ -358,7 +359,7 @@
     }
 #endif
 
-    if (SSL_IS_TLS13(s) && pktype == EVP_PKEY_RSA) {
+    if (ispss) {
         if (EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) <= 0
                    /* -1 here means set saltlen to the digest len */
                 || EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1) <= 0) {
diff --git a/ssl/statem/statem_srvr.c b/ssl/statem/statem_srvr.c
index 81a72ad..0573af1 100644
--- a/ssl/statem/statem_srvr.c
+++ b/ssl/statem/statem_srvr.c
@@ -1956,6 +1956,7 @@
     unsigned long type;
     const BIGNUM *r[4];
     EVP_MD_CTX *md_ctx = EVP_MD_CTX_new();
+    EVP_PKEY_CTX *pctx = NULL;
     size_t paramlen, paramoffset;
 
     if (!WPACKET_get_total_written(pkt, &paramoffset)) {
@@ -2212,7 +2213,8 @@
          */
         if (md) {
             unsigned char *sigbytes1, *sigbytes2;
-            unsigned int siglen;
+            size_t siglen;
+            int ispss = 0;
 
             /* Get length of the parameters we have written above */
             if (!WPACKET_get_length(pkt, &paramlen)) {
@@ -2222,7 +2224,7 @@
             }
             /* send signature algorithm */
             if (SSL_USE_SIGALGS(s)) {
-                if (!tls12_get_sigandhash(s, pkt, pkey, md)) {
+                if (!tls12_get_sigandhash(s, pkt, pkey, md, &ispss)) {
                     /* Should never happen */
                     SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
                            ERR_R_INTERNAL_ERROR);
@@ -2240,14 +2242,29 @@
              */
             if (!WPACKET_sub_reserve_bytes_u16(pkt, EVP_PKEY_size(pkey),
                                                &sigbytes1)
-                    || EVP_SignInit_ex(md_ctx, md, NULL) <= 0
-                    || EVP_SignUpdate(md_ctx, &(s->s3->client_random[0]),
-                                      SSL3_RANDOM_SIZE) <= 0
-                    || EVP_SignUpdate(md_ctx, &(s->s3->server_random[0]),
-                                      SSL3_RANDOM_SIZE) <= 0
-                    || EVP_SignUpdate(md_ctx, s->init_buf->data + paramoffset,
-                                      paramlen) <= 0
-                    || EVP_SignFinal(md_ctx, sigbytes1, &siglen, pkey) <= 0
+                    || EVP_DigestSignInit(md_ctx, &pctx, md, NULL, pkey) <= 0) {
+                SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
+                       ERR_R_INTERNAL_ERROR);
+                goto f_err;
+            }
+            if (ispss) {
+                if (EVP_PKEY_CTX_set_rsa_padding(pctx,
+                                                 RSA_PKCS1_PSS_PADDING) <= 0
+                           /* -1 here means set saltlen to the digest len */
+                        || EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1) <= 0) {
+                    SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,
+                           ERR_R_EVP_LIB);
+                    goto f_err;
+                }
+            }
+            if (EVP_DigestSignUpdate(md_ctx, &(s->s3->client_random[0]),
+                                     SSL3_RANDOM_SIZE) <= 0
+                    || EVP_DigestSignUpdate(md_ctx, &(s->s3->server_random[0]),
+                                            SSL3_RANDOM_SIZE) <= 0
+                    || EVP_DigestSignUpdate(md_ctx,
+                                            s->init_buf->data + paramoffset,
+                                            paramlen) <= 0
+                    || EVP_DigestSignFinal(md_ctx, sigbytes1, &siglen) <= 0
                     || !WPACKET_sub_allocate_bytes_u16(pkt, siglen, &sigbytes2)
                     || sigbytes1 != sigbytes2) {
                 SSLerr(SSL_F_TLS_CONSTRUCT_SERVER_KEY_EXCHANGE,