Implement Server side of PSK extension parsing

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/2259)
diff --git a/ssl/statem/extensions.c b/ssl/statem/extensions.c
index f1a1675..95bfe75 100644
--- a/ssl/statem/extensions.c
+++ b/ssl/statem/extensions.c
@@ -279,7 +279,8 @@
         TLSEXT_TYPE_psk,
         EXT_CLIENT_HELLO | EXT_TLS1_3_SERVER_HELLO | EXT_TLS_IMPLEMENTATION_ONLY
         | EXT_TLS1_3_ONLY,
-        NULL, NULL, tls_parse_stoc_psk, NULL, tls_construct_ctos_psk, NULL
+        NULL, tls_parse_ctos_psk, tls_parse_stoc_psk, NULL,
+        tls_construct_ctos_psk, NULL
     }
 };
 
@@ -1002,3 +1003,97 @@
 
     return 1;
 }
+
+int tls_psk_do_binder(SSL *s, const EVP_MD *md, const unsigned char *msgstart,
+                      size_t binderoffset, const unsigned char *binderin,
+                      unsigned char *binderout,
+                      SSL_SESSION *sess, int sign)
+{
+    EVP_PKEY *mackey = NULL;
+    EVP_MD_CTX *mctx = NULL;
+    unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
+    unsigned char finishedkey[EVP_MAX_MD_SIZE], tmpbinder[EVP_MAX_MD_SIZE];
+    const char resumption_label[] = "resumption psk binder key";
+    size_t hashsize = EVP_MD_size(md), bindersize;
+    int ret = -1;
+
+    /* Generate the early_secret */
+    if (!tls13_generate_secret(s, md, NULL, sess->master_key,
+                               sess->master_key_length,
+                               (unsigned char *)&s->early_secret)) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    /*
+     * Create the handshake hash for the binder key...the messages so far are
+     * empty!
+     */
+    mctx = EVP_MD_CTX_new();
+    if (mctx == NULL
+            || EVP_DigestInit_ex(mctx, md, NULL) <= 0
+            || EVP_DigestFinal_ex(mctx, hash, NULL) <= 0) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    /* Generate the binder key */
+    if (!tls13_hkdf_expand(s, md, s->early_secret,
+                           (unsigned char *)resumption_label,
+                           sizeof(resumption_label) - 1, hash, binderkey,
+                           hashsize)) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    /* Generate the finished key */
+    if (!tls13_derive_finishedkey(s, md, binderkey, finishedkey, hashsize)) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    /*
+     * Get a hash of the ClientHello up to the start of the binders.
+     * TODO(TLS1.3): This will need to be tweaked when we implement
+     * HelloRetryRequest to include the digest of the previous messages here.
+     */
+    if (EVP_DigestInit_ex(mctx, md, NULL) <= 0
+            || EVP_DigestUpdate(mctx, msgstart, binderoffset) <= 0
+            || EVP_DigestFinal_ex(mctx, hash, NULL) <= 0) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    mackey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, finishedkey, hashsize);
+    if (mackey == NULL) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    if (!sign)
+        binderout = tmpbinder;
+
+    bindersize = hashsize;
+    if (EVP_DigestSignInit(mctx, NULL, md, NULL, mackey) <= 0
+            || EVP_DigestSignUpdate(mctx, hash, hashsize) <= 0
+            || EVP_DigestSignFinal(mctx, binderout, &bindersize) <= 0
+            || bindersize != hashsize) {
+        SSLerr(SSL_F_TLS_PSK_DO_BINDER, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    if (sign) {
+        ret = 1;
+    } else {
+        /* HMAC keys can't do EVP_DigestVerify* - use CRYPTO_memcmp instead */
+        ret = (CRYPTO_memcmp(binderin, binderout, hashsize) == 0);
+    }
+
+ err:
+    OPENSSL_cleanse(binderkey, sizeof(binderkey));
+    OPENSSL_cleanse(finishedkey, sizeof(finishedkey));
+    EVP_PKEY_free(mackey);
+    EVP_MD_CTX_free(mctx);
+
+    return ret;
+}
diff --git a/ssl/statem/extensions_clnt.c b/ssl/statem/extensions_clnt.c
index eb8cfa3..8c66332 100644
--- a/ssl/statem/extensions_clnt.c
+++ b/ssl/statem/extensions_clnt.c
@@ -663,16 +663,10 @@
                            int *al)
 {
 #ifndef OPENSSL_NO_TLS1_3
-    const SSL_CIPHER *cipher;
     uint32_t now, ages, agems;
-    size_t hashsize, bindersize, binderoffset, msglen;
+    size_t hashsize, binderoffset, msglen;
     unsigned char *binder = NULL, *msgstart = NULL;
-    EVP_PKEY *mackey = NULL;
     const EVP_MD *md;
-    EVP_MD_CTX *mctx = NULL;
-    unsigned char hash[EVP_MAX_MD_SIZE], binderkey[EVP_MAX_MD_SIZE];
-    unsigned char finishedkey[EVP_MAX_MD_SIZE];
-    const char resumption_label[] = "resumption psk binder key";
     int ret = 0;
 
     s->session->ext.tick_identity = TLSEXT_PSK_BAD_IDENTITY;
@@ -719,17 +713,12 @@
      */
     agems += s->session->ext.tick_age_add;
 
-    cipher = ssl3_get_cipher_by_id(s->session->cipher_id);
-    if (cipher == NULL) {
+    md = ssl_cipher_get_handshake_md(s->session->cipher_id);
+    if (md == NULL) {
         /* Don't recognise this cipher so we can't use the session. Ignore it */
         return 1;
     }
-    md = ssl_md(cipher->algorithm2);
-    if (md == NULL) {
-        /* Shouldn't happen!! */
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        return 0;
-    }
+
     hashsize = EVP_MD_size(md);
 
     /* Create the extension, but skip over the binder for now */
@@ -757,60 +746,8 @@
 
     msgstart = WPACKET_get_curr(pkt) - msglen;
 
-    /* Generate the early_secret */
-    if (!tls13_generate_secret(s, md, NULL, s->session->master_key,
-                               s->session->master_key_length,
-                               (unsigned char *)&s->early_secret)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    /*
-     * Create the handshake hash for the binder key...the messages so far are
-     * empty!
-     */
-    mctx = EVP_MD_CTX_new();
-    if (mctx == NULL
-            || EVP_DigestInit_ex(mctx, md, NULL) <= 0
-            || EVP_DigestFinal_ex(mctx, hash, NULL) <= 0) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    /* Generate the binder key */
-    if (!tls13_hkdf_expand(s, md, s->early_secret,
-                           (unsigned char *)resumption_label,
-                           sizeof(resumption_label) - 1, hash, binderkey,
-                           hashsize)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    /* Generate the finished key */
-    if (!tls13_derive_finishedkey(s, md, binderkey, finishedkey, hashsize)) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    /*
-     * Get a hash of the ClientHello up to the start of the binders.
-     * TODO(TLS1.3): This will need to be tweaked when we implement
-     * HelloRetryRequest to include the digest of the previous messages here.
-     */
-    if (EVP_DigestInit_ex(mctx, md, NULL) <= 0
-            || EVP_DigestUpdate(mctx, msgstart, binderoffset) <= 0
-            || EVP_DigestFinal_ex(mctx, hash, NULL) <= 0) {
-        SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    mackey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, finishedkey, hashsize);
-    bindersize = hashsize;
-    if (binderkey == NULL
-            || EVP_DigestSignInit(mctx, NULL, md, NULL, mackey) <= 0
-            || EVP_DigestSignUpdate(mctx, hash, hashsize) <= 0
-            || EVP_DigestSignFinal(mctx, binder, &bindersize) <= 0
-            || bindersize != hashsize) {
+    if (tls_psk_do_binder(s, md, msgstart, binderoffset, NULL, binder,
+                          s->session, 1) != 1) {
         SSLerr(SSL_F_TLS_CONSTRUCT_CTOS_PSK, ERR_R_INTERNAL_ERROR);
         goto err;
     }
@@ -819,11 +756,6 @@
 
     ret = 1;
  err:
-    OPENSSL_cleanse(binderkey, sizeof(binderkey));
-    OPENSSL_cleanse(finishedkey, sizeof(finishedkey));
-    EVP_PKEY_free(mackey);
-    EVP_MD_CTX_free(mctx);
-
     return ret;
 #else
     return 1;
diff --git a/ssl/statem/extensions_srvr.c b/ssl/statem/extensions_srvr.c
index 1e10a10..314cd5a 100644
--- a/ssl/statem/extensions_srvr.c
+++ b/ssl/statem/extensions_srvr.c
@@ -655,10 +655,9 @@
         return 0;
     }
 
-    if (!s->hit
-            && !PACKET_memdup(&supported_groups_list,
-                              &s->session->ext.supportedgroups,
-                              &s->session->ext.supportedgroups_len)) {
+    if (!PACKET_memdup(&supported_groups_list,
+                       &s->session->ext.supportedgroups,
+                       &s->session->ext.supportedgroups_len)) {
         *al = SSL_AD_DECODE_ERROR;
         return 0;
     }
@@ -680,6 +679,96 @@
     return 1;
 }
 
+int tls_parse_ctos_psk(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al)
+{
+    PACKET identities, binders, binder;
+    size_t binderoffset, hashsize;
+    SSL_SESSION *sess = NULL;
+    unsigned int id, i;
+    const EVP_MD *md = NULL;
+
+    if (!PACKET_get_length_prefixed_2(pkt, &identities)) {
+        *al = SSL_AD_DECODE_ERROR;
+        return 0;
+    }
+
+    for (id = 0; PACKET_remaining(&identities) != 0; id++) {
+        PACKET identity;
+        unsigned long ticket_age;
+        int ret;
+
+        if (!PACKET_get_length_prefixed_2(&identities, &identity)
+                || !PACKET_get_net_4(&identities, &ticket_age)) {
+            *al = SSL_AD_DECODE_ERROR;
+            return 0;
+        }
+
+        ret = tls_decrypt_ticket(s, PACKET_data(&identity),
+                                 PACKET_remaining(&identity), NULL, 0, &sess);
+        if (ret == TICKET_FATAL_ERR_MALLOC || ret == TICKET_FATAL_ERR_OTHER) {
+            *al = SSL_AD_INTERNAL_ERROR;
+            return 0;
+        }
+        if (ret == TICKET_NO_DECRYPT)
+            continue;
+
+        md = ssl_cipher_get_handshake_md(sess->cipher_id);
+        if (md == NULL) {
+            /*
+             * Don't recognise this cipher so we can't use the session.
+             * Ignore it
+             */
+            SSL_SESSION_free(sess);
+            sess = NULL;
+            continue;
+        }
+
+        /*
+         * TODO(TLS1.3): Somehow we need to handle the case of a ticket renewal.
+         * Ignored for now
+         */
+
+        break;
+    }
+
+    if (sess == NULL)
+        return 1;
+
+    binderoffset = PACKET_data(pkt) - (const unsigned char *)s->init_buf->data;
+
+    hashsize = EVP_MD_size(md);
+
+    if (!PACKET_get_length_prefixed_2(pkt, &binders)) {
+        *al = SSL_AD_DECODE_ERROR;
+        goto err;
+    }
+
+    for (i = 0; i <= id; i++) {
+        if (!PACKET_get_length_prefixed_1(&binders, &binder)) {
+            *al = SSL_AD_DECODE_ERROR;
+            goto err;
+        }
+    }
+
+    if (PACKET_remaining(&binder) != hashsize
+            || tls_psk_do_binder(s, md,
+                                 (const unsigned char *)s->init_buf->data,
+                                 binderoffset, PACKET_data(&binder), NULL,
+                                 sess, 0) != 1) {
+        *al = SSL_AD_DECODE_ERROR;
+        SSLerr(SSL_F_TLS_PARSE_CTOS_PSK, ERR_R_INTERNAL_ERROR);
+        goto err;
+    }
+
+    sess->ext.tick_identity = id;
+    SSL_SESSION_free(s->session);
+    s->session = sess;
+
+    return 1;
+err:
+    return 0;
+}
+
 /*
  * Add the server's renegotiation binding
  */
diff --git a/ssl/statem/statem_locl.h b/ssl/statem/statem_locl.h
index 99f67e5..8079f30 100644
--- a/ssl/statem/statem_locl.h
+++ b/ssl/statem/statem_locl.h
@@ -166,6 +166,12 @@
 __owur int tls_construct_extensions(SSL *s, WPACKET *pkt, unsigned int context,
                                     X509 *x, size_t chainidx, int *al);
 
+__owur int tls_psk_do_binder(SSL *s, const EVP_MD *md,
+                             const unsigned char *msgstart,
+                             size_t binderoffset, const unsigned char *binderin,
+                             unsigned char *binderout,
+                             SSL_SESSION *sess, int sign);
+
 /* Server Extension processing */
 int tls_parse_ctos_renegotiate(SSL *s, PACKET *pkt, X509 *x, size_t chainidx,
                                int *al);
@@ -202,6 +208,7 @@
 int tls_parse_ctos_ems(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al);
 int tls_parse_ctos_psk_kex_modes(SSL *s, PACKET *pkt, X509 *x, size_t chainidx,
                                  int *al);
+int tls_parse_ctos_psk(SSL *s, PACKET *pkt, X509 *x, size_t chainidx, int *al);
 
 int tls_construct_stoc_renegotiate(SSL *s, WPACKET *pkt, X509 *x,
                                    size_t chainidx, int *al);