Search code examples
pythonpostgresqlapache-beamdataflowssh-tunnel

How to set up a SSH tunnel in Google Cloud Dataflow to an external database server?


I am facing a problem to make my Apache Beam pipeline work on Cloud Dataflow, with DataflowRunner.

The first step of the pipeline is to connect to an external Postgresql server hosted on a VM which is only externally accessible through SSH, port 22, and extract some data. I can't change these firewalling rules, so I can only connect to the DB server via SSH tunneling, aka port-forwarding.

In my code I make use of the python library sshtunnel. It works perfectly when the pipeline is launched from my development computer with DirectRunner:

from sshtunnel import open_tunnel

with open_tunnel(
        (user_options.ssh_tunnel_host, user_options.ssh_tunnel_port),
        ssh_username=user_options.ssh_tunnel_user,
        ssh_password=user_options.ssh_tunnel_password,
        remote_bind_address=(user_options.dbhost, user_options.dbport)
    ) as tunnel:
        with beam.Pipeline(options=pipeline_options) as p:
            (p | "Read data" >> ReadFromSQL(
                host=tunnel.local_bind_host,
                port=tunnel.local_bind_port,
                username=user_options.dbusername,
                password=user_options.dbpassword,
                database=user_options.dbname,
                wrapper=PostgresWrapper,
                query=select_query
            )
                | "Format CSV" >> DictToCSV(headers)
                | "Write CSV" >> WriteToText(user_options.export_location)
            )

The same code, launched with DataflowRunner inside a non-default VPC where all ingress are deny but no egress restriction, and CloudNAT configured, fails with this message:

psycopg2.OperationalError: could not connect to server: Connection refused Is the server running on host "0.0.0.0" and accepting TCP/IP connections on port 41697? [while running 'Read data/Read']

So, obviously something is wrong with my tunnel but I cannot spot what exactly. I was beginning to wonder whether a direct SSH tunnel setup was even possible through CloudNAT, until I found this blog post: https://cloud.google.com/blog/products/gcp/guide-to-common-cloud-dataflow-use-case-patterns-part-1 stating:

A core strength of Cloud Dataflow is that you can call external services for data enrichment. For example, you can call a micro service to get additional data for an element. Within a DoFn, call-out to the service (usually done via HTTP). You have full control to make any type of connection that you choose, so long as the firewall rules you set up within your project/network allow it.

So it should be possible to set up this tunnel ! I don't want to give up but I don't know what to try next. Any idea ?

Thanks for reading


Solution

  • Problem solved ! I can't believe I've spent two full days on this... I was looking completely in the wrong direction.

    The issue was not with some Dataflow or GCP networking configuration, and as far as I can tell...

    You have full control to make any type of connection that you choose, so long as the firewall rules you set up within your project/network allow it

    is true.

    The problem was of course in my code : only the problem was revealed only in a distributed environment. I had make the mistake of opening the tunnel from the main pipeline processor, instead of the workers. So the SSH tunnel was up but not between the workers and the target server, only between the main pipeline and the target!

    To fix this, I had to change my requesting DoFn to wrap the query execution with the tunnel :

    class TunnelledSQLSourceDoFn(sql.SQLSourceDoFn):
    """Wraps SQLSourceDoFn in a ssh tunnel"""
    
    def __init__(self, *args, **kwargs):
        self.dbport = kwargs["port"]
        self.dbhost = kwargs["host"]
        self.args = args
        self.kwargs = kwargs
        super().__init__(*args, **kwargs)
    
    def process(self, query, *args, **kwargs):
        # Remote side of the SSH Tunnel
        remote_address = (self.dbhost, self.dbport)
        ssh_tunnel = (self.kwargs['ssh_host'], self.kwargs['ssh_port'])
        with open_tunnel(
            ssh_tunnel,
            ssh_username=self.kwargs["ssh_user"],
            ssh_password=self.kwargs["ssh_password"],
            remote_bind_address=remote_address,
            set_keepalive=10.0
        ) as tunnel:
            forwarded_port = tunnel.local_bind_port
            self.kwargs["port"] = forwarded_port
            source = sql.SQLSource(*self.args, **self.kwargs)
            sql.SQLSouceInput._build_value(source, source.runtime_params)
            logging.info("Processing - {}".format(query))
            for records, schema in source.client.read(query):
                for row in records:
                    yield source.client.row_as_dict(row, schema)
    

    as you can see, I had to override some bits of pysql_beam library.

    Finally, each worker open its own tunnel for each request. It's probably possible to optimize this behavior but it's enough for my needs.