Harmonise setting the header and closing construction

Ensure all message types work the same way including CCS so that the state
machine doesn't need to know about special cases. Put all the special logic
into ssl_set_handshake_header() and ssl_close_construct_packet().

Reviewed-by: Rich Salz <rsalz@openssl.org>
diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c
index 630c94d..d19b97a 100644
--- a/ssl/s3_lib.c
+++ b/ssl/s3_lib.c
@@ -2779,6 +2779,10 @@
 
 int ssl3_set_handshake_header(SSL *s, WPACKET *pkt, int htype)
 {
+    /* No header in the event of a CCS */
+    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC)
+        return 1;
+
     /* Set the content type and 3 bytes for the message len */
     if (!WPACKET_put_bytes_u8(pkt, htype)
             || !WPACKET_start_sub_packet_u24(pkt))
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 06cf6e6..8a7e1a9 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -1586,7 +1586,7 @@
     /* Set the handshake header */
     int (*set_handshake_header) (SSL *s, WPACKET *pkt, int type);
     /* Close construction of the handshake message */
-    int (*close_construct_packet) (SSL *s, WPACKET *pkt);
+    int (*close_construct_packet) (SSL *s, WPACKET *pkt, int htype);
     /* Write out handshake message */
     int (*do_write) (SSL *s);
 } SSL3_ENC_METHOD;
@@ -1596,8 +1596,8 @@
         (((unsigned char *)s->init_buf->data) + s->method->ssl3_enc->hhlen)
 # define ssl_set_handshake_header(s, pkt, htype) \
         s->method->ssl3_enc->set_handshake_header((s), (pkt), (htype))
-# define ssl_close_construct_packet(s, pkt) \
-        s->method->ssl3_enc->close_construct_packet((s), (pkt))
+# define ssl_close_construct_packet(s, pkt, htype) \
+        s->method->ssl3_enc->close_construct_packet((s), (pkt), (htype))
 # define ssl_do_write(s)  s->method->ssl3_enc->do_write(s)
 
 /* Values for enc_flags */
@@ -1901,9 +1901,9 @@
 __owur long ssl3_default_timeout(void);
 
 __owur int ssl3_set_handshake_header(SSL *s, WPACKET *pkt, int htype);
-__owur int tls_close_construct_packet(SSL *s, WPACKET *pkt);
+__owur int tls_close_construct_packet(SSL *s, WPACKET *pkt, int htype);
 __owur int dtls1_set_handshake_header(SSL *s, WPACKET *pkt, int htype);
-__owur int dtls1_close_construct_packet(SSL *s, WPACKET *pkt);
+__owur int dtls1_close_construct_packet(SSL *s, WPACKET *pkt, int htype);
 __owur int ssl3_handshake_write(SSL *s);
 
 __owur int ssl_allow_compression(SSL *s);
diff --git a/ssl/statem/statem_clnt.c b/ssl/statem/statem_clnt.c
index 18eaf32..52c07ea 100644
--- a/ssl/statem/statem_clnt.c
+++ b/ssl/statem/statem_clnt.c
@@ -516,69 +516,69 @@
     int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
     int ret = 1, mt;
 
-    if (st->hand_state == TLS_ST_CW_CHANGE) {
-        /* Special case becase it is a different content type */
+    switch (st->hand_state) {
+    default:
+        /* Shouldn't happen */
+        return 0;
+
+    case TLS_ST_CW_CHANGE:
         if (SSL_IS_DTLS(s))
-            return dtls_construct_change_cipher_spec(s, pkt);
+            confunc = dtls_construct_change_cipher_spec;
+        else
+            confunc = tls_construct_change_cipher_spec;
+        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
+        break;
 
-        return tls_construct_change_cipher_spec(s, pkt);
-    } else {
-        switch (st->hand_state) {
-        default:
-            /* Shouldn't happen */
-            return 0;
+    case TLS_ST_CW_CLNT_HELLO:
+        confunc = tls_construct_client_hello;
+        mt = SSL3_MT_CLIENT_HELLO;
+        break;
 
-        case TLS_ST_CW_CLNT_HELLO:
-            confunc = tls_construct_client_hello;
-            mt = SSL3_MT_CLIENT_HELLO;
-            break;
+    case TLS_ST_CW_CERT:
+        confunc = tls_construct_client_certificate;
+        mt = SSL3_MT_CERTIFICATE;
+        break;
 
-        case TLS_ST_CW_CERT:
-            confunc = tls_construct_client_certificate;
-            mt = SSL3_MT_CERTIFICATE;
-            break;
+    case TLS_ST_CW_KEY_EXCH:
+        confunc = tls_construct_client_key_exchange;
+        mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
+        break;
 
-        case TLS_ST_CW_KEY_EXCH:
-            confunc = tls_construct_client_key_exchange;
-            mt = SSL3_MT_CLIENT_KEY_EXCHANGE;
-            break;
-
-        case TLS_ST_CW_CERT_VRFY:
-            confunc = tls_construct_client_verify;
-            mt = SSL3_MT_CERTIFICATE_VERIFY;
-            break;
+    case TLS_ST_CW_CERT_VRFY:
+        confunc = tls_construct_client_verify;
+        mt = SSL3_MT_CERTIFICATE_VERIFY;
+        break;
 
 #if !defined(OPENSSL_NO_NEXTPROTONEG)
-        case TLS_ST_CW_NEXT_PROTO:
-            confunc = tls_construct_next_proto;
-            mt = SSL3_MT_NEXT_PROTO;
-            break;
+    case TLS_ST_CW_NEXT_PROTO:
+        confunc = tls_construct_next_proto;
+        mt = SSL3_MT_NEXT_PROTO;
+        break;
 #endif
-        case TLS_ST_CW_FINISHED:
-            mt = SSL3_MT_FINISHED;
-            break;
-        }
+    case TLS_ST_CW_FINISHED:
+        mt = SSL3_MT_FINISHED;
+        break;
+    }
 
-        if (!ssl_set_handshake_header(s, pkt, mt)) {
-            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ssl_set_handshake_header(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
 
-        if (st->hand_state == TLS_ST_CW_FINISHED)
-            ret = tls_construct_finished(s, pkt,
-                                         s->method->
-                                         ssl3_enc->client_finished_label,
-                                         s->method->
-                                         ssl3_enc->client_finished_label_len);
-        else
-            ret = confunc(s, pkt);
+    if (st->hand_state == TLS_ST_CW_FINISHED)
+        ret = tls_construct_finished(s, pkt,
+                                     s->method->
+                                     ssl3_enc->client_finished_label,
+                                     s->method->
+                                     ssl3_enc->client_finished_label_len);
+    else
+        ret = confunc(s, pkt);
 
-        if (!ret || !ssl_close_construct_packet(s, pkt)) {
-            SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    if (!ret || !ssl_close_construct_packet(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_CLIENT_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
     }
     return 1;
 }
diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c
index cc016da..5b90c56 100644
--- a/ssl/statem/statem_dtls.c
+++ b/ssl/statem/statem_dtls.c
@@ -874,41 +874,16 @@
  */
 int dtls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)
 {
-    if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS)) {
-        SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-        goto err;
-    }
-
-    s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
-    s->init_num = DTLS1_CCS_HEADER_LENGTH;
-
     if (s->version == DTLS1_BAD_VER) {
         s->d1->next_handshake_write_seq++;
 
         if (!WPACKET_put_bytes_u16(pkt, s->d1->handshake_write_seq)) {
             SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-            goto err;
+            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
         }
-
-        s->init_num += 2;
-    }
-
-    s->init_off = 0;
-
-    dtls1_set_message_header_int(s, SSL3_MT_CCS, 0,
-                                 s->d1->handshake_write_seq, 0, 0);
-
-    /* buffer the message to handle re-xmits */
-    if (!dtls1_buffer_message(s, 1)) {
-        SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-        goto err    ;
     }
 
     return 1;
-
- err:
-    ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-    return 0;
 }
 
 #ifndef OPENSSL_NO_SCTP
@@ -1206,35 +1181,48 @@
 {
     unsigned char *header;
 
-    dtls1_set_message_header(s, htype, 0, 0, 0);
-
-    /*
-     * We allocate space at the start for the message header. This gets filled
-     * in later
-     */
-    if (!WPACKET_allocate_bytes(pkt, DTLS1_HM_HEADER_LENGTH, &header)
-            || !WPACKET_start_sub_packet(pkt))
-        return 0;
+    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) {
+        s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
+        dtls1_set_message_header_int(s, SSL3_MT_CCS, 0,
+                                     s->d1->handshake_write_seq, 0, 0);
+        if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS))
+            return 0;
+    } else {
+        dtls1_set_message_header(s, htype, 0, 0, 0);
+        /*
+         * We allocate space at the start for the message header. This gets
+         * filled in later
+         */
+        if (!WPACKET_allocate_bytes(pkt, DTLS1_HM_HEADER_LENGTH, &header)
+                || !WPACKET_start_sub_packet(pkt))
+            return 0;
+    }
 
     return 1;
 }
 
-int dtls1_close_construct_packet(SSL *s, WPACKET *pkt)
+int dtls1_close_construct_packet(SSL *s, WPACKET *pkt, int htype)
 {
     size_t msglen;
 
-    if (!WPACKET_close(pkt)
+    if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt))
             || !WPACKET_get_length(pkt, &msglen)
             || msglen > INT_MAX)
         return 0;
-    s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
+
+    if (htype != SSL3_MT_CHANGE_CIPHER_SPEC) {
+        s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
+        s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
+    }
     s->init_num = (int)msglen;
     s->init_off = 0;
 
-    /* Buffer the message to handle re-xmits */
-    if (!dtls1_buffer_message(s, 0))
-        return 0;
+    if (htype != DTLS1_MT_HELLO_VERIFY_REQUEST) {
+        /* Buffer the message to handle re-xmits */
+        if (!dtls1_buffer_message(s, htype == SSL3_MT_CHANGE_CIPHER_SPEC
+                                     ? 1 : 0))
+            return 0;
+    }
 
     return 1;
 }
diff --git a/ssl/statem/statem_lib.c b/ssl/statem/statem_lib.c
index cac18cc..fa0032b 100644
--- a/ssl/statem/statem_lib.c
+++ b/ssl/statem/statem_lib.c
@@ -57,11 +57,11 @@
     return (0);
 }
 
-int tls_close_construct_packet(SSL *s, WPACKET *pkt)
+int tls_close_construct_packet(SSL *s, WPACKET *pkt, int htype)
 {
     size_t msglen;
 
-    if (!WPACKET_close(pkt)
+    if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt))
             || !WPACKET_get_length(pkt, &msglen)
             || msglen > INT_MAX)
         return 0;
@@ -260,9 +260,6 @@
         return 0;
     }
 
-    s->init_num = 1;
-    s->init_off = 0;
-
     return 1;
 }
 
diff --git a/ssl/statem/statem_srvr.c b/ssl/statem/statem_srvr.c
index 46bd5c7..78850a7 100644
--- a/ssl/statem/statem_srvr.c
+++ b/ssl/statem/statem_srvr.c
@@ -625,87 +625,90 @@
     int (*confunc) (SSL *s, WPACKET *pkt) = NULL;
     int ret = 1, mt;
 
-    if (st->hand_state == TLS_ST_SW_CHANGE) {
-        /* Special case becase it is a different content type */
+    switch (st->hand_state) {
+    default:
+        /* Shouldn't happen */
+        return 0;
+
+    case TLS_ST_SW_CHANGE:
         if (SSL_IS_DTLS(s))
-            return dtls_construct_change_cipher_spec(s, pkt);
+            confunc = dtls_construct_change_cipher_spec;
+        else
+            confunc = tls_construct_change_cipher_spec;
+        mt = SSL3_MT_CHANGE_CIPHER_SPEC;
+        break;
 
-        return tls_construct_change_cipher_spec(s, pkt);
-    } else if (st->hand_state == DTLS_ST_SW_HELLO_VERIFY_REQUEST) {
-        /* Special case because we don't call ssl_close_construct_packet() */
-        return dtls_construct_hello_verify_request(s, pkt);
-    } else {
-        switch (st->hand_state) {
-        default:
-            /* Shouldn't happen */
-            return 0;
+    case DTLS_ST_SW_HELLO_VERIFY_REQUEST:
+        confunc = dtls_construct_hello_verify_request;
+        mt = DTLS1_MT_HELLO_VERIFY_REQUEST;
+        break;
 
-        case TLS_ST_SW_HELLO_REQ:
-            /* No construction function needed */
-            mt = SSL3_MT_HELLO_REQUEST;
-            break;
+    case TLS_ST_SW_HELLO_REQ:
+        /* No construction function needed */
+        mt = SSL3_MT_HELLO_REQUEST;
+        break;
 
-        case TLS_ST_SW_SRVR_HELLO:
-            confunc = tls_construct_server_hello;
-            mt = SSL3_MT_SERVER_HELLO;
-            break;
+    case TLS_ST_SW_SRVR_HELLO:
+        confunc = tls_construct_server_hello;
+        mt = SSL3_MT_SERVER_HELLO;
+        break;
 
-        case TLS_ST_SW_CERT:
-            confunc = tls_construct_server_certificate;
-            mt = SSL3_MT_CERTIFICATE;
-            break;
+    case TLS_ST_SW_CERT:
+        confunc = tls_construct_server_certificate;
+        mt = SSL3_MT_CERTIFICATE;
+        break;
 
-        case TLS_ST_SW_KEY_EXCH:
-            confunc = tls_construct_server_key_exchange;
-            mt = SSL3_MT_SERVER_KEY_EXCHANGE;
-            break;
+    case TLS_ST_SW_KEY_EXCH:
+        confunc = tls_construct_server_key_exchange;
+        mt = SSL3_MT_SERVER_KEY_EXCHANGE;
+        break;
 
-        case TLS_ST_SW_CERT_REQ:
-            confunc = tls_construct_certificate_request;
-            mt = SSL3_MT_CERTIFICATE_REQUEST;
-            break;
+    case TLS_ST_SW_CERT_REQ:
+        confunc = tls_construct_certificate_request;
+        mt = SSL3_MT_CERTIFICATE_REQUEST;
+        break;
 
-        case TLS_ST_SW_SRVR_DONE:
-            confunc = tls_construct_server_done;
-            mt = SSL3_MT_SERVER_DONE;
-            break;
+    case TLS_ST_SW_SRVR_DONE:
+        confunc = tls_construct_server_done;
+        mt = SSL3_MT_SERVER_DONE;
+        break;
 
-        case TLS_ST_SW_SESSION_TICKET:
-            confunc = tls_construct_new_session_ticket;
-            mt = SSL3_MT_NEWSESSION_TICKET;
-            break;
+    case TLS_ST_SW_SESSION_TICKET:
+        confunc = tls_construct_new_session_ticket;
+        mt = SSL3_MT_NEWSESSION_TICKET;
+        break;
 
-        case TLS_ST_SW_CERT_STATUS:
-            confunc = tls_construct_cert_status;
-            mt = SSL3_MT_CERTIFICATE_STATUS;
-            break;
+    case TLS_ST_SW_CERT_STATUS:
+        confunc = tls_construct_cert_status;
+        mt = SSL3_MT_CERTIFICATE_STATUS;
+        break;
 
-        case TLS_ST_SW_FINISHED:
-            mt = SSL3_MT_FINISHED;
-            break;
-        }
-
-        if (!ssl_set_handshake_header(s, pkt, mt)) {
-            SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
-
-        if (st->hand_state == TLS_ST_SW_FINISHED)
-            ret = tls_construct_finished(s, pkt,
-                                         s->method->
-                                         ssl3_enc->server_finished_label,
-                                         s->method->
-                                         ssl3_enc->server_finished_label_len);
-        else if (confunc != NULL)
-            ret = confunc(s, pkt);
-
-        if (!ret || !ssl_close_construct_packet(s, pkt)) {
-            SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
-                   ERR_R_INTERNAL_ERROR);
-            return 0;
-        }
+    case TLS_ST_SW_FINISHED:
+        mt = SSL3_MT_FINISHED;
+        break;
     }
+
+    if (!ssl_set_handshake_header(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
+    if (st->hand_state == TLS_ST_SW_FINISHED)
+        ret = tls_construct_finished(s, pkt,
+                                     s->method->
+                                     ssl3_enc->server_finished_label,
+                                     s->method->
+                                     ssl3_enc->server_finished_label_len);
+    else if (confunc != NULL)
+        ret = confunc(s, pkt);
+
+    if (!ret || !ssl_close_construct_packet(s, pkt, mt)) {
+        SSLerr(SSL_F_OSSL_STATEM_SERVER_CONSTRUCT_MESSAGE,
+               ERR_R_INTERNAL_ERROR);
+        return 0;
+    }
+
     return 1;
 }
 
@@ -881,8 +884,6 @@
 
 int dtls_construct_hello_verify_request(SSL *s, WPACKET *pkt)
 {
-    size_t msglen;
-
     if (s->ctx->app_gen_cookie_cb == NULL ||
         s->ctx->app_gen_cookie_cb(s, s->d1->cookie,
                                   &(s->d1->cookie_len)) == 0 ||
@@ -892,27 +893,12 @@
         return 0;
     }
 
-    if (!ssl_set_handshake_header(s, pkt,
-                                         DTLS1_MT_HELLO_VERIFY_REQUEST)
-            || !dtls_raw_hello_verify_request(pkt, s->d1->cookie,
-                                              s->d1->cookie_len)
-               /*
-                * We don't call close_construct_packet() because we don't want
-                * to buffer this message
-                */
-            || !WPACKET_close(pkt)
-            || !WPACKET_get_length(pkt, &msglen)
-            || !WPACKET_finish(pkt)) {
+    if (!dtls_raw_hello_verify_request(pkt, s->d1->cookie,
+                                              s->d1->cookie_len)) {
         SSLerr(SSL_F_DTLS_CONSTRUCT_HELLO_VERIFY_REQUEST, ERR_R_INTERNAL_ERROR);
         return 0;
     }
 
-    /* number of bytes to write */
-    s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH;
-    s->init_num = (int)msglen;
-    s->init_off = 0;
-
     return 1;
 }
 
@@ -3002,8 +2988,7 @@
 
             /* Put timeout and length */
             if (!WPACKET_put_bytes_u32(pkt, 0)
-                    || !WPACKET_put_bytes_u16(pkt, 0)
-                    || !ssl_close_construct_packet(s, pkt)) {
+                    || !WPACKET_put_bytes_u16(pkt, 0)) {
                 SSLerr(SSL_F_TLS_CONSTRUCT_NEW_SESSION_TICKET,
                        ERR_R_INTERNAL_ERROR);
                 goto err;