Split out CKE construction PSK pre-amble and RSA into a separate function

The tls_construct_client_key_exchange() function is too long. This splits
out the construction of the PSK pre-amble into a separate function as well
as the RSA construction.

Reviewed-by: Richard Levitte <levitte@openssl.org>
This commit is contained in:
Matt Caswell 2016-07-07 14:42:27 +01:00
parent 0bce0b02d8
commit 13c0ec4ad4

View File

@ -2012,38 +2012,27 @@ MSG_PROCESS_RETURN tls_process_server_done(SSL *s, PACKET *pkt)
return MSG_PROCESS_FINISHED_READING; return MSG_PROCESS_FINISHED_READING;
} }
int tls_construct_client_key_exchange(SSL *s) static int tls_construct_cke_psk_preamble(SSL *s, unsigned char **p,
size_t *pskhdrlen, int *al)
{ {
unsigned char *p;
int n;
#ifndef OPENSSL_NO_PSK #ifndef OPENSSL_NO_PSK
size_t pskhdrlen = 0; int ret = 0;
#endif
unsigned long alg_k;
alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
p = ssl_handshake_start(s);
#ifndef OPENSSL_NO_PSK
if (alg_k & SSL_PSK) {
int psk_err = 1;
/* /*
* The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a
* \0-terminated identity. The last byte is for us for simulating * \0-terminated identity. The last byte is for us for simulating
* strnlen. * strnlen.
*/ */
char identity[PSK_MAX_IDENTITY_LEN + 1]; char identity[PSK_MAX_IDENTITY_LEN + 1];
size_t identitylen; size_t identitylen = 0;
unsigned char psk[PSK_MAX_PSK_LEN]; unsigned char psk[PSK_MAX_PSK_LEN];
unsigned char *tmppsk; unsigned char *tmppsk = NULL;
char *tmpidentity; char *tmpidentity = NULL;
size_t psklen; size_t psklen = 0;
if (s->psk_client_callback == NULL) { if (s->psk_client_callback == NULL) {
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
SSL_R_PSK_NO_CLIENT_CB); SSL_R_PSK_NO_CLIENT_CB);
*al = SSL_AD_INTERNAL_ERROR;
goto err; goto err;
} }
@ -2056,58 +2045,62 @@ int tls_construct_client_key_exchange(SSL *s)
if (psklen > PSK_MAX_PSK_LEN) { if (psklen > PSK_MAX_PSK_LEN) {
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_INTERNAL_ERROR); ERR_R_INTERNAL_ERROR);
goto psk_err; *al = SSL_AD_HANDSHAKE_FAILURE;
goto err;
} else if (psklen == 0) { } else if (psklen == 0) {
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
SSL_R_PSK_IDENTITY_NOT_FOUND); SSL_R_PSK_IDENTITY_NOT_FOUND);
goto psk_err; *al = SSL_AD_HANDSHAKE_FAILURE;
goto err;
} }
identitylen = strlen(identity); identitylen = strlen(identity);
if (identitylen > PSK_MAX_IDENTITY_LEN) { if (identitylen > PSK_MAX_IDENTITY_LEN) {
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_INTERNAL_ERROR); ERR_R_INTERNAL_ERROR);
goto psk_err; *al = SSL_AD_HANDSHAKE_FAILURE;
goto err;
} }
tmppsk = OPENSSL_memdup(psk, psklen); tmppsk = OPENSSL_memdup(psk, psklen);
tmpidentity = OPENSSL_strdup(identity); tmpidentity = OPENSSL_strdup(identity);
if (tmppsk == NULL || tmpidentity == NULL) { if (tmppsk == NULL || tmpidentity == NULL) {
OPENSSL_cleanse(identity, sizeof(identity)); SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
OPENSSL_cleanse(psk, psklen); *al = SSL_AD_INTERNAL_ERROR;
OPENSSL_clear_free(tmppsk, psklen); goto err;
OPENSSL_clear_free(tmpidentity, identitylen);
goto memerr;
} }
OPENSSL_free(s->s3->tmp.psk); OPENSSL_free(s->s3->tmp.psk);
s->s3->tmp.psk = tmppsk; s->s3->tmp.psk = tmppsk;
s->s3->tmp.psklen = psklen; s->s3->tmp.psklen = psklen;
tmppsk = NULL;
OPENSSL_free(s->session->psk_identity); OPENSSL_free(s->session->psk_identity);
s->session->psk_identity = tmpidentity; s->session->psk_identity = tmpidentity;
s2n(identitylen, p); tmpidentity = NULL;
memcpy(p, identity, identitylen); s2n(identitylen, *p);
pskhdrlen = 2 + identitylen; memcpy(*p, identity, identitylen);
p += identitylen; *pskhdrlen = 2 + identitylen;
psk_err = 0; *p += identitylen;
psk_err:
ret = 1;
err:
OPENSSL_cleanse(psk, psklen); OPENSSL_cleanse(psk, psklen);
OPENSSL_cleanse(identity, sizeof(identity)); OPENSSL_cleanse(identity, sizeof(identity));
if (psk_err != 0) { OPENSSL_clear_free(tmppsk, psklen);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); OPENSSL_clear_free(tmpidentity, identitylen);
goto err;
}
}
if (alg_k & SSL_kPSK) {
n = 0;
} else
#endif
/* Fool emacs indentation */ return ret;
if (0) { #else
} SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
*al = SSL_AD_INTERNAL_ERROR;
return 0;
#endif
}
static int tls_construct_cke_rsa(SSL *s, unsigned char **p, int *len, int *al)
{
#ifndef OPENSSL_NO_RSA #ifndef OPENSSL_NO_RSA
else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
unsigned char *q; unsigned char *q;
EVP_PKEY *pkey = NULL; EVP_PKEY *pkey = NULL;
EVP_PKEY_CTX *pctx = NULL; EVP_PKEY_CTX *pctx = NULL;
@ -2121,68 +2114,103 @@ psk_err:
*/ */
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_INTERNAL_ERROR); ERR_R_INTERNAL_ERROR);
goto err; return 0;
} }
pkey = X509_get0_pubkey(s->session->peer); pkey = X509_get0_pubkey(s->session->peer);
if (EVP_PKEY_get0_RSA(pkey) == NULL) { if (EVP_PKEY_get0_RSA(pkey) == NULL) {
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_INTERNAL_ERROR); ERR_R_INTERNAL_ERROR);
goto err; return 0;
} }
pmslen = SSL_MAX_MASTER_KEY_LENGTH; pmslen = SSL_MAX_MASTER_KEY_LENGTH;
pms = OPENSSL_malloc(pmslen); pms = OPENSSL_malloc(pmslen);
if (pms == NULL) if (pms == NULL) {
goto memerr; SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_MALLOC_FAILURE);
*al = SSL_AD_INTERNAL_ERROR;
return 0;
}
pms[0] = s->client_version >> 8; pms[0] = s->client_version >> 8;
pms[1] = s->client_version & 0xff; pms[1] = s->client_version & 0xff;
if (RAND_bytes(pms + 2, pmslen - 2) <= 0) { if (RAND_bytes(pms + 2, pmslen - 2) <= 0) {
OPENSSL_clear_free(pms, pmslen);
goto err; goto err;
} }
q = p; q = *p;
/* Fix buf for TLS and beyond */ /* Fix buf for TLS and beyond */
if (s->version > SSL3_VERSION) if (s->version > SSL3_VERSION)
p += 2; *p += 2;
pctx = EVP_PKEY_CTX_new(pkey, NULL); pctx = EVP_PKEY_CTX_new(pkey, NULL);
if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0 if (pctx == NULL || EVP_PKEY_encrypt_init(pctx) <= 0
|| EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) { || EVP_PKEY_encrypt(pctx, NULL, &enclen, pms, pmslen) <= 0) {
OPENSSL_clear_free(pms, pmslen);
EVP_PKEY_CTX_free(pctx);
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
ERR_R_EVP_LIB); ERR_R_EVP_LIB);
goto err; goto err;
} }
if (EVP_PKEY_encrypt(pctx, p, &enclen, pms, pmslen) <= 0) { if (EVP_PKEY_encrypt(pctx, *p, &enclen, pms, pmslen) <= 0) {
OPENSSL_clear_free(pms, pmslen);
EVP_PKEY_CTX_free(pctx);
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE,
SSL_R_BAD_RSA_ENCRYPT); SSL_R_BAD_RSA_ENCRYPT);
goto err; goto err;
} }
n = enclen; *len = enclen;
EVP_PKEY_CTX_free(pctx); EVP_PKEY_CTX_free(pctx);
pctx = NULL; pctx = NULL;
# ifdef PKCS1_CHECK # ifdef PKCS1_CHECK
if (s->options & SSL_OP_PKCS1_CHECK_1) if (s->options & SSL_OP_PKCS1_CHECK_1)
p[1]++; (*p)[1]++;
if (s->options & SSL_OP_PKCS1_CHECK_2) if (s->options & SSL_OP_PKCS1_CHECK_2)
tmp_buf[0] = 0x70; tmp_buf[0] = 0x70;
# endif # endif
/* Fix buf for TLS and beyond */ /* Fix buf for TLS and beyond */
if (s->version > SSL3_VERSION) { if (s->version > SSL3_VERSION) {
s2n(n, q); s2n(*len, q);
n += 2; *len += 2;
} }
s->s3->tmp.pms = pms; s->s3->tmp.pms = pms;
s->s3->tmp.pmslen = pmslen; s->s3->tmp.pmslen = pmslen;
}
return 1;
err:
OPENSSL_clear_free(pms, pmslen);
EVP_PKEY_CTX_free(pctx);
return 0;
#else
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR);
*al = SSL_AD_INTERNAL_ERROR;
return 0;
#endif #endif
}
int tls_construct_client_key_exchange(SSL *s)
{
unsigned char *p;
int n;
size_t pskhdrlen = 0;
unsigned long alg_k;
int al = -1;
alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
p = ssl_handshake_start(s);
if ((alg_k & SSL_PSK)
&& !tls_construct_cke_psk_preamble(s, &p, &pskhdrlen, &al))
goto err;
if (alg_k & SSL_kPSK) {
n = 0;
} else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) {
if (!tls_construct_cke_rsa(s, &p, &n, &al))
goto err;
}
#ifndef OPENSSL_NO_DH #ifndef OPENSSL_NO_DH
else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) { else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) {
DH *dh_clnt = NULL; DH *dh_clnt = NULL;
@ -2421,9 +2449,7 @@ psk_err:
goto err; goto err;
} }
#ifndef OPENSSL_NO_PSK
n += pskhdrlen; n += pskhdrlen;
#endif
if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) { if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
@ -2433,9 +2459,11 @@ psk_err:
return 1; return 1;
memerr: memerr:
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE); SSLerr(SSL_F_TLS_CONSTRUCT_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE);
al = SSL_AD_INTERNAL_ERROR;
err: err:
if (al != -1)
ssl3_send_alert(s, SSL3_AL_FATAL, al);
OPENSSL_clear_free(s->s3->tmp.pms, s->s3->tmp.pmslen); OPENSSL_clear_free(s->s3->tmp.pms, s->s3->tmp.pmslen);
s->s3->tmp.pms = NULL; s->s3->tmp.pms = NULL;
#ifndef OPENSSL_NO_PSK #ifndef OPENSSL_NO_PSK