Commit b3992377 authored by rsleevi@chromium.org's avatar rsleevi@chromium.org

Improve support for requesting client certs in tlslite

Currently, tlslite only supports the caller passing in a list of CAs pre-encoded for the TLS CertificateRequest message. This CL improves that, by providing a means of extracting the DER-encoded subject name from an X509 certificate, supplying a list of such names to tlslite's server routines, and having tlslite encode the list of CAs as part of the CertificateRequest.

BUG=47656, 47658
TEST=net_unittests

Review URL: http://codereview.chromium.org/3177015

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@56982 0039d316-1c4b-4281-b951-d872f2087c98
parent bbf7e53f
...@@ -62,7 +62,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): ...@@ -62,7 +62,7 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer):
"""This is a specialization of StoppableHTTPerver that add https support.""" """This is a specialization of StoppableHTTPerver that add https support."""
def __init__(self, server_address, request_hander_class, cert_path, def __init__(self, server_address, request_hander_class, cert_path,
ssl_client_auth): ssl_client_auth, ssl_client_cas):
s = open(cert_path).read() s = open(cert_path).read()
x509 = tlslite.api.X509() x509 = tlslite.api.X509()
x509.parse(s) x509.parse(s)
...@@ -70,6 +70,12 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): ...@@ -70,6 +70,12 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer):
s = open(cert_path).read() s = open(cert_path).read()
self.private_key = tlslite.api.parsePEMKey(s, private=True) self.private_key = tlslite.api.parsePEMKey(s, private=True)
self.ssl_client_auth = ssl_client_auth self.ssl_client_auth = ssl_client_auth
self.ssl_client_cas = []
for ca_file in ssl_client_cas:
s = open(ca_file).read()
x509 = tlslite.api.X509()
x509.parse(s)
self.ssl_client_cas.append(x509.subject)
self.session_cache = tlslite.api.SessionCache() self.session_cache = tlslite.api.SessionCache()
StoppableHTTPServer.__init__(self, server_address, request_hander_class) StoppableHTTPServer.__init__(self, server_address, request_hander_class)
...@@ -80,7 +86,8 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer): ...@@ -80,7 +86,8 @@ class HTTPSServer(tlslite.api.TLSSocketServerMixIn, StoppableHTTPServer):
tlsConnection.handshakeServer(certChain=self.cert_chain, tlsConnection.handshakeServer(certChain=self.cert_chain,
privateKey=self.private_key, privateKey=self.private_key,
sessionCache=self.session_cache, sessionCache=self.session_cache,
reqCert=self.ssl_client_auth) reqCert=self.ssl_client_auth,
reqCAs=self.ssl_client_cas)
tlsConnection.ignoreAbruptClose = True tlsConnection.ignoreAbruptClose = True
return True return True
except tlslite.api.TLSAbruptCloseError: except tlslite.api.TLSAbruptCloseError:
...@@ -1227,10 +1234,16 @@ def main(options, args): ...@@ -1227,10 +1234,16 @@ def main(options, args):
if options.cert: if options.cert:
# let's make sure the cert file exists. # let's make sure the cert file exists.
if not os.path.isfile(options.cert): if not os.path.isfile(options.cert):
print 'specified cert file not found: ' + options.cert + ' exiting...' print 'specified server cert file not found: ' + options.cert + \
' exiting...'
return return
for ca_cert in options.ssl_client_ca:
if not os.path.isfile(ca_cert):
print 'specified trusted client CA file not found: ' + ca_cert + \
' exiting...'
return
server = HTTPSServer(('127.0.0.1', port), TestPageHandler, options.cert, server = HTTPSServer(('127.0.0.1', port), TestPageHandler, options.cert,
options.ssl_client_auth) options.ssl_client_auth, options.ssl_client_ca)
print 'HTTPS server started on port %d...' % port print 'HTTPS server started on port %d...' % port
else: else:
server = StoppableHTTPServer(('127.0.0.1', port), TestPageHandler) server = StoppableHTTPServer(('127.0.0.1', port), TestPageHandler)
...@@ -1297,6 +1310,10 @@ if __name__ == '__main__': ...@@ -1297,6 +1310,10 @@ if __name__ == '__main__':
'the server should use.') 'the server should use.')
option_parser.add_option('', '--ssl-client-auth', action='store_true', option_parser.add_option('', '--ssl-client-auth', action='store_true',
help='Require SSL client auth on every connection.') help='Require SSL client auth on every connection.')
option_parser.add_option('', '--ssl-client-ca', action='append', default=[],
help='Specify that the client certificate request '
'should indicate that it supports the CA contained '
'in the specified certificate file')
option_parser.add_option('', '--file-root-url', default='/files/', option_parser.add_option('', '--file-root-url', default='/files/',
help='Specify a root URL for files served.') help='Specify a root URL for files served.')
option_parser.add_option('', '--never-die', default=False, option_parser.add_option('', '--never-die', default=False,
......
...@@ -10,3 +10,14 @@ Local Modifications: ...@@ -10,3 +10,14 @@ Local Modifications:
http://sourceforge.net/mailarchive/forum.php?thread_name=41C9B18B.2010201%40ag.com&forum_name=tlslite-users http://sourceforge.net/mailarchive/forum.php?thread_name=41C9B18B.2010201%40ag.com&forum_name=tlslite-users
- patches/python26.patch: Replace sha, md5 module imports with hashlib, as - patches/python26.patch: Replace sha, md5 module imports with hashlib, as
they are deprecated in Python 2.6 they are deprecated in Python 2.6
- patches/ca_request.patch: tlslite/X509.py was changed to obtain the
DER-encoded distinguished name for a certificate, without requiring any
addition libraries.
tlslite/utils/ASN1Parser.py was changed to allow obtaining the unparsed
data for an element in a SEQUENCE, in addition to providing the parsed
value (tag and length removed)
tlslite/messages.py was changed from accepting/returning a single byte
array in the CertificateRequest message for the CA names to accept a list
of byte arrays, each containing a DER-encoded distinguished name.
tlslite/TLSConnection.py was changed to take a list of such byte arrays
when creating a TLS server that will request client authentication.
Only in chromium: patches
diff -aur tlslite-0.3.8/tlslite/TLSConnection.py chromium/tlslite/TLSConnection.py
--- tlslite-0.3.8/tlslite/TLSConnection.py 2004-10-06 01:55:37.000000000 -0400
+++ chromium/tlslite/TLSConnection.py 2010-08-18 22:17:30.962786700 -0400
@@ -931,7 +931,8 @@
def handshakeServer(self, sharedKeyDB=None, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
- sessionCache=None, settings=None, checker=None):
+ sessionCache=None, settings=None, checker=None,
+ reqCAs=None):
"""Perform a handshake in the role of server.
This function performs an SSL or TLS handshake. Depending on
@@ -997,6 +998,11 @@
invoked to examine the other party's authentication
credentials, if the handshake completes succesfully.
+ @type reqCAs: list of L{array.array} of unsigned bytes
+ @param reqCAs: A collection of DER-encoded DistinguishedNames that
+ will be sent along with a certificate request. This does not affect
+ verification.
+
@raise socket.error: If a socket error occurs.
@raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
without a preceding alert.
@@ -1006,13 +1012,14 @@
"""
for result in self.handshakeServerAsync(sharedKeyDB, verifierDB,
certChain, privateKey, reqCert, sessionCache, settings,
- checker):
+ checker, reqCAs):
pass
def handshakeServerAsync(self, sharedKeyDB=None, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
- sessionCache=None, settings=None, checker=None):
+ sessionCache=None, settings=None, checker=None,
+ reqCAs=None):
"""Start a server handshake operation on the TLS connection.
This function returns a generator which behaves similarly to
@@ -1028,14 +1035,15 @@
sharedKeyDB=sharedKeyDB,
verifierDB=verifierDB, certChain=certChain,
privateKey=privateKey, reqCert=reqCert,
- sessionCache=sessionCache, settings=settings)
+ sessionCache=sessionCache, settings=settings,
+ reqCAs=reqCAs)
for result in self._handshakeWrapperAsync(handshaker, checker):
yield result
def _handshakeServerAsyncHelper(self, sharedKeyDB, verifierDB,
certChain, privateKey, reqCert, sessionCache,
- settings):
+ settings, reqCAs):
self._handshakeStart(client=False)
@@ -1045,6 +1053,8 @@
raise ValueError("Caller passed a certChain but no privateKey")
if privateKey and not certChain:
raise ValueError("Caller passed a privateKey but no certChain")
+ if reqCAs and not reqCert:
+ raise ValueError("Caller passed reqCAs but not reqCert")
if not settings:
settings = HandshakeSettings()
@@ -1380,7 +1390,9 @@
msgs.append(ServerHello().create(self.version, serverRandom,
sessionID, cipherSuite, certificateType))
msgs.append(Certificate(certificateType).create(serverCertChain))
- if reqCert:
+ if reqCert and reqCAs:
+ msgs.append(CertificateRequest().create([], reqCAs))
+ elif reqCert:
msgs.append(CertificateRequest())
msgs.append(ServerHelloDone())
for result in self._sendMsgs(msgs):
diff -aur tlslite-0.3.8/tlslite/X509.py chromium/tlslite/X509.py
--- tlslite-0.3.8/tlslite/X509.py 2004-03-19 21:43:19.000000000 -0400
+++ chromium/tlslite/X509.py 2010-08-18 22:17:30.967787000 -0400
@@ -13,11 +13,15 @@
@type publicKey: L{tlslite.utils.RSAKey.RSAKey}
@ivar publicKey: The subject public key from the certificate.
+
+ @type subject: L{array.array} of unsigned bytes
+ @ivar subject: The DER-encoded ASN.1 subject distinguished name.
"""
def __init__(self):
self.bytes = createByteArraySequence([])
self.publicKey = None
+ self.subject = None
def parse(self, s):
"""Parse a PEM-encoded X.509 certificate.
@@ -63,6 +67,10 @@
else:
subjectPublicKeyInfoIndex = 5
+ #Get the subject
+ self.subject = tbsCertificateP.getChildBytes(\
+ subjectPublicKeyInfoIndex - 1)
+
#Get the subjectPublicKeyInfo
subjectPublicKeyInfoP = tbsCertificateP.getChild(\
subjectPublicKeyInfoIndex)
diff -aur tlslite-0.3.8/tlslite/messages.py chromium/tlslite/messages.py
--- tlslite-0.3.8/tlslite/messages.py 2004-10-06 01:01:24.000000000 -0400
+++ chromium/tlslite/messages.py 2010-08-18 22:17:30.976787500 -0400
@@ -338,8 +338,7 @@
def __init__(self):
self.contentType = ContentType.handshake
self.certificate_types = []
- #treat as opaque bytes for now
- self.certificate_authorities = createByteArraySequence([])
+ self.certificate_authorities = []
def create(self, certificate_types, certificate_authorities):
self.certificate_types = certificate_types
@@ -349,7 +348,13 @@
def parse(self, p):
p.startLengthCheck(3)
self.certificate_types = p.getVarList(1, 1)
- self.certificate_authorities = p.getVarBytes(2)
+ ca_list_length = p.get(2)
+ index = 0
+ self.certificate_authorities = []
+ while index != ca_list_length:
+ ca_bytes = p.getVarBytes(2)
+ self.certificate_authorities.append(ca_bytes)
+ index += len(ca_bytes)+2
p.stopLengthCheck()
return self
@@ -357,7 +362,14 @@
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
trial)
w.addVarSeq(self.certificate_types, 1, 1)
- w.addVarSeq(self.certificate_authorities, 1, 2)
+ caLength = 0
+ #determine length
+ for ca_dn in self.certificate_authorities:
+ caLength += len(ca_dn)+2
+ w.add(caLength, 2)
+ #add bytes
+ for ca_dn in self.certificate_authorities:
+ w.addVarSeq(ca_dn, 1, 2)
return HandshakeMsg.postWrite(self, w, trial)
class ServerKeyExchange(HandshakeMsg):
diff -aur tlslite-0.3.8/tlslite/utils/ASN1Parser.py chromium/tlslite/utils/ASN1Parser.py
--- tlslite-0.3.8/tlslite/utils/ASN1Parser.py 2004-10-06 01:02:40.000000000 -0400
+++ chromium/tlslite/utils/ASN1Parser.py 2010-08-18 22:17:30.979787700 -0400
@@ -16,13 +16,16 @@
#Assuming this is a sequence...
def getChild(self, which):
+ return ASN1Parser(self.getChildBytes(which))
+
+ def getChildBytes(self, which):
p = Parser(self.value)
for x in range(which+1):
markIndex = p.index
p.get(1) #skip Type
length = self._getASN1Length(p)
p.getFixBytes(length)
- return ASN1Parser(p.bytes[markIndex : p.index])
+ return p.bytes[markIndex : p.index]
#Decode the ASN.1 DER length field
def _getASN1Length(self, p):
...@@ -931,7 +931,8 @@ class TLSConnection(TLSRecordLayer): ...@@ -931,7 +931,8 @@ class TLSConnection(TLSRecordLayer):
def handshakeServer(self, sharedKeyDB=None, verifierDB=None, def handshakeServer(self, sharedKeyDB=None, verifierDB=None,
certChain=None, privateKey=None, reqCert=False, certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None): sessionCache=None, settings=None, checker=None,
reqCAs=None):
"""Perform a handshake in the role of server. """Perform a handshake in the role of server.
This function performs an SSL or TLS handshake. Depending on This function performs an SSL or TLS handshake. Depending on
...@@ -997,6 +998,11 @@ class TLSConnection(TLSRecordLayer): ...@@ -997,6 +998,11 @@ class TLSConnection(TLSRecordLayer):
invoked to examine the other party's authentication invoked to examine the other party's authentication
credentials, if the handshake completes succesfully. credentials, if the handshake completes succesfully.
@type reqCAs: list of L{array.array} of unsigned bytes
@param reqCAs: A collection of DER-encoded DistinguishedNames that
will be sent along with a certificate request. This does not affect
verification.
@raise socket.error: If a socket error occurs. @raise socket.error: If a socket error occurs.
@raise tlslite.errors.TLSAbruptCloseError: If the socket is closed @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
without a preceding alert. without a preceding alert.
...@@ -1006,13 +1012,14 @@ class TLSConnection(TLSRecordLayer): ...@@ -1006,13 +1012,14 @@ class TLSConnection(TLSRecordLayer):
""" """
for result in self.handshakeServerAsync(sharedKeyDB, verifierDB, for result in self.handshakeServerAsync(sharedKeyDB, verifierDB,
certChain, privateKey, reqCert, sessionCache, settings, certChain, privateKey, reqCert, sessionCache, settings,
checker): checker, reqCAs):
pass pass
def handshakeServerAsync(self, sharedKeyDB=None, verifierDB=None, def handshakeServerAsync(self, sharedKeyDB=None, verifierDB=None,
certChain=None, privateKey=None, reqCert=False, certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None): sessionCache=None, settings=None, checker=None,
reqCAs=None):
"""Start a server handshake operation on the TLS connection. """Start a server handshake operation on the TLS connection.
This function returns a generator which behaves similarly to This function returns a generator which behaves similarly to
...@@ -1028,14 +1035,15 @@ class TLSConnection(TLSRecordLayer): ...@@ -1028,14 +1035,15 @@ class TLSConnection(TLSRecordLayer):
sharedKeyDB=sharedKeyDB, sharedKeyDB=sharedKeyDB,
verifierDB=verifierDB, certChain=certChain, verifierDB=verifierDB, certChain=certChain,
privateKey=privateKey, reqCert=reqCert, privateKey=privateKey, reqCert=reqCert,
sessionCache=sessionCache, settings=settings) sessionCache=sessionCache, settings=settings,
reqCAs=reqCAs)
for result in self._handshakeWrapperAsync(handshaker, checker): for result in self._handshakeWrapperAsync(handshaker, checker):
yield result yield result
def _handshakeServerAsyncHelper(self, sharedKeyDB, verifierDB, def _handshakeServerAsyncHelper(self, sharedKeyDB, verifierDB,
certChain, privateKey, reqCert, sessionCache, certChain, privateKey, reqCert, sessionCache,
settings): settings, reqCAs):
self._handshakeStart(client=False) self._handshakeStart(client=False)
...@@ -1045,6 +1053,8 @@ class TLSConnection(TLSRecordLayer): ...@@ -1045,6 +1053,8 @@ class TLSConnection(TLSRecordLayer):
raise ValueError("Caller passed a certChain but no privateKey") raise ValueError("Caller passed a certChain but no privateKey")
if privateKey and not certChain: if privateKey and not certChain:
raise ValueError("Caller passed a privateKey but no certChain") raise ValueError("Caller passed a privateKey but no certChain")
if reqCAs and not reqCert:
raise ValueError("Caller passed reqCAs but not reqCert")
if not settings: if not settings:
settings = HandshakeSettings() settings = HandshakeSettings()
...@@ -1380,7 +1390,9 @@ class TLSConnection(TLSRecordLayer): ...@@ -1380,7 +1390,9 @@ class TLSConnection(TLSRecordLayer):
msgs.append(ServerHello().create(self.version, serverRandom, msgs.append(ServerHello().create(self.version, serverRandom,
sessionID, cipherSuite, certificateType)) sessionID, cipherSuite, certificateType))
msgs.append(Certificate(certificateType).create(serverCertChain)) msgs.append(Certificate(certificateType).create(serverCertChain))
if reqCert: if reqCert and reqCAs:
msgs.append(CertificateRequest().create([], reqCAs))
elif reqCert:
msgs.append(CertificateRequest()) msgs.append(CertificateRequest())
msgs.append(ServerHelloDone()) msgs.append(ServerHelloDone())
for result in self._sendMsgs(msgs): for result in self._sendMsgs(msgs):
......
...@@ -13,11 +13,15 @@ class X509: ...@@ -13,11 +13,15 @@ class X509:
@type publicKey: L{tlslite.utils.RSAKey.RSAKey} @type publicKey: L{tlslite.utils.RSAKey.RSAKey}
@ivar publicKey: The subject public key from the certificate. @ivar publicKey: The subject public key from the certificate.
@type subject: L{array.array} of unsigned bytes
@ivar subject: The DER-encoded ASN.1 subject distinguished name.
""" """
def __init__(self): def __init__(self):
self.bytes = createByteArraySequence([]) self.bytes = createByteArraySequence([])
self.publicKey = None self.publicKey = None
self.subject = None
def parse(self, s): def parse(self, s):
"""Parse a PEM-encoded X.509 certificate. """Parse a PEM-encoded X.509 certificate.
...@@ -63,6 +67,10 @@ class X509: ...@@ -63,6 +67,10 @@ class X509:
else: else:
subjectPublicKeyInfoIndex = 5 subjectPublicKeyInfoIndex = 5
#Get the subject
self.subject = tbsCertificateP.getChildBytes(\
subjectPublicKeyInfoIndex - 1)
#Get the subjectPublicKeyInfo #Get the subjectPublicKeyInfo
subjectPublicKeyInfoP = tbsCertificateP.getChild(\ subjectPublicKeyInfoP = tbsCertificateP.getChild(\
subjectPublicKeyInfoIndex) subjectPublicKeyInfoIndex)
......
...@@ -347,8 +347,7 @@ class CertificateRequest(HandshakeMsg): ...@@ -347,8 +347,7 @@ class CertificateRequest(HandshakeMsg):
def __init__(self): def __init__(self):
self.contentType = ContentType.handshake self.contentType = ContentType.handshake
self.certificate_types = [] self.certificate_types = []
#treat as opaque bytes for now self.certificate_authorities = []
self.certificate_authorities = createByteArraySequence([])
def create(self, certificate_types, certificate_authorities): def create(self, certificate_types, certificate_authorities):
self.certificate_types = certificate_types self.certificate_types = certificate_types
...@@ -358,7 +357,13 @@ class CertificateRequest(HandshakeMsg): ...@@ -358,7 +357,13 @@ class CertificateRequest(HandshakeMsg):
def parse(self, p): def parse(self, p):
p.startLengthCheck(3) p.startLengthCheck(3)
self.certificate_types = p.getVarList(1, 1) self.certificate_types = p.getVarList(1, 1)
self.certificate_authorities = p.getVarBytes(2) ca_list_length = p.get(2)
index = 0
self.certificate_authorities = []
while index != ca_list_length:
ca_bytes = p.getVarBytes(2)
self.certificate_authorities.append(ca_bytes)
index += len(ca_bytes)+2
p.stopLengthCheck() p.stopLengthCheck()
return self return self
...@@ -366,7 +371,14 @@ class CertificateRequest(HandshakeMsg): ...@@ -366,7 +371,14 @@ class CertificateRequest(HandshakeMsg):
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request, w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
trial) trial)
w.addVarSeq(self.certificate_types, 1, 1) w.addVarSeq(self.certificate_types, 1, 1)
w.addVarSeq(self.certificate_authorities, 1, 2) caLength = 0
#determine length
for ca_dn in self.certificate_authorities:
caLength += len(ca_dn)+2
w.add(caLength, 2)
#add bytes
for ca_dn in self.certificate_authorities:
w.addVarSeq(ca_dn, 1, 2)
return HandshakeMsg.postWrite(self, w, trial) return HandshakeMsg.postWrite(self, w, trial)
class ServerKeyExchange(HandshakeMsg): class ServerKeyExchange(HandshakeMsg):
......
...@@ -16,13 +16,16 @@ class ASN1Parser: ...@@ -16,13 +16,16 @@ class ASN1Parser:
#Assuming this is a sequence... #Assuming this is a sequence...
def getChild(self, which): def getChild(self, which):
return ASN1Parser(self.getChildBytes(which))
def getChildBytes(self, which):
p = Parser(self.value) p = Parser(self.value)
for x in range(which+1): for x in range(which+1):
markIndex = p.index markIndex = p.index
p.get(1) #skip Type p.get(1) #skip Type
length = self._getASN1Length(p) length = self._getASN1Length(p)
p.getFixBytes(length) p.getFixBytes(length)
return ASN1Parser(p.bytes[markIndex : p.index]) return p.bytes[markIndex : p.index]
#Decode the ASN.1 DER length field #Decode the ASN.1 DER length field
def _getASN1Length(self, p): def _getASN1Length(self, p):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment