Search code examples
pythonpython-3.xpython-requestsvhostsurllib3

Monkey patching _ssl_wrap_socket in Python requests library isn't executing


We are trying to add HTTPS support to a web server virtual host scanning tool. Said tool uses the python3 requests library, which uses urllib3 under the hood.

We need a way to provide our own SNI hostname so are attempting to monkey patch the _ssl_wrap_socket function of urllib3 to control server_hostname but aren't having much success.

Here is the full code:

from urllib3.util import ssl_
_target_host = None
_orig_wrap_socket = ssl_.ssl_wrap_socket

def _ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
                     ca_certs=None, server_hostname=None,
                     ssl_version=None, ciphers=None, ssl_context=None,
                     ca_cert_dir=None):
    _orig_wrap_socket(sock, keyfile=keyfile, certfile=certfile,
                      cert_reqs=cert_reqs, ca_certs=ca_certs,
                      server_hostname=_target_host, ssl_version=ssl_version,
                      ciphers=ciphers, ssl_context=ssl_context,
                      ca_cert_dir=ca_cert_dir)

ssl_.ssl_wrap_socket = _ssl_wrap_socket

We then call requests.get() further down in the code. The full context can be found on Github (here).

Unfortunately this isn't working as our code never appears to be reached, and we're not sure why. Is there something obvious that we're missing or a better way to approach this issue?

Further Explanation

The following is the full class:

import os
import random

import requests
import hashlib
import pandas as pd
import time
from lib.core.discovered_host import *
import urllib3

DEFAULT_USER_AGENT = 'Mozilla/5.0 (Windows NT 6.1; Win64; x64) '\
                     'AppleWebKit/537.36 (KHTML, like Gecko) '\
                     'Chrome/61.0.3163.100 Safari/537.36'

urllib3.disable_warnings()

from urllib3.util import ssl_



class virtual_host_scanner(object):
    """Virtual host scanning class

    Virtual host scanner has the following properties:

    Attributes:
        wordlist: location to a wordlist file to use with scans
        target: the target for scanning
        port: the port to scan. Defaults to 80
        ignore_http_codes: commad seperated list of http codes to ignore
        ignore_content_length: integer value of content length to ignore
        output: folder to write output file to
    """
    def __init__(self, target, wordlist, **kwargs):
        self.target = target
        self.wordlist = wordlist
        self.base_host = kwargs.get('base_host')
        self.rate_limit = int(kwargs.get('rate_limit', 0))
        self.port = int(kwargs.get('port', 80))
        self.real_port = int(kwargs.get('real_port', 80))
        self.ssl = kwargs.get('ssl', False)
        self.fuzzy_logic = kwargs.get('fuzzy_logic', False)
        self.unique_depth = int(kwargs.get('unique_depth', 1))
        self.ignore_http_codes = kwargs.get('ignore_http_codes', '404')
        self.first_hit = kwargs.get('first_hit')

        self.ignore_content_length = int(
            kwargs.get('ignore_content_length', 0)
        )

        self.add_waf_bypass_headers = kwargs.get(
            'add_waf_bypass_headers',
            False
        )

        # this can be made redundant in future with better exceptions
        self.completed_scan = False

        # this is maintained until likely-matches is refactored to use
        # new class
        self.results = []

        # store associated data for discovered hosts
        # in array for oN, oJ, etc'
        self.hosts = []

        # available user-agents
        self.user_agents = list(kwargs.get('user_agents')) \
            or [DEFAULT_USER_AGENT]

    @property
    def ignore_http_codes(self):
        return self._ignore_http_codes

    @ignore_http_codes.setter
    def ignore_http_codes(self, codes):
        self._ignore_http_codes = [
            int(code) for code in codes.replace(' ', '').split(',')
        ]

    _target_host = None
    _orig_wrap_socket = ssl_.ssl_wrap_socket

    def _ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
                         ca_certs=None, server_hostname=None,
                         ssl_version=None, ciphers=None, ssl_context=None,
                         ca_cert_dir=None):
        print('SHOULD BE PRINTED')
        _orig_wrap_socket(sock, keyfile=keyfile, certfile=certfile,
                          cert_reqs=cert_reqs, ca_certs=ca_certs,
                          server_hostname=_target_host, ssl_version=ssl_version,
                          ciphers=ciphers, ssl_context=ssl_context,
                          ca_cert_dir=ca_cert_dir)

    def scan(self):
        print('fdsa')
        ssl_.ssl_wrap_socket = self._ssl_wrap_socket

        if not self.base_host:
            self.base_host = self.target

        if not self.real_port:
            self.real_port = self.port

        for virtual_host in self.wordlist:
            hostname = virtual_host.replace('%s', self.base_host)

            if self.real_port == 80:
                host_header = hostname
            else:
                host_header = '{}:{}'.format(hostname, self.real_port)

            headers = {
                'User-Agent': random.choice(self.user_agents),
                'Host': host_header,
                'Accept': '*/*'
            }

            if self.add_waf_bypass_headers:
                headers.update({
                    'X-Originating-IP': '127.0.0.1',
                    'X-Forwarded-For': '127.0.0.1',
                    'X-Remote-IP': '127.0.0.1',
                    'X-Remote-Addr': '127.0.0.1'
                })

            dest_url = '{}://{}:{}/'.format(
                'https' if self.ssl else 'http',
                self.target,
                self.port
            )

            _target_host = hostname

            try:
                res = requests.get(dest_url, headers=headers, verify=False)
            except requests.exceptions.RequestException:
                continue

            if res.status_code in self.ignore_http_codes:
                continue

            response_length = int(res.headers.get('content-length', 0))
            if self.ignore_content_length and \
               self.ignore_content_length == response_length:
                continue

            # hash the page results to aid in identifing unique content
            page_hash = hashlib.sha256(res.text.encode('utf-8')).hexdigest()

            self.hosts.append(self.create_host(res, hostname, page_hash))

            # add url and hash into array for likely matches
            self.results.append(hostname + ',' + page_hash)

            if len(self.hosts) >= 1 and self.first_hit:
                break

            # rate limit the connection, if the int is 0 it is ignored
            time.sleep(self.rate_limit)

        self.completed_scan = True

    def likely_matches(self):
        if self.completed_scan is False:
            print("[!] Likely matches cannot be printed "
                  "as a scan has not yet been run.")
            return

        # segment results from previous scan into usable results
        segmented_data = {}
        for item in self.results:
            result = item.split(",")
            segmented_data[result[0]] = result[1]

        dataframe = pd.DataFrame([
            [key, value] for key, value in segmented_data.items()],
            columns=["key_col", "val_col"]
        )

        segmented_data = dataframe.groupby("val_col").filter(
            lambda x: len(x) <= self.unique_depth
        )

        return segmented_data["key_col"].values.tolist()

    def create_host(self, response, hostname, page_hash):
        """
        Creates a host using the responce and the hash.
        Prints current result in real time.
        """
        output = '[#] Found: {} (code: {}, length: {}, hash: {})\n'.format(
            hostname,
            response.status_code,
            response.headers.get('content-length'),
            page_hash
        )

        host = discovered_host()
        host.hostname = hostname
        host.response_code = response.status_code
        host.hash = page_hash
        host.contnet = response.content

        for key, val in response.headers.items():
            output += '  {}: {}\n'.format(key, val)
            host.keys.append('{}: {}'.format(key, val))

        print(output)

        return host

In this case the following line is never being hit:

print('SHOULD BE PRINTED')

This also results in the following log entry on the web server:

[Wed Oct 25 16:37:23.654321 2017] [ssl:error] [pid 1355] AH02032: Hostname provided via SNI and hostname test.test provided via HTTP are different

Which indicates the code was never run also.


Solution

  • Edit-1: No reload needed

    Thanks to @MartijnPieters helping me enhance the answer. There is no reload needed if we directly patch urllib3.connection. But requests package has some changes in the latest version, which made the original answer not work on some version of requests.

    Here is a updated version of the code, which handles all these things

    import requests
    
    try:
        assert requests.__version__ != "2.18.0"
        import requests.packages.urllib3.util.ssl_ as ssl_
        import requests.packages.urllib3.connection as connection
    except (ImportError,AssertionError,AttributeError):
        import urllib3.util.ssl_ as ssl_
        import urllib3.connection as connection
    
    print("Using " + requests.__version__)
    
    def _ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
                         ca_certs=None, server_hostname=None,
                         ssl_version=None, ciphers=None, ssl_context=None,
                         ca_cert_dir=None):
        print('SHOULD BE PRINTED')
        return ssl_.ssl_wrap_socket(sock, keyfile=keyfile, certfile=certfile,
                          cert_reqs=cert_reqs, ca_certs=ca_certs,
                          server_hostname=server_hostname, ssl_version=ssl_version,
                          ciphers=ciphers, ssl_context=ssl_context,
                          ca_cert_dir=ca_cert_dir)
    
    connection.ssl_wrap_socket = _ssl_wrap_socket
    
    res = requests.get("https://www.google.com", verify=True)
    

    The code is also available on

    https://github.com/tarunlalwani/monkey-patch-ssl_wrap_socket

    Original Answer

    So two issues in your code.

    requests doesn't actually directly import urllib3. It does it through its own context using requests.packages

    So the socket you want to overwrite is

    requests.packages.urllib3.util.ssl_.ssl_wrap_socket
    

    Next if you look at connection.py from urllib3/connection.py

    from .util.ssl_ import (
        resolve_cert_reqs,
        resolve_ssl_version,
        ssl_wrap_socket,
        assert_fingerprint,
    )
    

    This is a local import and it can't be overridden on first attempt as the code is loaded when we use import requests. You can easily confirm that by putting a breakpoint and checking the stack trace back to parent file.

    So for the monkey patch to work we need to reload the module once the patching is done, so it takes our patched function

    Below is minimal code showing that interception works this way

    try:
        reload  # Python 2.7
    except NameError:
        try:
            from importlib import reload  # Python 3.4+
        except ImportError:
            from imp import reload  # Python 3.0 - 3.3
    
    def _ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
                         ca_certs=None, server_hostname=None,
                         ssl_version=None, ciphers=None, ssl_context=None,
                         ca_cert_dir=None):
        print('SHOULD BE PRINTED')
        _orig_wrap_socket(sock, keyfile=keyfile, certfile=certfile,
                          cert_reqs=cert_reqs, ca_certs=ca_certs,
                          server_hostname=_target_host, ssl_version=ssl_version,
                          ciphers=ciphers, ssl_context=ssl_context,
                          ca_cert_dir=ca_cert_dir)
    
    import requests
    _orig_wrap_socket = requests.packages.urllib3.util.ssl_.ssl_wrap_socket
    
    requests.packages.urllib3.util.ssl_.ssl_wrap_socket = _ssl_wrap_socket
    reload(requests.packages.urllib3.connection)
    
    res = requests.get("https://www.google.com", verify=True)
    

    Debug breakpoint