diff options
Diffstat (limited to 'src/net/tls.c')
-rw-r--r-- | src/net/tls.c | 177 |
1 files changed, 110 insertions, 67 deletions
diff --git a/src/net/tls.c b/src/net/tls.c index fa4b58d49..f5bff7a49 100644 --- a/src/net/tls.c +++ b/src/net/tls.c @@ -36,6 +36,8 @@ #include <gpxe/xfer.h> #include <gpxe/open.h> #include <gpxe/filter.h> +#include <gpxe/asn1.h> +#include <gpxe/x509.h> #include <gpxe/tls.h> static int tls_send_plaintext ( struct tls_session *tls, unsigned int type, @@ -43,6 +45,33 @@ static int tls_send_plaintext ( struct tls_session *tls, unsigned int type, static void tls_clear_cipher ( struct tls_session *tls, struct tls_cipherspec *cipherspec ); +/****************************************************************************** + * + * Utility functions + * + ****************************************************************************** + */ + +/** + * Extract 24-bit field value + * + * @v field24 24-bit field + * @ret value Field value + * + * TLS uses 24-bit integers in several places, which are awkward to + * parse in C. + */ +static unsigned long tls_uint24 ( uint8_t field24[3] ) { + return ( ( field24[0] << 16 ) + ( field24[1] << 8 ) + field24[2] ); +} + +/****************************************************************************** + * + * Cleanup functions + * + ****************************************************************************** + */ + /** * Free TLS session * @@ -57,8 +86,7 @@ static void free_tls ( struct refcnt *refcnt ) { tls_clear_cipher ( tls, &tls->tx_cipherspec_pending ); tls_clear_cipher ( tls, &tls->rx_cipherspec ); tls_clear_cipher ( tls, &tls->rx_cipherspec_pending ); - free ( tls->rsa_mod ); - free ( tls->rsa_pub_exp ); + x509_free_rsa_public_key ( &tls->rsa ); free ( tls->rx_data ); /* Free TLS structure itself */ @@ -622,8 +650,8 @@ static int tls_send_client_hello ( struct tls_session *tls ) { static int tls_send_client_key_exchange ( struct tls_session *tls ) { /* FIXME: Hack alert */ RSA_CTX *rsa_ctx; - RSA_pub_key_new ( &rsa_ctx, tls->rsa_mod, tls->rsa_mod_len, - tls->rsa_pub_exp, tls->rsa_pub_exp_len ); + RSA_pub_key_new ( &rsa_ctx, tls->rsa.modulus, tls->rsa.modulus_len, + tls->rsa.exponent, tls->rsa.exponent_len ); struct { uint32_t type_length; uint16_t encrypted_pre_master_secret_len; @@ -641,8 +669,8 @@ static int tls_send_client_key_exchange ( struct tls_session *tls ) { DBGC ( tls, "RSA encrypting plaintext, modulus, exponent:\n" ); DBGC_HD ( tls, &tls->pre_master_secret, sizeof ( tls->pre_master_secret ) ); - DBGC_HD ( tls, tls->rsa_mod, tls->rsa_mod_len ); - DBGC_HD ( tls, tls->rsa_pub_exp, tls->rsa_pub_exp_len ); + DBGC_HD ( tls, tls->rsa.modulus, tls->rsa.modulus_len ); + DBGC_HD ( tls, tls->rsa.exponent, tls->rsa.exponent_len ); RSA_encrypt ( rsa_ctx, ( const uint8_t * ) &tls->pre_master_secret, sizeof ( tls->pre_master_secret ), key_xchg.encrypted_pre_master_secret, 0 ); @@ -761,17 +789,16 @@ static int tls_new_alert ( struct tls_session *tls, void *data, size_t len ) { } /** - * Receive new Server Hello record + * Receive new Server Hello handshake record * * @v tls TLS session - * @v data Plaintext record - * @v len Length of plaintext record + * @v data Plaintext handshake record + * @v len Length of plaintext handshake record * @ret rc Return status code */ static int tls_new_server_hello ( struct tls_session *tls, void *data, size_t len ) { struct { - uint32_t type_length; uint16_t version; uint8_t random[32]; uint8_t session_id_len; @@ -818,72 +845,74 @@ static int tls_new_server_hello ( struct tls_session *tls, } /** - * Receive new Certificate record + * Receive new Certificate handshake record * * @v tls TLS session - * @v data Plaintext record - * @v len Length of plaintext record + * @v data Plaintext handshake record + * @v len Length of plaintext handshake record * @ret rc Return status code */ static int tls_new_certificate ( struct tls_session *tls, void *data, size_t len ) { struct { - uint32_t type_length; uint8_t length[3]; - uint8_t first_cert_length[3]; - uint8_t asn1_start[0]; + uint8_t certificates[0]; } __attribute__ (( packed )) *certificate = data; - uint8_t *cert = certificate->asn1_start; - int offset = 0; - - /* FIXME */ - (void) len; - - if (asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 || - asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 || - asn1_skip_obj(cert, &offset, ASN1_EXPLICIT_TAG) || - asn1_skip_obj(cert, &offset, ASN1_INTEGER) || - asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) || - asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) || - asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) || - asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) || - asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 || - asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) || - asn1_next_obj(cert, &offset, ASN1_BIT_STRING) < 0) { - DBGC ( tls, "TLS %p invalid certificate\n", tls ); - DBGC_HD ( tls, cert + offset, 64 ); - return -EPERM; - } - - offset++; - - if (asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0) { - DBGC ( tls, "TLS %p invalid certificate\n", tls ); - DBGC_HD ( tls, cert + offset, 64 ); - return -EPERM; + struct { + uint8_t length[3]; + uint8_t certificate[0]; + } __attribute__ (( packed )) *element = + ( ( void * ) certificate->certificates ); + size_t elements_len = tls_uint24 ( certificate->length ); + void *end = ( certificate->certificates + elements_len ); + struct asn1_cursor cursor; + int rc; + + /* Sanity check */ + if ( end != ( data + len ) ) { + DBGC ( tls, "TLS %p received overlength Server Certificate\n", + tls ); + DBGC_HD ( tls, data, len ); + return -EINVAL; } - - tls->rsa_mod_len = asn1_get_int(cert, &offset, &tls->rsa_mod); - tls->rsa_pub_exp_len = asn1_get_int(cert, &offset, &tls->rsa_pub_exp); - - DBGC_HD ( tls, tls->rsa_mod, tls->rsa_mod_len ); - DBGC_HD ( tls, tls->rsa_pub_exp, tls->rsa_pub_exp_len ); - return 0; + /* Traverse certificate chain */ + do { + cursor.data = element->certificate; + cursor.len = tls_uint24 ( element->length ); + if ( ( cursor.data + cursor.len ) > end ) { + DBGC ( tls, "TLS %p received corrupt Server " + "Certificate\n", tls ); + DBGC_HD ( tls, data, len ); + return -EINVAL; + } + + // HACK + if ( ( rc = x509_rsa_public_key ( &cursor, + &tls->rsa ) ) != 0 ) { + DBGC ( tls, "TLS %p cannot determine RSA public key: " + "%s\n", tls, strerror ( rc ) ); + return rc; + } + return 0; + + element = ( cursor.data + cursor.len ); + } while ( element != end ); + + return -EINVAL; } /** - * Receive new Server Hello Done record + * Receive new Server Hello Done handshake record * * @v tls TLS session - * @v data Plaintext record - * @v len Length of plaintext record + * @v data Plaintext handshake record + * @v len Length of plaintext handshake record * @ret rc Return status code */ static int tls_new_server_hello_done ( struct tls_session *tls, void *data, size_t len ) { struct { - uint32_t type_length; char next[0]; } __attribute__ (( packed )) *hello_done = data; void *end = hello_done->next; @@ -910,11 +939,11 @@ static int tls_new_server_hello_done ( struct tls_session *tls, } /** - * Receive new Finished record + * Receive new Finished handshake record * * @v tls TLS session - * @v data Plaintext record - * @v len Length of plaintext record + * @v data Plaintext handshake record + * @v len Length of plaintext handshake record * @ret rc Return status code */ static int tls_new_finished ( struct tls_session *tls, @@ -937,33 +966,47 @@ static int tls_new_finished ( struct tls_session *tls, */ static int tls_new_handshake ( struct tls_session *tls, void *data, size_t len ) { - uint8_t *type = data; + struct { + uint8_t type; + uint8_t length[3]; + uint8_t payload[0]; + } __attribute__ (( packed )) *handshake = data; + void *payload = &handshake->payload; + size_t payload_len = tls_uint24 ( handshake->length ); + void *end = ( payload + payload_len ); int rc; - switch ( *type ) { + /* Sanity check */ + if ( end != ( data + len ) ) { + DBGC ( tls, "TLS %p received overlength Handshake\n", tls ); + DBGC_HD ( tls, data, len ); + return -EINVAL; + } + + switch ( handshake->type ) { case TLS_SERVER_HELLO: - rc = tls_new_server_hello ( tls, data, len ); + rc = tls_new_server_hello ( tls, payload, payload_len ); break; case TLS_CERTIFICATE: - rc = tls_new_certificate ( tls, data, len ); + rc = tls_new_certificate ( tls, payload, payload_len ); break; case TLS_SERVER_HELLO_DONE: - rc = tls_new_server_hello_done ( tls, data, len ); + rc = tls_new_server_hello_done ( tls, payload, payload_len ); break; case TLS_FINISHED: - rc = tls_new_finished ( tls, data, len ); + rc = tls_new_finished ( tls, payload, payload_len ); break; default: DBGC ( tls, "TLS %p ignoring handshake type %d\n", - tls, *type ); + tls, handshake->type ); rc = 0; break; } /* Add to handshake digest (except for Hello Requests, which - * are explicitly excludede). + * are explicitly excluded). */ - if ( *type != TLS_HELLO_REQUEST ) + if ( handshake->type != TLS_HELLO_REQUEST ) tls_add_handshake ( tls, data, len ); return rc; |