Add TLSv1.3 server side external PSK support

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3670)
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index cd1fc2e..2dfa7f6 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -763,6 +763,10 @@
                                                const char *identity,
                                                unsigned char *psk,
                                                unsigned int max_psk_len);
+typedef int (*SSL_psk_find_session_cb_func)(SSL *ssl,
+                                            const unsigned char *identity,
+                                            size_t identity_len,
+                                            SSL_SESSION **sess);
 void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx, SSL_psk_server_cb_func cb);
 void SSL_set_psk_server_callback(SSL *ssl, SSL_psk_server_cb_func cb);
 
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 7dc6cd6..b26db6a 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -952,6 +952,7 @@
     SSL_psk_client_cb_func psk_client_callback;
     SSL_psk_server_cb_func psk_server_callback;
 # endif
+    SSL_psk_find_session_cb_func psk_find_session_cb;
 
 # ifndef OPENSSL_NO_SRP
     SRP_CTX srp_ctx;            /* ctx for SRP authentication */
@@ -1122,6 +1123,7 @@
     SSL_psk_client_cb_func psk_client_callback;
     SSL_psk_server_cb_func psk_server_callback;
 # endif
+    SSL_psk_find_session_cb_func psk_find_session_cb;
     SSL_CTX *ctx;
     /* Verified chain of peer */
     STACK_OF(X509) *verified_chain;
diff --git a/ssl/statem/extensions.c b/ssl/statem/extensions.c
index d40c34c..7defbb1 100644
--- a/ssl/statem/extensions.c
+++ b/ssl/statem/extensions.c
@@ -1227,17 +1227,27 @@
 
 int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
                       size_t binderoffset, const unsigned char *binderin,
-                      unsigned char *binderout,
-                      SSL_SESSION *sess, int sign)
+                      unsigned char *binderout, SSL_SESSION *sess, int sign,
+                      int external)
 {
     EVP_PKEY *mackey = NULL;
     EVP_MD_CTX *mctx = NULL;
     unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
     unsigned char finishedkey[EVP_MAX_MD_SIZE], tmpbinder[EVP_MAX_MD_SIZE];
     const char resumption_label[] = "res binder";
-    size_t bindersize, hashsize = EVP_MD_size(md);
+    const char external_label[] = "ext binder";
+    const char *label;
+    size_t bindersize, labelsize, hashsize = EVP_MD_size(md);
     int ret = -1;
 
+    if (external) {
+        label = external_label;
+        labelsize = sizeof(external_label) - 1;
+    } else {
+        label = resumption_label;
+        labelsize = sizeof(resumption_label) - 1;
+    }
+
     /* Generate the early_secret */
     if (!tls13_generate_secret(s, md, NULL, sess->master_key,
                                sess->master_key_length,
@@ -1259,10 +1269,8 @@
     }
 
     /* Generate the binder key */
-    if (!tls13_hkdf_expand(s, md, s->early_secret,
-                           (unsigned char *)resumption_label,
-                           sizeof(resumption_label) - 1, hash, binderkey,
-                           hashsize)) {
+    if (!tls13_hkdf_expand(s, md, s->early_secret, (unsigned char *)label,
+                           labelsize, hash, binderkey, hashsize)) {
         SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
         goto err;
     }
diff --git a/ssl/statem/extensions_clnt.c b/ssl/statem/extensions_clnt.c
index 8aa795e..2f6d16e 100644
--- a/ssl/statem/extensions_clnt.c
+++ b/ssl/statem/extensions_clnt.c
@@ -894,7 +894,7 @@
     msgstart = WPACKET_get_curr(pkt) - msglen;
 
     if (tls_psk_do_binder(s, md, msgstart, binderoffset, NULL, binder,
-                          s->session, 1) != 1) {
+                          s->session, 1, 0) != 1) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
         goto err;
     }
diff --git a/ssl/statem/extensions_srvr.c b/ssl/statem/extensions_srvr.c
index 39e6c07..4e65320 100644
--- a/ssl/statem/extensions_srvr.c
+++ b/ssl/statem/extensions_srvr.c
@@ -686,9 +686,8 @@
     PACKET identities, binders, binder;
     size_t binderoffset, hashsize;
     SSL_SESSION *sess = NULL;
-    unsigned int id, i;
+    unsigned int id, i, ext = 0;
     const EVP_MD *md = NULL;
-    uint32_t ticket_age = 0, now, agesec, agems;
 
     /*
      * If we have no PSK kex mode that we recognise then we can't resume so
@@ -706,7 +705,6 @@
     for (id = 0; PACKET_remaining(&identities) != 0; id++) {
         PACKET identity;
         unsigned long ticket_agel;
-        int ret;
 
         if (!PACKET_get_length_prefixed_2(&identities, &identity)
                 || !PACKET_get_net_4(&identities, &ticket_agel)) {
@@ -714,16 +712,64 @@
             return 0;
         }
 
-        ticket_age = (uint32_t)ticket_agel;
+        if (s->psk_find_session_cb != NULL
+                && s->psk_find_session_cb(s, PACKET_data(&identity),
+                                          PACKET_remaining(&identity), &sess)) {
+            SSL_SESSION *sesstmp = ssl_session_dup(sess, 0);
 
-        ret = tls_decrypt_ticket(s, PACKET_data(&identity),
-                                 PACKET_remaining(&identity), NULL, 0, &sess);
-        if (ret == TICKET_FATAL_ERR_MALLOC || ret == TICKET_FATAL_ERR_OTHER) {
-            *al = SSL_AD_INTERNAL_ERROR;
-            return 0;
+            if (sesstmp == NULL) {
+                *al = SSL_AD_INTERNAL_ERROR;
+                return 0;
+            }
+            SSL_SESSION_free(sess);
+            sess = sesstmp;
+
+            /*
+             * We've just been told to use this session for this context so
+             * make sure the sid_ctx matches up.
+             */
+            memcpy(sess->sid_ctx, s->sid_ctx, s->sid_ctx_length);
+            sess->sid_ctx_length = s->sid_ctx_length;
+            ext = 1;
+        } else {
+            uint32_t ticket_age = 0, now, agesec, agems;
+            int ret = tls_decrypt_ticket(s, PACKET_data(&identity),
+                                         PACKET_remaining(&identity), NULL, 0,
+                                         &sess);
+
+            if (ret == TICKET_FATAL_ERR_MALLOC
+                    || ret == TICKET_FATAL_ERR_OTHER) {
+                *al = SSL_AD_INTERNAL_ERROR;
+                return 0;
+            }
+            if (ret == TICKET_NO_DECRYPT)
+                continue;
+
+            ticket_age = (uint32_t)ticket_agel;
+            now = (uint32_t)time(NULL);
+            agesec = now - (uint32_t)sess->time;
+            agems = agesec * (uint32_t)1000;
+            ticket_age -= sess->ext.tick_age_add;
+
+            /*
+             * For simplicity we do our age calculations in seconds. If the
+             * client does it in ms then it could appear that their ticket age
+             * is longer than ours (our ticket age calculation should always be
+             * slightly longer than the client's due to the network latency).
+             * Therefore we add 1000ms to our age calculation to adjust for
+             * rounding errors.
+             */
+            if (sess->timeout >= (long)agesec
+                    && agems / (uint32_t)1000 == agesec
+                    && ticket_age <= agems + 1000
+                    && ticket_age + TICKET_AGE_ALLOWANCE >= agems + 1000) {
+                /*
+                 * Ticket age is within tolerance and not expired. We allow it
+                 * for early data
+                 */
+                s->ext.early_data_ok = 1;
+            }
         }
-        if (ret == TICKET_NO_DECRYPT)
-            continue;
 
         md = ssl_md(sess->cipher->algorithm2);
         if (md != ssl_md(s->s3->tmp.new_cipher->algorithm2)) {
@@ -732,12 +778,6 @@
             sess = NULL;
             continue;
         }
-
-        /*
-         * TODO(TLS1.3): Somehow we need to handle the case of a ticket renewal.
-         * Ignored for now
-         */
-
         break;
     }
 
@@ -763,7 +803,7 @@
             || tls_psk_do_binder(s, md,
                                  (const unsigned char *)s->init_buf->data,
                                  binderoffset, PACKET_data(&binder), NULL,
-                                 sess, 0) != 1) {
+                                 sess, 0, ext) != 1) {
         *al = SSL_AD_DECODE_ERROR;
         SSLerr(SSL_F_TLS_PARSE_CTOS_PSK, ERR_R_INTERNAL_ERROR);
         goto err;
@@ -771,31 +811,6 @@
 
     sess->ext.tick_identity = id;
 
-    now = (uint32_t)time(NULL);
-    agesec = now - (uint32_t)sess->time;
-    agems = agesec * (uint32_t)1000;
-    ticket_age -= sess->ext.tick_age_add;
-
-
-    /*
-     * For simplicity we do our age calculations in seconds. If the client does
-     * it in ms then it could appear that their ticket age is longer than ours
-     * (our ticket age calculation should always be slightly longer than the
-     * client's due to the network latency). Therefore we add 1000ms to our age
-     * calculation to adjust for rounding errors.
-     */
-    if (sess->timeout >= (long)agesec
-            && agems / (uint32_t)1000 == agesec
-            && ticket_age <= agems + 1000
-            && ticket_age + TICKET_AGE_ALLOWANCE >= agems + 1000) {
-        /*
-         * Ticket age is within tolerance and not expired. We allow it for early
-         * data
-         */
-        s->ext.early_data_ok = 1;
-    }
-
-
     SSL_SESSION_free(s->session);
     s->session = sess;
     return 1;
diff --git a/ssl/statem/statem_locl.h b/ssl/statem/statem_locl.h
index b6ce4cc..1f8c22d 100644
--- a/ssl/statem/statem_locl.h
+++ b/ssl/statem/statem_locl.h
@@ -183,7 +183,7 @@
                              const unsigned char *msgstart,
                              size_t binderoffset, const unsigned char *binderin,
                              unsigned char *binderout,
-                             SSL_SESSION *sess, int sign);
+                             SSL_SESSION *sess, int sign, int external);
 
 /* Server Extension processing */
 int tls_parse_ctos_renegotiate(SSL *s, PACKET *pkt, unsigned int context,