Search code examples
amazon-web-servicesjupyter-notebookamazon-athenaamazon-sagemakerpyathena

Query a table/database in Athena from a Notebook instance


I have developed different Athena Workgroups for different teams so that I can separate their queries and their query results. The users would like to query the tables available to them from their notebook instances (JupyterLab). I am having difficulty finding code which successfully covers the requirement of querying a table from the user's specific workgroup. I have only found code that will query the table from the primary workgroup.

The code I have currently used is added below.

from pyathena import connect
import pandas as pd
conn = connect(s3_staging_dir='<ATHENA QUERY RESULTS LOCATION>',
region_name='<YOUR REGION, for example, us-west-2>')


df = pd.read_sql("SELECT * FROM <DATABASE-NAME>.<YOUR TABLE NAME> limit 8;", conn)
df

This code does not work as the users only have access to perform queries from their specific workgroups hence get errors when this code is run. It also does not cover the requirement of separating the user's queries in user specific workgroups.

Any suggestions on how I can add alter the code so that I can run the queries within a specific workgroup from the notebook instance?


Solution

  • Documentation of pyathena is not super extensive, but after looking into source code we can see that connect simply creates instance of Connection class.

    def connect(*args, **kwargs):
        from pyathena.connection import Connection
        return Connection(*args, **kwargs)
    

    Now, after looking into signature of Connection.__init__ on GitHub we can see parameter work_group=None which name in the same way as one of the parameters for start_query_execution from the official AWS Python API boto3. Here is what their documentation say about it:

    WorkGroup (string) -- The name of the workgroup in which the query is being started.

    After following through usages and imports in Connection we endup with BaseCursor class that under the hood makes a call to start_query_execution while unpacking a dictionary with parameters assembled by BaseCursor._build_start_query_execution_request method. That is excatly where we can see familar syntax for submitting queries to AWS Athena, in particular the following part:

    if self._work_group or work_group:
        request.update({
            'WorkGroup': work_group if work_group else self._work_group
        })
    

    So this should do a trick for your case:

    import pandas as pd
    from pyathena import connect
    
    
    conn = connect(
        s3_staging_dir='<ATHENA QUERY RESULTS LOCATION>',
        region_name='<YOUR REGION, for example, us-west-2>',
        work_group='<USER SPECIFIC WORKGROUP>'
    )
    
    df = pd.read_sql("SELECT * FROM <DATABASE-NAME>.<YOUR TABLE NAME> limit 8;", conn)