Python module for flexible SSL HTTP server handling

in

This module allows flexible SSL certificate from a server. Unlike standard ssl module in Python, this function handles HTTP proxy and invalid certificates.

When implementing a website scanner I've stumbled on specific features of Python's modules like ssl or urllib that made using them difficult. The ssl module always tries to validate SSL certificates so you can't talk to servers with self-signed certificates, nor even fetch their certificate. On the other hand urllib doesn't validate certificates, but it doesn't allow access to low-level details like certificates etc.

The below module uses pyOpenSSL and handles all kinds of certificates and proxies pretty well. Sample code can be found at the end of the module, in the __main__ section.

import socket
import re
import os.path
from OpenSSL import SSL
 
HTTP_HEAD_FMT = 'HEAD / HTTP/1.1\r\nHost: {0}\r\nConnection: close\r\nAccept-Charset: UTF-8,*;q=0.5\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nUser-Agent: Mozilla/5.0 (Windows NT 5.1)\r\n\r\n'
 
HTTP_CONNECT_FMT = 'CONNECT {0}:{1} HTTP/1.0\r\nConnection: close\r\n\r\n'
 
class ProxyNotSet(Exception):
    """ proxy host and port not set """
 
class HTTPS:
    def __verify_none(self, conn, cert, errun, depth, status):
        """ This always returns OK status of validation. Used to fetch contents of invalid certificates. """
        self.__cert_status = status
        return True
 
    def __verify_builtin(self, conn, cert, errun, depth, status):
        """ This returns validation status as passed from OpenSSL built-in validation. """
        # by default cert_status is True
        # if any cert in the path doesn't verify, it's reset to False
        if status == 0:
            self.__cert_status = False
            self.__http_status = 'SSL validation failed'
        return True
 
    def __init__(self):
        # Proxy related
        self.__proxy_host = None
        self.__proxy_port = None
        # Target related
        self.__server = None
        self.__port = None
        self.__ok = True
        self.__http_re = re.compile(r"^HTTP/1\.[01]\s([0-9]{3})\s(.+)$")
        # init resetable values
        self.__reset()
 
    def __reset(self):
        self.__ok = True
        # SSL related
        self.__ctx = SSL.Context(SSL.SSLv23_METHOD)
        self.set_verify(False, None)
        self.__ssl = None
        self.__cert = None
        self.__cert_status = True
        # HTTP related
        self.__http_status = None
        self.__http_headers = None
 
    def set_proxy(self, host, port):
        """
        Set proxy hostname and port.
        """
        self.__proxy_host = host
        self.__proxy_port = port
 
    def set_verify(self, flag, ca_file):
        """
        Determines server's SSL certificate validation strategy. If False (default) no validation will be performed.        
        """
        ctx = self.__ctx
        if flag:  # True
            ctx.set_verify(SSL.VERIFY_PEER, self.__verify_builtin)
            ctx.load_verify_locations(ca_file)
        else:     # False
            ctx.set_verify(SSL.VERIFY_PEER, self.__verify_none)
 
    def set_target(self, host, port=443):
        """
        Set new target for SSL connection. Will reset previous connection, if any.
        """
        self.__server = host
        self.__port = port
        if self.__ssl != None:
            self.__ssl.shutdown()
            self.__ssl.close()
        self.__reset()
 
    def __http_parse(self, data):
        headers = dict()
        r = self.__http_re
        for line in data.splitlines():
            if len(line):
                line = line.decode('utf-8')
                m = r.match(str(line))
                if m:         # search for HTTP/1.x
                    self.__http_status = m.group(1)
                else:
                     tmp = line.split(None, 1)
                     header = tmp[0].upper().strip(':')
                     value = tmp[1]
                     headers[header] = value
        self.__http_headers = headers
 
    def __fetch_cert(self):
        if self.__ssl == None:
            self.__connect()
        ss = self.__ssl
        self.__cert = ss.get_peer_certificate()
 
    def get_cert(self):
        if self.__cert == None:
            self.__fetch_cert()
        cert = self.__cert
        ret = dict()
        ext = dict()
#     for i in range(0, cert.get_extension_count()):
#         e = cert.get_extension(i)
#         ext[e.get_short_name()] = e.get_data()
#     ret['extensions'] = ext
        expired = cert.has_expired()
        if expired == 1:
            expired = True
        else:
            expired = False
        ret['expired'] = expired
        ret['valid'] = self.__cert_status
        subject = cert.get_subject()
        ret['subject'] = subject
        ret['name'] = subject.commonName
        return ret
 
    def init(self):
        if self.__http_status == None:
            self.__fetch_headers()
        return self.__ok
 
    def __fetch_headers(self):
        if self.__ssl == None:
            self.__connect()
        if not self.__ok:
            return
        ss = self.__ssl
        req = HTTP_HEAD_FMT.format(self.__server)
        ss.sendall(bytes(req, 'utf-8'))
        response = b''
        while True:
            try:
                data = ss.recv(1024)
            except (SSL.ZeroReturnError, SSL.SysCallError):
                break
            if not data:
                break
            response += data
        self.__http_parse(response)
 
    def get_headers(self):
        return self.__http_headers
 
    def get_status(self):
        return self.__http_status
 
    def get_header(self, name):
        if not self.__http_headers:
            return None
        if name in self.__http_headers:
            return self.__http_headers[name]
        else:
            return None
 
    def __connect(self):
        if self.__proxy_host != None and self.__proxy_port != None:
            self.__connect_proxy()
        else:
            self.__connect_direct()
 
    def __connect_proxy(self):
        if self.__proxy_host == None or self.__proxy_port == None:
            raise ProxyNotSet
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            s.connect((self.__proxy_host, self.__proxy_port))
        except (socket.timeout, socket.error) as detail:
            self.__ok = False
            self.__http_status = detail
            return
        connect_str = HTTP_CONNECT_FMT.format(self.__server, self.__port)
        s.sendall(bytes(connect_str, 'utf-8'))
        s.recv(4096) # clean the socket         
        ss = SSL.Connection(self.__ctx, s)
        ss.setblocking(True)
        ss.set_connect_state()
        try:
            ss.do_handshake()
        except SSL.SysCallError as detail:
            self.__http_status = detail
            self.__ok = False
            return
        self.__ssl = ss
 
    def __connect_direct(self):
        ss = SSL.Connection(self.__ctx, socket.socket(socket.AF_INET))
        ss.setblocking(True)
        try:
            ss.connect((self.__server, self.__port))
        except (socket.timeout, socket.error) as detail:
            self.__http_status = detail
            self.__ok = False
            return
        try:
            ss.do_handshake()
        except SSL.SysCallError as detail:
            self.__http_status = detail
            self.__ok = False
            return
        self.__ssl = ss
 
import unittest
 
class TestCase(unittest.TestCase):
    def setUp(self):
        self.h = HTTPS()
        self.need_proxy = False
        s = socket.socket(socket.AF_INET)
        try:
            s.connect(('www.google.com', 80))
        except socket.error:
            self.need_proxy = True
    def test_no_proxy(self):
        self.h.set_target('mail.google.com', 443)
        self.assertEqual(self.h.init(), True, 'connection error')
    def test_verify_valid(self):
        self.h.set_target('mail.google.com', 443)
        self.h.set_verify(True, os.path.join(os.path.dirname(__file__), 'cacert.pem'))
        if self.h.init():
            cert = self.h.get_cert()
            self.assertEqual(cert['valid'], True, 'mail.google.com should validate by default')
    def test_verify_invalid(self):
        self.h.set_target('cacert.org', 443)
        self.h.set_verify(True, os.path.join(os.path.dirname(__file__), 'cacert.pem'))
        if self.h.init():
            cert = self.h.get_cert()
            self.assertEqual(cert['valid'], False, 'cacert.org should NOT validate by default')
 
if __name__ == "__main__":
    unittest.main()