3

I am writing a lambda function that takes a list of CW Log Groups and runs an "export to s3" task on each of them.

I am writing automated tests using pytest and I'm using moto.mock_logs (among others), but create_export_tasks() is not yet implemented (NotImplementedError).

To continue using moto.mock_logs for all other methods, I am trying to patch just that single create_export_task() method using mock.patch, but it's unable to find the correct object to patch (ImportError).

I successfully used mock.Mock() to provide me just the functionality that I need, but I'm wondering if I can do the same with mock.patch()?

Working Code: lambda.py

# lambda.py
"""Export CloudWatch Logs to S3 every 24 hours."""
import logging
import os
from time import time
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class CloudWatchLogsS3Archive:
    botocore_config = Config(retries={"max_attempts": 10, "mode": "adaptive"})

    def __init__(self, s3_bucket, account_id) -> None:
        self.s3_bucket = s3_bucket
        self.account_id = account_id
        self.extra_args = {}
        self.log_groups = []
        self.log_groups_to_export = []
        self.logs = boto3.client("logs", config=self.botocore_config)
        self.ssm = boto3.client("ssm", config=self.botocore_config)

    def check_valid_inputs(self):
        """Check that required inputs are present and valid"""
        if len(self.account_id) != 12:
            logging.error("Account Id must be valid 12-digit AWS account id")
            raise ValueError("Account Id must be valid 12-digit AWS account id")

    def collect_log_groups(self):
        """Capture the names of all of the CloudWatch Log Groups"""
        paginator = self.logs.get_paginator("describe_log_groups")
        page_it = paginator.paginate()
        for p in page_it:
            for lg in p["logGroups"]:
                yield lg["logGroupName"]

    def get_last_export_time(self, Name) -> str:
        """Get time of the last export from SSM Parameter Store"""
        try:
            return self.ssm.get_parameter(Name=Name)["Parameter"]["Value"]
        except (self.ssm.exceptions.ParameterNotFound, ClientError) as exc:
            logger.warning(*exc.args)
            if exc.response["Error"]["Code"] == "ParameterNotFound":  # type: ignore
                return "0"
            else:
                raise

    def set_export_time(self):
        """Set current export time"""
        return round(time() * 1000)

    def put_export_time(self, put_time, Name):
        """Put current export time to SSM Parameter Store"""
        self.ssm.put_parameter(Name=Name, Value=str(put_time), Overwrite=True)

    def create_export_tasks(
        self, log_group_name, fromTime, toTime, s3_bucket, account_id
    ):
        """Create new CloudWatchLogs Export Tasks"""
        try:
            response = self.logs.create_export_task(
                logGroupName=log_group_name,
                fromTime=int(fromTime),
                to=toTime,
                destination=s3_bucket,
                destinationPrefix="{}/{}".format(account_id, log_group_name.strip("/")),
            )
            logger.info("✔   Task created: %s" % response["taskId"])
        except self.logs.exceptions.LimitExceededException:
            """The Boto3 standard retry mode will catch throttling errors and
            exceptions, and will back off and retry them for you."""
            logger.warning(
                "⚠   Too many concurrently running export tasks "
                "(LimitExceededException); backing off and retrying..."
            )
            # return False
        except Exception as e:
            logger.exception(
                "✖   Error exporting %s: %s",
                log_group_name,
                getattr(e, "message", repr(e)),
            )


def lambda_handler(event, context):
    s3_bucket = os.environ["S3_BUCKET"]
    account_id = os.environ["ACCOUNT_ID"]
    c = CloudWatchLogsS3Archive(s3_bucket, account_id)
    c.check_valid_inputs()
    log_groups = c.collect_log_groups()
    for log_group_name in log_groups:
        fromTime = c.get_last_export_time(log_group_name)
        toTime = c.set_export_time()
        c.create_export_tasks(log_group_name, fromTime, toTime, s3_bucket, account_id)
        c.put_export_time(log_group_name, toTime)

Test Code (pytest): test_lambda.py

# test_lambda.py
"""Test Lambda Function"""
import os
from unittest import mock

import boto3
import moto
import pytest


@pytest.fixture
def f_aws_credentials(autouse=True):
    """Mocked AWS Credentials for moto.

    This is a "side effect" function and None is returned because we are
    modifying the environment in which other downstream functions are excuting
    """
    os.environ["AWS_ACCESS_KEY_ID"] = "testing"
    os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
    os.environ["AWS_SECURITY_TOKEN"] = "testing"
    os.environ["AWS_SESSION_TOKEN"] = "testing"
    os.environ["AWS_DEFAULT_REGION"] = "us-east-1"


@moto.mock_logs
@moto.mock_ssm
def test_create_export_tasks():
    from cloudwatch_logs_s3_archive import CloudWatchLogsS3Archive

    c = CloudWatchLogsS3Archive("bucket", "123412341234")
    # ssm = boto3.client("ssm")
    logs = boto3.client("logs")
    logs.create_log_group(logGroupName="/log-exporter-last-export/first")
    logs.create_log_group(logGroupName="/log-exporter-last-export/second")
    logs.create_log_group(logGroupName="/log-exporter-last-export/third")
    log_group_name = "/log-exporter-last-export/first"
    s3_bucket = "s3_bucket"
    account_id = 123412341234
    toTime = c.set_export_time()
    fromTime = c.get_last_export_time(Name="/log-exporter-last-export/first")
    c.logs.create_export_task = mock.Mock(
        return_value={"taskId": "I am mocked via mock.Mock"}
    )
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called
    c.logs.create_export_task.assert_called
    c.logs.create_export_task.assert_called_with(
        logGroupName=log_group_name,
        fromTime=int(fromTime),
        to=toTime,
        destination=s3_bucket,
        destinationPrefix="{}/{}".format(account_id, log_group_name.strip("/")),
    )

Felipe Alvarez
  • 3,720
  • 2
  • 33
  • 42

1 Answers1

2

I'm wondering if I can do the same with mock.patch()?

Sure, by using mock.patch.object():

with mock.patch.object(
    c.logs,
    'create_export_task',
    return_value={"taskId": "I am mocked via mock.Mock"}
):
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called

If you don't like the context manager usage, I recommend installing the pytest-mock plugin alongside pytest which provides a handy mocker fixture. Your test would look like

def test_create_export_tasks(mocker):
    ...
    mocker.patch.object(
        c.logs,
        'create_export_task',
        return_value={"taskId": "I am mocked via mock.Mock"}
    )
    c.create_export_tasks(
        "/log-exporter-last-export/first", fromTime, toTime, "s3_bucket", 123412341234
    )
    assert c.logs.create_export_task.called

mocker is basically a proxy to unittest.mock module and offers the same functions and methods, except that it clears all patches at the end of the test automatically, so there's one less thing to care about.

hoefling
  • 59,418
  • 12
  • 147
  • 194