Jose López
Jose López
2020-04-03 | 5 min read

Autoscale Celery workers on ECS Fargate based on RabbitMQ metrics.

Autoscale Celery workers on ECS Fargate based on RabbitMQ metrics.

Asynchronous task queues are tools to allow pieces of a software program to run in a separate machine/process. Celery is a task queuing app. Celery communicates via messages, usually using a broker (e.g. RabbitMQ) to mediate between clients and workers.

If you are reading this, I'm assuming that you're already using Celery and RabbitMQ. In this article, we are going to leverage ECS Fargate scaling capabilities to quickly create/destroy new workers based on RabbitMQ queue depth metrics.

The Terraform code used in this article is available on OBytes GitHub public repos.

Lambda function to push RabbitMQ metrics to CloudWatch

The first thing we need to do is to gather RabbitMQ metrics and push them to CloudWatch. For this, we are going to use a Lambda function that uses boto3 and requests libraries to get a list of queues and their message counts using RabbitMQ API. Then, it will publish these data to AWS CloudWatch.

#!/usr/bin/env python3
from __future__ import with_statement, print_function
from base64 import b64decode
import os
import time
import urllib


import boto3
from botocore.vendored import requests


def get_queue_depths_and_publish_to_cloudwatch(host,
                                               port,
                                               username,
                                               password,
                                               vhost,
                                               namespace):
    """
    Calls the RabbitMQ API to get a list of queues and populate cloudwatch

    :param host:
    :param port:
    :param username:
    :param password:
    :param vhost:
    :param namespace:
    :return:
    """
    depths = get_queue_depths(host, port, username, password, vhost)
    publish_depths_to_cloudwatch(depths, namespace)


def get_queue_depths(host, port, username, password, vhost):
    """
    Get a list of queues and their message counts

    :param host:
    :param port:
    :param username:
    :param password:
    :param vhost:
    :return:
    """
    # Get list of queues
    try:
        r = requests.get('https://{}:{}/api/queues'.format(host, port),
                         auth=requests.auth.HTTPBasicAuth(username, password))
    except requests.exceptions.RequestException as e:
        log('rabbitmq_connection_failures')
        print("ERROR: Could not connect to {}:{} with user {}".format(
            host, port, username))
        return []

    queues = r.json()
    total = 0
    depths = {}
    for q in queues:

        # Ignore celery and pyrabbit queues
        if q['name'] == "aliveness-test":
            continue
        elif q['name'].endswith('.pidbox') or q['name'].startswith('celeryev.'):
            continue

        # Get individual queue counts
        try:
            r = requests.get('https://{}:{}/api/queues/{}/{}'.format(
                host,
                port,
                urllib.parse.quote_plus(vhost),
                urllib.parse.quote_plus(q['name'])),
                auth=requests.auth.HTTPBasicAuth(username, password))
        except requests.exceptions.RequestException as e:
            log('queue_depth_failure', tags=['queue:{}'.format(q['name'])])
            break

        qr = r.json()
        if r.status_code == 200 and 'messages' in qr:
            queue_depth = qr['messages']
            depths[q['name']] = queue_depth
            total = total + int(queue_depth)
        else:
            log('queue_depth_failure', tags=['queue:{}'.format(q['name'])])

    depths['total'] = str(total)
    return depths


def publish_depths_to_cloudwatch(depths, namespace):
    """

    :param depths:
    :param namespace:
    :return:
    """
    cloudwatch = boto3.client(
        'cloudwatch', region_name=os.environ.get("AWS_REGION"))
    for q in depths:
        try:
            cloudwatch.put_metric_data(
                Namespace=namespace,
                MetricData=[{
                    'MetricName': q,
                    'Timestamp': time.time(),
                    'Value': int(depths[q]),
                    'Unit': 'Count',
                }])
            log(namespace, 'gauge', depths[q], [
                'queue:' + q
            ])
        except Exception as e:
            print(str(e))
            log('cloudwatch_put_metric_error')


def lambda_handler(event, context):

    queue_group = context.function_name.split('-', 1)[0]

    host = os.environ.get("RABBITMQ_HOST")
    port = os.environ.get("RABBITMQ_PORT")
    user = os.environ.get("RABBITMQ_USER")
    pw = os.environ.get("RABBITMQ_PASS")
    get_queue_depths_and_publish_to_cloudwatch(
        host=host,
        port=port,
        username=user,
        password=boto3.client('kms').decrypt(CiphertextBlob=b64decode(pw))[
            'Plaintext'].decode('utf8').replace('\n', ''),
        vhost="/",
        namespace=queue_group + ".rabbitmq.depth")


def log(metric_name, metric_type='count', metric_value=1, tags=[]):
    """
    :param metric_name:
    :param metric_type:
    :param metric_value:
    :param tags:
    :return:
    """
    # MONITORING|unix_epoch_timestamp|metric_value|metric_type|my.metric.name|#tag1:value,tag2
    print("MONITORING|{}|{}|{}|{}|#{}".format(
        int(time.time()),
        metric_value,
        metric_type,
        'rabbitmq_cloudwatch.' + metric_name, ','.join(tags)))


if __name__ == "__main__":
    lambda_handler(event=None, context=None)

The function needs IAM permissions to create CloudWatch metrics, log groups and access the KMS key used to encrypt the RABBITMQ_PASS variable. We also need to allow access to RabbitMQ API (port 15672 by default) from this Lambda function.

Autoscale Celery workers based on RabbitMQ queue depth metrics

Now that we have our RabbitMQ queue depth metrics on CloudWatch, we need to configure autoscaling on ECS Fargate to scale the number of workers based on these data.

The values for autoscaling will depend on your setup. The best way to find accurate values is test and error. If you have peaks on your workload (e.g. Celery tasks that syncs products with the DB) that triggers a lot of tasks in a short span of time, you'll want to scale to a big number of workers to be able to process those quickly. However, if you have a steady workload you'll want to scale more slowly.

First of all, we need to create a CloudWatch metric alarm. This alarm will trigger the autoscaling policies based on RabbitMQ values. For test environments that are barely used, it might be interesting to have 0 workers by default and create them dynamically as there are messages on the queue. This will add some delay to the processing of new tasks (Fargate takes a few minutes to deploy new tasks), but it might be worthwhile.

Then, we need to configure the autoscaling target and policies. As I mentioned previously, these values might be adjusted based on your workflow, but you can configure as many steps as you want to scale based on your needs.

After everything is set up, if you check out ECS autoscaling tab, you should see something similar to the image shown below.

And your Celery workers will autoscale based on queue metrics!

And that's all folks! If you have any doubt, feel free to reach me out using the comments or via Twitter (@kstromeiraos).

Share article

More articles