Search code examples
python-2.7mockingnose

How to mock aws athena functionality.?


I have an Athena class that has athena functionality. :

class _Athena:
    """
    Creates connection and runs queries against
    AWS Athena.
    """

    def __init__(self, workgroup, database, params, query_string):
        self._workgroup = workgroup
        self._database = database
        self.params = params
        self._output_location = 's3://{bucket}/{path}'.format(
            bucket=_get_bucket(self.params),
            path=_get_path(self.params))
        self._client = _get_client(self.params)
        self._query_string = query_string

      def _get_query_details(self, query_id):
            """
            Gets Athena query details.
    
            :param query_id: id of athena query.
            :type query_id: str
            """
            while True:
                response_get_query_details = self._client.get_query_execution(
                    QueryExecutionId=query_id
                )
                status = response_get_query_details['QueryExecution'][
                    'Status']['State']
                LOGGER.info('Athena query status %s', status)
                if status in ('FAILED', 'CANCELLED'):
                    LOGGER.error(response_get_query_details)
                    raise Exception('Athena query with the string "{}" failed or'
                                    ' was cancelled'.format(self._query_string))
                if status == 'SUCCEEDED':
                    location = response_get_query_details['QueryExecution'][
                        'ResultConfiguration']['OutputLocation']
                    LOGGER.info("Athena output location: %s", location)
                    return response_get_query_details
                # if status of query is running or queued, wait
                else:
                    time.sleep(5)

I want to mock _get_query_details as part of my unit test. Here is the. function that I wrote to test it:

client = boto3.client('athena', 'us-east-1')


    @mock.patch.object(client, 'get_query_execution')
    @mock.patch('shared_utilities.verto_athena._Athena')
    def test_get_query_details(mock_class, mock_client):
        success = {'QueryExecution': {'Status': {'State': 'SUCCEEDED'}}}
        mock_client.return_value = success
        result = mock_class.return_value._get_query_details('id')
        tools.assert_equal(success, result)

However it fails with the following error :

AssertionError: {'QueryExecution': {'Status': {'State': 'SUCCEEDED'}}} != <MagicMock name='_Athena()._get_query_details()' id='4786490336'>

In principle this test should pass based in the state is SUCCEEDED. Any ideas what I might be. doing wrong here.


Solution

  • I was able to achieve that by the following :

       params = {
            "region": "us-east-1",
            "bucket": "temp-prod",
            "path": "path/to/obj"
        }
        
        _athena = module._Athena('workgroup',
                                 'database',
                                  params,
                                 'select * from query'
                                 )
        
        
        
            @mock.patch.object(_athena._client, 'get_query_execution')
            def test_get_query_details_on_success(mock_client):
                success = {
                    'QueryExecution':  {
                        'Status': {
                            'State': 'SUCCEEDED'
                        },
                        'ResultConfiguration':
                        {
                            'OutputLocation': 's3://output'
                        }
                    }
                }
                mock_client.return_value = success
                result = _athena._get_query_details('id')
                tools.assert_equal(success, result)