Hamza Adami
Hamza Adami
2021-11-17 | 56 min read

GO Serverless! Part 4 - Realtime Interactive and Secure Applications with AWS Websocket API Gateway

GO Serverless! Part 4 - Realtime Interactive and Secure Applications with AWS Websocket API Gateway

If you didn't read the previous part, we highly advise you to do that!


Previously, we've built a simple Flask/Fast APIs, we've deployed them as a Lambda Functions and exposed them through AWS API Gateway V2 as HTTP APIs.

Throughout this part, we will build a Generic, Reusable and Pluggable Websocket Stack (SUMU) that can be hooked to any application to make it interactive and provide realtime capability. the stack should be:

  • Interoperable - integrate with any application.
  • Secure - support Websocket Secured (WSS) and clients' JWT authorization during websocket upgrade request.
  • Fast - provide fast messages deliveries.

Additionally, we will build a simple demo application implementing the basic features that a modern websocket-compatible application needs.

TLDR

SUMU

SUMU(live in Japanese) is a Generic, Reusable and Pluggable Websocket Stack that consists of the following components:

  • Connections Store: a DynamoDB table where all users active connections are cached, SUMU automatically adds new connections to the table, delete closed connections from it and prune stale connections using DDB TTL attribute. additionally, It streams INSERT and DELETE events, so other apps can track users' presence.

  • Integration Inputs/Outputs: SUMU provides an Input SNS Topic and an Input SQS Queue to receive notification requests from external applications seeking to notify the connected users. It also provides an Output SNS Topic and an Output SQS Queue for external applications to receive messages from connected users.

  • Websocket Request JWT Authorizer: a request JWT Authorizer hooked with the connection route and capable of integrating with any JWT IaaS provider (Firebase, Cognito, Auth0...) in order verify the JWT token signature, expiration time and allowed audiences and return authorization policies to APIGW.

  • Websockets API Gateway: SUMU provides a Websocket API Gateway with a connection and disconnection routes integrated with DynamoDB for connections tracking, a Keepalive (ping/pong) route to avoid IDLE connections termination and two messaging (publish/send) routes integrated with SNS/SQS to fanout users' messages to backend applications.

  • Websockets Notifications Async Pusher: a serverless and fast AWS API Gateway websockets notifications' pusher built using Python AsyncIO to support asynchronous and concurrent non-blocking IO calls to DynamoDB connections store and API Gateway management API. making it suitable for receiving notification requests from backend applications and broadcasting those messages to multiple users with a fast and cost-effective approach.

  • Presence Watchdog: Connections Tracker for tracking all users' connections and notifying backend applications about users' presence, It can fanout an ONLINE presence event whenever a user connects, and an OFFLINE presence event whenever a user terminate all his connections from all devices.

Usage

SUMU is built using pure AWS serverless technologies and it can be provisioned with just 2 Terraform modules, the actual SUMU module and a helper module to expose the Websocket API with a custom domain.

# The actual SUMU stack
module "sumu" {
  source      = "git::https://github.com/obytes/terraform-aws-sumu//modules/serverless"
  prefix      = local.prefix
  common_tags = local.common_tags

  # Authorizer
  issuer_jwks_uri         = "https://www.googleapis.com/service_accounts/v1/metadata/x509/[email protected]"
  authorized_audiences    = ["sumu-websocket", ]
  verify_token_expiration = true

  s3_artifacts = {
     arn    = aws_s3_bucket.artifacts.arn
     bucket = aws_s3_bucket.artifacts.bucket
  }
  github = {
     owner          = "obytes"
     webhook_secret = "not-secret"
     connection_arn = "arn:aws:codestar-connections:us-east-1:{ACCOUNT_ID}:connection/{CONNECTION_ID}"
  }
  github_repository               = {
    authorizer = {
      name   = "apigw-jwt-authorizer"
      branch = "main"
    }
    pusher = {
      name   = "apigw-websocket-pusher"
      branch = "main"
    }
  }
  ci_notifications_slack_channels = {
     info  = "ci-info"
     alert = "ci-alert"
  }

  stage_name      = "mvp"
  apigw_endpoint  = "https://live.kodhive.com/push"
  presence_source = "topic"
}

# The Websocket API Exposer
module "gato" {
  source      = "git::https://github.com/obytes/terraform-aws-gato//modules/core-route53"
  prefix      = local.prefix
  common_tags = local.common_tags

  # DNS
  r53_zone_id = aws_route53_zone.prerequisite.zone_id
  cert_arn    = aws_acm_certificate.prerequisite.arn
  domain_name = "kodhive.com"
  sub_domains = {
    stateless = "api"
    statefull = "live"
  }

  # Rest APIS
  http_apis = []

  ws_apis = [
    {
      id    = module.sumu.ws_api_id
      key   = "live"
      stage = module.sumu.ws_api_stage_name
    }
  ]
}

Connecting users

You can connect to SUMU Websocket API using the native Websocket component or a helper package like Sockette. The authorization query string is required to establish the connection with the API.

import Sockette from "sockette";

function connect(accessToken: string): void {
    let endpoint = `wss://live.kodhive.com/push?authorization=${accessToken}`;
    setConnecting(true);
    ws = new Sockette(endpoint, {
        timeout: 5e3,
        maxAttempts: 5,
        onopen: e => {keepAlive();setConnected(true);setConnecting(false);},
        onmessage: e => {
            console.log(JSON.parse(e.data).message)
        },
    });
}

To keep the user's connection active, you can send ping frames periodically, API Gateway will respond with a pong frame immediately.

let keepAliveInterval: any = null;
function ping() {
    if (ws && connected) {
        ws.json({action: 'ping'});
    } else message.error("Not yet connected!")
}
function keepAlive() {
    if (ws && connected) {
        clearInterval(keepAliveInterval)
        keepAliveInterval = setInterval(ping, 3 * 60 * 1000) // Every 3 minutes
    } else message.error("Not yet connected!")
}
keepAlive()

Sending messages from clients to backend

SUMU is integrated with SNS and SQS. clients can send messages to SNS or publish them to SQS queue, The message should be a JSON String that contains the action and the actual message:

  • Send a message to backend applications through SQS:
function send(type: string, msg: {}): void {
    if (ws && connected) {
        ws.json({
            action: "send",
            message: {type: type, message: msg}
        });
        message.success("Message sent!");
    } else message.error("Not yet connected!")
}
  • Publish a message to backend applications through SNS:
function publish(type: string, msg: {}): void {
    if (ws && connected) {
        ws.json({
            action: "publish",
            message: {type: type, message: msg}
        });
        message.success("Message published!");
    } else message.error("Not yet connected!")
}

Subscribing backends to clients' messages

For instance, you can subscribe a Lambda Function as a backend processor of clients' messages by creating an SNS subscription and allowing SNS to invoke the Lambda Function:

resource "aws_sns_topic_subscription" "_" {
  topic_arn = var.messages_topic_arn
  protocol  = "lambda"
  endpoint  = module.server.lambda["alias_arn"]
}

resource "aws_lambda_permission" "with_sns" {
  statement_id  = "AllowExecutionFromSNS"
  action        = "lambda:InvokeFunction"
  function_name = module.server.lambda["arn"]
  qualifier     = module.server.lambda["alias"]
  principal     = "sns.amazonaws.com"
  source_arn    = var.messages_topic_arn
}

In addition to Lambda, you can create subscriptions to publish messages to HTTP webhook endpoints, SMS and Emails.

Polling clients messages from backends

In case a backend application wants to process clients' messages in batches, you can create an SQS event source and give the Lambda Function permission to receive messages from the queue:

resource "aws_lambda_event_source_mapping" "_" {
  enabled                            = true
  batch_size                         = 10
  event_source_arn                   = var.messages_queue_arn
  function_name                      = module.server.lambda["alias_arn"]
  maximum_batching_window_in_seconds = 0 # Do not wait until batch size is fulfilled
}

data "aws_iam_policy_document" "policy" {
  statement {
    actions = [
      "sqs:ChangeMessageVisibility",
      "sqs:ChangeMessageVisibilityBatch",
      "sqs:DeleteMessage",
      "sqs:DeleteMessageBatch",
      "sqs:GetQueueAttributes",
      "sqs:ReceiveMessage"
    ]

    resources = [
      var.messages_queue_arn
    ]
  }
}

SQS is better than SNS if you want to avoid hitting the Lambda Concurrency Limit which is 1,000 (Can be increased to 100,000 by AWS service request)

Notifying clients from backend

Backend applications can push notifications to AWS API Gateway Websocket connected users by sending a notification request to the service integrated with the Pusher (SNS|SQS), notifications requests should meet the following format:

For multicast notifications, the message should be a JSON String that contains the list of users and the actual data:

import json

message = {
    "users": ["783304b1-2320-44db-8f58-09c3035a686b", "a280aa41-d99b-4e1c-b126-6f39720633cc"],
    "data": {"type": "notification", "message": "A message sent to multiple user"}
}
message_to_send = json.dumps(message)

For broadcast notifications, the same but do not provide users list or provide an empty users list:

import json

message = {
    "data": {"type": "announcement", "message": "A broadcast to all users"}
}
message_to_send = json.dumps(message)

For exclusion notifications, instead of providing users list, provide a list of excluded users:

import json

message = {
    "exclude_users": ["783304b1-2320-44db-8f58-09c3035a686b"],
    "data": {
      "type": "announcement",
      "message": {
        "user_id": "783304b1-2320-44db-8f58-09c3035a686b",
        "status": "OFFLINE"
      }
    }
}
message_to_send = json.dumps(message)

Notification requests through SNS

SUMU Pusher is subscribing to notifications SNS Topic, and whenever backend applications Publish notification requests to SNS, the later will quickly notify the Pusher by sending the notification request to the subscribed Pusher Lambda.

This will result in a fast delivery because this approach does not introduce a polling mechanism and SNS will notify the Pusher whenever a notification request is available.

However, at scale SNS will trigger a Pusher Lambda Function for every notification request and given that the Lambda Concurrency Limit is 1,000 per account (Can be increased to 100,000 by support-ticket) notification requests will be throttled for very large applications that can have more than 100,000 concurrent inflight messages.

Publish to SNS when you have a small/medium application with a moderate number of users

import os
import json
import time
import boto3

message = {
    "users": ["783304b1-2320-44db-8f58-09c3035a686b", "a280aa41-d99b-4e1c-b126-6f39720633cc"],
    "data": {
        "type": "notification",
        "message": {
            "text": "Your order has been fulfilled!",
            "timestamp": int(time.time())
        }
    }
}
boto3.client("sns").sns.publish(
    TargetArn=os.environ["NOTIFICATIONS_TOPIC_ARN"],
    Message=json.dumps(message),
)

Sending notification requests through SQS

Unlike SNS, when sending notifications to SQS queue, the Pusher Lambda Function event source is configured to poll notification requests from the SQS Queue, and it will periodically poll notification requests from the Queue using Long Polling Technique.

This will result in notifications requests to be processed in batches, which comes with many benefits:

  • Fewer Lambda Invocations - to not reach the Lambda Concurrency Limit.
  • Concurrent Notifications - as the pusher uses AsyncIO, it will be able to process batches of SQS Records concurrently.
  • Low cost - thanks to SQS Batches and fewer Lambda Invocations.

Pusher can meet the same speed and performance of SNS because the SQS queue receive_wait_time_seconds is set to 20. this will make the Lambda Service do Long Polling instead of Short Polling. In addition to that, Lambda service will have a background worker that has five instances polling every 20 seconds, this will ensure that the lambda will receive the notifications requests as soon as they arrive in the queue.

AWS: the automatic scaling behavior of Lambda is designed to keep polling costs low when a queue is empty while simultaneously letting us scale up to high throughput when the queue is being used heavily. When an SQS event source mapping is initially created and enabled, or when messages first appear after a period with no traffic, then the Lambda service will begin polling the SQS queue using five parallel long-polling connections. The Lambda service monitors the number of inflight messages, and when it detects that this number is trending up, it will increase the polling frequency by 20 ReceiveMessage requests per minute and the function concurrency by 60 calls per minute. As long as the queue remains busy it will continue to scale until it hits the function concurrency limits.

5 parallel connections, each will send 3 ReceiveMessage requests per minute = 15 messages every minute. so 900 every hour, 21600 every day and 648,000 every month.

AWS gives you one million messages for free every month. After that it’s only $0.40 per million messages, so the cost is very low for consuming messages from SQS.

Send to SQS when you have a large application with millions of concurrent inflight messages.

import os
import json
import time
import boto3

message = {
    "users": ["783304b1-2320-44db-8f58-09c3035a686b", "a280aa41-d99b-4e1c-b126-6f39720633cc"],
    "data": {
        "type": "notification",
        "message": {
            "text": "Your order has been fulfilled!",
            "timestamp": int(time.time())
        }
    }
}
boto3.client("sqs").send_message(
  QueueUrl=os.environ.get("NOTIFICATIONS_QUEUE_URL"),
  MessageBody=json.dumps(message),
)

A live demo application is deployed at https://sumu.kodhive.com, hurry up and test it, the live demo will be removed after 1 month from publishing the article!

That's all πŸŽ‰. enjoy your serverless Websocket API. However, if you are not in a rush, continue the article to see how we've built it. you will not regret that.

β€œYou take the blue pill, the story ends, you wake up in your bed and believe whatever you want to believe. You take the red pill, you stay in wonderland, and I show you how deep the rabbit hole goes.” Morpheus

Integration Mechanism

Source: Integration

It's working

First thing to think about when building reusable and interoperable stacks is how the dependent applications will integrate with them in terms of input/output.


In order for our stack to be interoperable with other external applications, we will expose:

  • Two Outputs - SNS Topic and SQS Queue for clients' messages.
  • Two Inputs - SNS Topic and SQS Queue for backend's notification requests.

Output

  • Messages SNS Topic: where the clients' messages will be published, external applications' backend processors can subscribe to this topic and SNS will broadcast published messages to the subscribers as they arrive.
# components/integration/messages.tf
resource "aws_sns_topic" "messages" {
  name = "${local.prefix}-messages"
}
  • Messages SQS Queue: where the clients' messages will be sent, external applications' backend processors can be configured with the queue as an event source, long poll messages and process them in batches.
# components/integration/messages.tf
# Receives users messages from "send" route
resource "aws_sqs_queue" "messages" {
  name                        = "${local.prefix}-messages"
  fifo_queue                  = false
  content_based_deduplication = false

  delay_seconds              = 0      # No delay
  max_message_size           = 262144 # 256 KiB
  receive_wait_time_seconds  = 20     # Long Polling not Short Polling
  message_retention_seconds  = 345600 # 4 Days
  visibility_timeout_seconds = 120*6  # 12 minutes

  redrive_policy = jsonencode({
    deadLetterTargetArn = aws_sqs_queue.messages_dlq.arn
    maxReceiveCount     = 5 # Move failed messages to DLQ after 5 failures
  })

  tags = local.common_tags
}

resource "aws_sqs_queue" "messages_dlq" {
  name = "${local.prefix}-messages-dlq"
}

Input

  • Notifications SNS Topic: where external applications' backend processors publish notification requests, SUMU Publisher that we will build later on this article, will subscribe to this topic and will fulfill notification requests by posting them to target users connections.
# components/integration/notifications.tf
resource "aws_sns_topic" "notifications" {
  name = "${local.prefix}-notifications"
}
  • Notifications SQS Topic: where external applications' can send notification requests for batch processing by SUMU Publisher, this will level the load from Lambda Service, in order for SUMU publisher to process notification requests in batches without hitting Lambda Concurrency Limit and without enduring IO wait time cost as the requests will be processed concurrently.
# components/integration/notifications.tf
# Receives backend notifications requests
resource "aws_sqs_queue" "notifications" {
  name                        = "${local.prefix}-notifications"
  fifo_queue                  = false
  content_based_deduplication = false

  delay_seconds              = 0      # No delay
  max_message_size           = 262144 # 256 KiB
  receive_wait_time_seconds  = 20     # Long Polling not Short Polling
  message_retention_seconds  = 345600 # 4 Days
  visibility_timeout_seconds = 120*6  # 12 minutes, 6 times the lambda timeout (AWS Recommendation)

  redrive_policy = jsonencode({
    deadLetterTargetArn = aws_sqs_queue.notifications_dlq.arn
    maxReceiveCount     = 5 # Move failed messages to DLQ after 5 failures
  })

  tags = local.common_tags
}

resource "aws_sqs_queue" "notifications_dlq" {
  name = "${local.prefix}-notifications-dlq"
}

Load leveling

Do you remember Load Leveling Pattern from part 1? we are following the same pattern here by creating an SQS queue allowing external applications to poll from the queue as much as they can process without crashing low resources' backends.

Also, this will allow SUMU Notifications Pusher Lambda to process notifications from the queue in batches instead of one by one. Eventually, we will avoid the Lambda concurrency limit because there will be fewer invocations.

We have created two SQS queues and a dead letter queue for each one of them, where failed clients messages and failed backend messages will be placed for further investigation, the redrive policy will move failed messages to dead letter queues after 5 failures.

We have set the delay_seconds to 0 seconds to not delay notifications and receive_wait_time_seconds to 20 seconds to use Long Polling instead of Short Polling because it's both a fast and a cost effective mode. the visibility_timeout_seconds is set to 6 times the timeout configured for the Pusher Lambda function as per AWS recommendation.

AWS In almost all cases, Amazon SQS long polling is preferable to short polling. Long-polling requests let your queue consumers receive messages as soon as they arrive in your queue while reducing the number of empty ReceiveMessageResponse instances returned.

Users Connections Store

Source: Users Connection Store

Store

AWS API Gateway does not provide a builtin functionality to store and manage users' connections, so we have to build a websocket connections store and tracking mechanism.

The first thing we have to have is a connections store, and for that we will leverage DynamoDB because of its:

  • Performance and scalability - Amazon DynamoDB provides high throughput at very low latency and gives us the ability to auto-scale by tracking how close our usage is to the upper bounds.

  • Events Stream - DynamoDB streams allows us to receive and update item-level data before and after changes.

  • Time To Live: allowing us to set timestamps for deleting expired data from our tables. as soon as the timestamp expires.

  • Cost Effective – One year free tier allows more than 40 million database operations/month and pricing is based on throughput (read/write per second) rather than storage.

# components/database/connections.tf
resource "aws_dynamodb_table" "connections" {
  name = "SumuConnections"

  ################
  # Billing
  ################
  billing_mode   = "PAY_PER_REQUEST"
  read_capacity  = 0
  write_capacity = 0

  ##################
  # Index/Sort Keys
  ##################
  hash_key  = "user_id"
  range_key = "connection_id"

  ##################
  # Attributes
  ##################
  # User (Partition Key)
  attribute {
    name = "user_id"
    type = "S"
  }

  # WS Connection ID (Sort Key)
  attribute {
    name = "connection_id"
    type = "S"
  }

  ttl {
    enabled        = true
    attribute_name = "delete_at"
  }

  stream_enabled   = true
  stream_view_type = "KEYS_ONLY"

  tags = {
    table_name = "SumuConnections"
  }
}

The configuration for the connections table is as follows:

Billing mode: pay per request for a Serverless Mode.

Partition Key: This will store the user id which is the user's JWT sub attribute. in order for the Messages' Pusher to retrieve all connections of a specific user.

Sort Key: This will hold the user's websocket connection_id generated after successful authorization and connection upgrade. every user will have multiple connections, and we will not create an inverted schema Global Secondary Index for this range key, because we will not need to get the user of a specific connection. instead, we will only need to list the connections of a specific user.

TTL: Time To Live for each connection, we will generate it after connection upgrade and we will set it to +2h future timestamp for when DynamoDB will consider the connection as stale and it will automatically delete it. it's +2h because API Gateway Websocket connection max timeout is 2 hours, if a connection remains on the table more than 2 hours, this is an indicator of a stale connection that should be pruned.

Stream: we've enabled DynamoDB streams to fanout DELETE and INSERT events, this will make the presence's watchdog that we will build later able to track users online and offline events.

Websocket JWT Request Authorizer

Source: AWS API Gateway JWT Authorizer

Fingerprint

Clients will connect to the WebSocket API when initiating a WebSocket Upgrade Request. If the request succeeds, the $connect route is executed while the connection is being established.

Because the WebSocket connection is a stateful connection, we can configure the authorization on the $connect route only. AuthN/AuthZ will be performed only at connection time. If the $connect request fails due to AuthN/AuthZ failure, the connection will not be made.

AWS API Gateway supports multiple mechanisms for authentication and authorization, it supports AWS IAM Roles/Policies, IAM Tags and Custom Lambda Authorizers.

In order for our stack to support most third-party applications we will go with Lambda JWT Authorizer, because JWT is widely used and considered a standard for web applications. bellow is the handler for authorization requests:

import os
import time

import requests
from cachecontrol import CacheControl
from cachecontrol.caches import FileCache
from jose import jwt, jws, jwk, JWTError
from jose.utils import base64url_decode

FIREBASE_JWK_URI = "https://www.googleapis.com/service_accounts/v1/metadata/x509/[email protected]"


def search_for_key(token, keys, construct=True):
    # get the kid from the headers prior to verification
    headers = jwt.get_unverified_headers(token)
    kid = headers["kid"]
    # search for the kid in the downloaded public keys
    key = list(filter(lambda k: k["kid"] == kid, keys))
    if not key:
        raise JWTError(f"Public key not found in jwks.json")
    else:
        key = key[0]
    if construct:
        return jwk.construct(key)
    else:
        return key


def get_public_key(token):
    """
    Because Google's public keys are only changed infrequently (on the order of once per day),
    we can take advantage of caching to reduce latency and the potential for network errors.
    """
    jwks_uri = os.environ["JWT_ISSUER_JWKS_URI"]
    sess = CacheControl(requests.Session(), cache=FileCache("/tmp/jwks-cache"))
    request = sess.get(jwks_uri)
    ks = request.json()
    keys = []
    #
    if jwks_uri == FIREBASE_JWK_URI:
        for k, v in ks.items():
            keys.append({
                "alg": "RS256",
                "kid": k,
                "pem": v
            })
        return search_for_key(token, keys, construct=False)
    else:
        keys = ks["keys"]
        return search_for_key(token, keys, construct=True)


def valid_signature(token, key):
    if isinstance(key, dict):
        # verify the signature, exception should be thrown if verification failed
        jws.verify(token, key["pem"], [key["alg"]], verify=True)
    else:
        # get the last two sections of the token,
        # message and signature (encoded in base64)
        message, encoded_signature = str(token).rsplit('.', 1)
        # decode the signature
        decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
        # verify the signature
        if not key.verify(message.encode("utf8"), decoded_signature):
            raise JWTError("Signature verification failed")
    return True


def decode(token, verify_expiration=True, authorized_audiences=None):
    # since we passed the verification, we can now safely
    # use the unverified claims
    claims = jwt.get_unverified_claims(token)
    # additionally we can verify the token expiration
    if verify_expiration:
        if time.time() > claims["exp"]:
            raise JWTError("Token is expired")
    # and the Audience
    if authorized_audiences:
        # OID TOKEN (aud), OAUTH ACCESS TOKEN (client_id)
        aud = claims.get("aud", claims.get("client_id"))
        if not aud:
            raise JWTError("Token does not have aud nor client_id attribute")
        if aud not in authorized_audiences:
            raise JWTError("Token was not issued for this audience")
    # now we can use the claims
    return claims


def verify(token):
    key = get_public_key(token)
    if valid_signature(token, key):
        authorized_audiences = os.environ.get("JWT_AUTHORIZED_AUDIENCES", []).split(",")
        return decode(
            token,
            verify_expiration=os.environ.get("JWT_VERIFY_EXPIRATION", "true") == "true",
            authorized_audiences=authorized_audiences if len(authorized_audiences) else None
        )


def generate_policy(principal, effect, reason=None):
    auth_response = {
        "principalId": principal,
        "policyDocument": {
            "Version": "2012-10-17",
            "Statement": [
                {
                    "Action": "execute-api:Invoke",
                    "Effect": effect,
                    "Resource": os.environ.get("AUTHORIZED_APIS", "*").split(",")
                }
            ]
        }
    }
    if reason:
        auth_response.update({
            'context': {
                'error': reason
            }
        })
    return auth_response


def check_auth(token):
    if not token:
        return generate_policy("rogue", "Deny", reason="Missing Access Token")
    try:
        claims = verify(token)
        if claims:
            return generate_policy(claims["sub"], "Allow")
    except Exception as e:
        return generate_policy("rogue", "Deny", reason=str(e))


def handle(event, context):
    token = event["headers"].get("Authorization", event.get("queryStringParameters", {}).get("authorization"))
    policy = check_auth(token)
    return policy

The authorizer will be able to:

  • Integrate with any JWT token provider.
  • Verify user's JWT token integrity using provider's public keys.
  • Verify token expiration.
  • Check that the token is issued for an audience in the allowed audiences list.
  • Bind the JWT's sub user attribute to the API Gateway authorization policy.
  • Return a deny or allow authorization policy to API Gateway.

Deploy the Authorizer

To deploy the Authorizer as a Lambda Function, we will use the Terraform modules from Part3

  • Create the Lambda Function
# modules/serverless/components.tf
module "authorizer" {
  source      = "git::https://github.com/obytes/terraform-aws-codeless-lambda.git//modules/lambda"
  prefix      = "${local.prefix}-authorizer"
  common_tags = local.common_tags

  handler = "app.main.handle"
  envs    = {
    AUTHORIZED_APIS          = join(",", module.gateway.authorized_apis)
    JWT_ISSUER_JWKS_URI      = var.issuer_jwks_uri
    JWT_AUTHORIZED_AUDIENCES = join(",", var.authorized_audiences)
    JWT_VERIFY_EXPIRATION    = var.verify_token_expiration
  }
}
  • Deploy the Lambda Function
# modules/serverless/ci.tf
module "authorizer_ci" {
  source      = "git::https://github.com/obytes/terraform-aws-lambda-ci.git//modules/ci"
  prefix      = "${local.prefix}-authorizer-ci"
  common_tags = var.common_tags

  # Lambda
  lambda                   = module.authorizer.lambda
  app_src_path             = "sources"
  packages_descriptor_path = "sources/requirements/lambda.txt"

  # Github
  s3_artifacts      = var.s3_artifacts
  github            = var.github
  pre_release       = var.pre_release
  github_repository = var.github_repositories.authorizer

  # Notifications
  ci_notifications_slack_channels = var.ci_notifications_slack_channels
}

Websocket AWS API Gateway

Source: AWS API Gateway Websocket API

Gateway

After preparing the integration, the connections store and the request authorizer components we will now set up our API Gateway and integrate it with those components, but first we have to create an IAM Role for API Gateway to assume and attach an IAM Policy to it so it can integrate with those components:

# components/gateway/iam.tf
# Integration role
resource "aws_iam_role" "_" {
  name               = var.prefix
  assume_role_policy = data.aws_iam_policy_document.assume_policy.json
  path               = "/"
}
data "aws_iam_policy_document" "assume_policy" {
  statement {
    effect = "Allow"
    actions = ["sts:AssumeRole"]
    principals {
      type        = "Service"
      identifiers = ["apigateway.amazonaws.com"]
    }
  }
}
data "aws_iam_policy_document" "policy" {
  statement {
    effect = "Allow"
    actions = [
      "lambda:InvokeFunction",
    ]
    resources = [
      var.request_authorizer.alias_arn
    ]
  }
  statement {
    effect = "Allow"
    actions = [
      "dynamodb:PutItem",
      "dynamodb:DeleteItem",
    ]

    resources = [
      var.connections_table.arn,
    ]
  }
  statement {
    effect = "Allow"
    actions = [
      "sns:Publish",
    ]

    resources = [
      var.messages_topic.arn,
    ]
  }
  statement {
    effect = "Allow"
    actions = [
      "sqs:SendMessage",
    ]

    resources = [
      var.messages_queue.arn,
    ]
  }
}
resource "aws_iam_policy" "_" {
  name   = local.prefix
  policy = data.aws_iam_policy_document.policy.json
  path   = "/"
}
resource "aws_iam_role_policy_attachment" "_" {
  policy_arn = aws_iam_policy._.arn
  role       = aws_iam_role._.name
}

We're in a good place regarding permissions, let's create an API Gateway V2 Websocket API resource where we specify the selection expressions for the route and API Keys:

# components/gateway/api.tf
resource "aws_apigatewayv2_api" "_" {
  name          = "${local.prefix}-ws-api"
  description   = "Sumu Websockets API"
  protocol_type = "WEBSOCKET"

  route_selection_expression   = "$request.body.action"
  api_key_selection_expression = "$request.header.x-api-key"

  tags = local.common_tags
}

Next, we integrate the API Gateway Websocket API with the Lambda JWT Authorizer we have created earlier, we set the authorization type to REQUEST because we are going to use a request authorizer.

There is no method in the JavaScript WebSockets API for specifying additional headers for the client/browser to send. for that, we instruct API Gateway to get the authorization header from the query string route.request.querystring.authorization.

# components/gateway/authorizers.tf
resource "aws_apigatewayv2_authorizer" "request" {
  name                       = "${var.prefix}-request-authz"
  api_id                     = aws_apigatewayv2_api._.id
  authorizer_type            = "REQUEST"
  identity_sources           = [
    "route.request.querystring.authorization",
  ]
  authorizer_uri             = var.request_authorizer.invoke_arn
  authorizer_credentials_arn = aws_iam_role._.arn
}

After integrating the Request Authorizer with API Gateway and in addition to the IAM Role, we need an additional Lambda Permission to give API Gateway access to invoke the Authorizer Lambda Function:

# components/gateway/permission.tf
resource "aws_lambda_permission" "allow_apigw" {
  statement_id  = local.prefix
  action        = "lambda:InvokeFunction"
  function_name = var.request_authorizer.name
  qualifier     = var.request_authorizer.alias
  principal     = "apigateway.amazonaws.com"
  source_arn    = "${aws_apigatewayv2_api._.execution_arn}/${aws_apigatewayv2_stage._.name}/*"
}

Connection route

Setting up an integration for $connect is optional. because API Gateway provides a default $connect route. however, in our case we want to build custom $connect and $disconnect routes for many reasons:

  • We want to be notified when clients connect and disconnect.
  • We want to throttle connections and control who connects.
  • We want our backend to publish users presence messages (online, offline) back to clients using Fanout SNS Topics.
  • We want to store each connection ID and other information into a database (Amazon DynamoDB).

Unlike other approaches that evolves Lambda Functions to manage user's connections, We will follow a fast and cost effective approach and instead of going through a proxy Lambda Function to persist the connection, we will call DynamoDB PutItem action directly from API Gateway.

First, we need a request template that acts as the DynamoDB PutItem action body, we will map the principalId from the authorizer context to the table's Hash Key user_id, and map the generated connectionId to the table's Range Key connection_id.

# components/gateway/connect/ddb_put.json
{
  "TableName": "SumuConnections",
  "Item": {
    "user_id": {
      "S": "$context.authorizer.principalId"
    },
    "connection_id": {
      "S": "$context.connectionId"
    },
    #set($delete_connection_at = ($context.requestTimeEpoch / 1000) + 7200)
    "delete_at": {
      "S": "$delete_connection_at"
    }
  }
}

As per AWS docs, the maximum connection duration for WebSocket APIs is 2 hours and to make sure stale connections are deleted from DynamoDB, we set the connection TTL attribute delete_at to now+2H.


Next, we will configure the integration with DynamoDB using the integration resource, we will provide:

  • The integration URI: that will be called by API Gateway, including the action type PutItem

  • The credentials ARN: the arn of the IAM role we created earlier. allowing ddb:PutItem permission

  • The request template: the request body template we've created earlier.

# components/gateway/connect/integration.tf
resource "aws_apigatewayv2_integration" "ack_presence" {
  api_id                 = var.api_id
  description            = "Acknowledge user presence"

  passthrough_behavior   = "WHEN_NO_MATCH"
  payload_format_version = "1.0"

  # Upstream
  integration_type     = "AWS"
  integration_uri      = "arn:aws:apigateway:${data.aws_region.current.name}:dynamodb:action/PutItem"
  connection_type      = "INTERNET"
  credentials_arn      = var.credentials_arn
  integration_method   = "POST"
  timeout_milliseconds = 29000

  request_templates = {
    "application/json" = file("${path.module}/ddb_put.json")
  }

  lifecycle {
    ignore_changes = [
      passthrough_behavior
    ]
  }
}

After setting up the integration, we will add a new route that will receive the connection requests, and authorize the users with our JWT Request Authorizer before establishing the connection:

# components/gateway/connect/route.tf
resource "aws_apigatewayv2_route" "connect" {
  api_id         = var.api_id

  # UPSTREAM
  target         = "integrations/${aws_apigatewayv2_integration.ack_presence.id}"
  route_key      = "$connect"
  operation_name = "Acknowledge user presence"

  # AUTHORIZATION
  authorizer_id      = var.request_authorizer_id
  authorization_type = "CUSTOM"
  api_key_required   = false

  route_response_selection_expression = "$default"
}

Lastly, we add an integration response and a route response to return to the client:

# components/gateway/connect/response.tf
resource "aws_apigatewayv2_integration_response" "hello" {
  api_id                   = var.api_id
  integration_id           = aws_apigatewayv2_integration.ack_presence.id
  integration_response_key = "/200/"
}

resource "aws_apigatewayv2_route_response" "hello" {
  api_id             = var.api_id
  route_id           = aws_apigatewayv2_route.connect.id
  route_response_key = "$default"
}

Disconnection route

The $disconnect route is executed after the connection is closed. The connection can be closed by the client or by the server after:

  • IDLE connection timeout - 10 minutes, if the client did not send keepalive (ping/pong) requests.
  • Maximum connection duration: even if the client sends keepalive requests, it will be disconnected by API Gateway after 2 hours.

As the connection is already closed when the route is executed, $disconnect is a best-effort event. API Gateway will try its best to deliver the $disconnect event to our integration, but it cannot guarantee delivery. this is why we are using the DynamoDB TTL attributes to delete the stale connection automatically after 2hours.


The same as the connection route we will need a request template. however, it will be used this time to delete the closed connection.

# components/gateway/disconnect/ddb_delete.json
{
  "TableName": "SumuConnections",
  "Key": {
    "user_id": {
      "S": "$context.authorizer.principalId"
    },
    "connection_id": {
      "S": "$context.connectionId"
    }
  }
}

The integration request will be the same as the connection integration request, the only things we need to change are the DynamoDB action which will be DeleteItem, and reference the new deletion request template:

# components/gateway/disconnect/integration.tf
resource "aws_apigatewayv2_integration" "ack_absence" {
  api_id                 = var.api_id
  description            = "Acknowledge user absence"

  passthrough_behavior   = "WHEN_NO_MATCH"
  payload_format_version = "1.0"

  # Upstream
  integration_type     = "AWS"
  integration_uri      = "arn:aws:apigateway:${data.aws_region.current.name}:dynamodb:action/DeleteItem"
  connection_type      = "INTERNET"
  credentials_arn      = var.credentials_arn
  integration_method   = "POST"
  timeout_milliseconds = 29000

  request_templates = {
    "application/json" = file("${path.module}/ddb_delete.json")
  }

  lifecycle {
    ignore_changes = [
      passthrough_behavior
    ]
  }
}

When the connection is terminated gracefully or ungracefully, API Gateway will trigger the disconnection route, so let's create this route. this time we will not need the authorizer because as we said, it's only needed for connection route:

# components/gateway/disconnect/route.tf
resource "aws_apigatewayv2_route" "disconnect" {
  api_id         = var.api_id

  # UPSTREAM
  target         = "integrations/${aws_apigatewayv2_integration.ack_absence.id}"
  route_key      = "$disconnect"
  operation_name = "Acknowledge user absence"

  # AUTHORIZATION
  authorization_type = "NONE"
  api_key_required   = false

  route_response_selection_expression = "$default"
}

For graceful connection termination, we need the disconnect route to return responses to clients:

# components/gateway/disconnect/response.tf
resource "aws_apigatewayv2_integration_response" "bye" {
  api_id                   = var.api_id
  integration_id           = aws_apigatewayv2_integration.ack_absence.id
  integration_response_key = "/200/"
}

resource "aws_apigatewayv2_route_response" "bye" {
  api_id             = var.api_id
  route_id           = aws_apigatewayv2_route.disconnect.id
  route_response_key = "$default"
}

Keep Alive route

To ensure the clients connections are not considered IDLE by API Gateway after 10 minutes timeout, we will implement a ping/pong mechanism to serve as a keepalive and as a means to verify that the remote client is still responsive. The clients may send a Ping request periodically after the connection is established and before the connection is closed.

As a start we need two request templates, one for a MOCK service that will receive the ping request, and the second one is for the actual pong response sent to the client:

# components/gateway/keepalive/ping.json
{
  "statusCode" : 200
}
# components/gateway/keepalive/pong.json
{
  "type": "pong",
  "statusCode" : 200,
  "connectionId" : "$context.connectionId"
}

We need an integration to a MOCK service that will receive the ping request from the ping route and return HTTP_OK status code:

# components/gateway/keepalive/integration.tf
resource "aws_apigatewayv2_integration" "ping" {
  api_id      = var.api_id
  description = "Receive ping frame from client"

  # Upstream
  integration_type = "MOCK"

  template_selection_expression = "200"
  request_templates = {
    "200" = file("${path.module}/ping.json")
  }

  lifecycle {
    ignore_changes = [
      passthrough_behavior
    ]
  }
}

We also need a route to receive the ping request from the client and forward it to the MOCK integration:

# components/gateway/keepalive/route.tf
resource "aws_apigatewayv2_route" "ping" {
  api_id = var.api_id

  # UPSTREAM
  target         = "integrations/${aws_apigatewayv2_integration.ping.id}"
  route_key      = "ping"
  operation_name = "Ping websocket server"

  # AUTHORIZATION
  authorization_type = "NONE"
  api_key_required   = false

  route_response_selection_expression = "$default"
}

Upon receipt of a Ping frame, API Gateway must send a Pong response in response, unless it already received a disconnect request. It should respond with Pong response as soon as is practical:

# components/gateway/keepalive/response.tf
resource "aws_apigatewayv2_integration_response" "ping" {
  api_id                   = var.api_id
  integration_id           = aws_apigatewayv2_integration.ping.id
  integration_response_key = "/200/" # must be /XXX/ or $default

  template_selection_expression = "200"
  response_templates            = {
    "200" = file("${path.module}/pong.json")
  }
}

resource "aws_apigatewayv2_route_response" "ping" {
  api_id             = var.api_id
  route_id           = aws_apigatewayv2_route.ping.id
  route_response_key = "$default" # must be default
}

Thanks to the ping route we've created, users can just send this ping websocket message periodically to keep their connection active:

{"action": "ping"}

JSON Messages sent to API Gateway websocket routes should contain an action that match the target route or else it will be sent to the default route.

Publish route

We have dealt with connections lifecycle routes, and now we need to create the actual messaging routes, that will receive users messages and publish it to backend applications. before starting, in an event-driven architecture it's crucial to agree on a format that messages should respect:

{
  "action": "publish",
  "message": {
    "type": "type of the message (eg: call_op, send_message ...)",
    "message": {
      "the actual": "message"
    }
  }
}

To publish a message, the action should always be set to publish because we will name our route key as such. But first, let's create our integration and this time we will do things differently. we will publish every received message to an SNS topic, so external applications interested in the message can subscribe to the topic:

# components/gateway/publish/integration.tf
resource "aws_apigatewayv2_integration" "publish" {
  api_id                 = var.api_id
  description            = "Publish websocket message through SNS"

  passthrough_behavior   = "WHEN_NO_MATCH"
  payload_format_version = "1.0"

  # Upstream
  integration_type     = "AWS"
  integration_uri      = "arn:aws:apigateway:${data.aws_region.current.name}:sns:action/Publish"
  connection_type      = "INTERNET"
  credentials_arn      = var.credentials_arn
  integration_method   = "POST"
  timeout_milliseconds = 5000

  request_parameters = {
    "integration.request.querystring.TopicArn" = "'${var.messages_topic_arn}'"
    "integration.request.querystring.Message"  = "route.request.body.message"
    # Sender ID Attribute
    "integration.request.querystring.MessageAttributes.entry.1.Name"              = "'user_id'",
    "integration.request.querystring.MessageAttributes.entry.1.Value.DataType"    = "'String'",
    "integration.request.querystring.MessageAttributes.entry.1.Value.StringValue" = "context.authorizer.principalId",
    # Message Timestamp Attribute
    "integration.request.querystring.MessageAttributes.entry.2.Name"              = "'timestamp'",
    "integration.request.querystring.MessageAttributes.entry.2.Value.DataType"    = "'Number'",
    "integration.request.querystring.MessageAttributes.entry.2.Value.StringValue" = "context.requestTimeEpoch",
    # Message Source Attribute
    "integration.request.querystring.MessageAttributes.entry.3.Name"              = "'source'",
    "integration.request.querystring.MessageAttributes.entry.3.Value.DataType"    = "'String'",
    "integration.request.querystring.MessageAttributes.entry.3.Value.StringValue" = "'apigw.route.publish'",
  }

  lifecycle {
    ignore_changes = [
      passthrough_behavior
    ]
  }
}

Same as DynamoDB, we've set the integration uri to the SNS Publish action and provided the integration with the IAM role credentials ARN allowed to call that action.

This time we didn't provide a request template because SNS does not allow messages to be sent in request body and requires API Gateway to send them in request query string. For that we have used the request_parameters attribute to map API Gateway request attributes with SNS request query string.

For every websocket message, the user_id attribute will be added to the message attributes, so external applications can distinguish the message originator in a safe and secure way. we are also mapping the message timestamp at AWS API Gateway because backend applications should not trust messages' timestamps provided by clients, and should only trust the ones generated by API Gateway.

After setting up the integration, now we need a route to receive messages from users and publish them to the SNS integration:

# components/gateway/publish/route.tf
resource "aws_apigatewayv2_route" "publish" {
  api_id         = var.api_id

  # UPSTREAM
  target         = "integrations/${aws_apigatewayv2_integration.publish.id}"
  route_key      = "publish"
  operation_name = "Publish websocket message through SNS"

  # AUTHORIZATION
  authorization_type = "NONE"
  api_key_required   = false

  route_response_selection_expression = "$default"
}

Finally, we will create the integration response and the route response:

# components/gateway/publish/response.tf
resource "aws_apigatewayv2_integration_response" "publish" {
  api_id                   = var.api_id
  integration_id           = aws_apigatewayv2_integration.publish.id
  integration_response_key = "/200/"
}

resource "aws_apigatewayv2_route_response" "publish" {
  api_id             = var.api_id
  route_id           = aws_apigatewayv2_route.publish.id
  route_response_key = "$default"
}

Send route

As we agreed earlier, SUMU should also support sending clients messages to SQS queue for batch processing, and for that we will need to integrate AWS API Gateway Websocket API with SQS.

We will follow the same steps as we did with SNS integration, but this time SQS query params will have different names, and the action is SendMessage instead of Publish.

# components/gateway/send/integration.tf
resource "aws_apigatewayv2_integration" "send" {
  api_id                 = var.api_id
  description            = "Send websocket message through SQS"

  passthrough_behavior   = "WHEN_NO_MATCH"
  payload_format_version = "1.0"

  # Upstream
  integration_type     = "AWS"
  integration_uri      = "arn:aws:apigateway:${data.aws_region.current.name}:sqs:action/SendMessage"
  connection_type      = "INTERNET"
  credentials_arn      = var.credentials_arn
  integration_method   = "POST"
  timeout_milliseconds = 5000

  request_parameters = {
    "integration.request.querystring.QueueUrl" = "'${var.messages_queue_url}'"
    "integration.request.querystring.MessageBody"  = "route.request.body.message"
    # Sender ID Attribute
    "integration.request.querystring.MessageAttributes.1.Name"              = "'user_id'"
    "integration.request.querystring.MessageAttributes.1.Value.DataType"    = "'String'"
    "integration.request.querystring.MessageAttributes.1.Value.StringValue" = "context.authorizer.principalId"
    # Message Timestamp Attribute
    "integration.request.querystring.MessageAttributes.2.Name"              = "'timestamp'"
    "integration.request.querystring.MessageAttributes.2.Value.DataType"    = "'Number'"
    "integration.request.querystring.MessageAttributes.2.Value.StringValue" = "context.requestTimeEpoch"
    # Message Source Attribute
    "integration.request.querystring.MessageAttributes.3.Name"               = "'source'",
    "integration.request.querystring.MessageAttributes.3.Value.DataType"     = "'String'",
    "integration.request.querystring.MessageAttributes.3.Value.StringValue"  = "'apigw.route.send'",
  }

  lifecycle {
    ignore_changes = [
      passthrough_behavior
    ]
  }
}

Same as the publish messaging route, we need a send route:

# components/gateway/send/route.tf
resource "aws_apigatewayv2_route" "send" {
  api_id         = var.api_id

  # UPSTREAM
  target         = "integrations/${aws_apigatewayv2_integration.send.id}"
  route_key      = "send"
  operation_name = "Send websocket message through SQS"

  # AUTHORIZATION
  authorization_type = "NONE"
  api_key_required   = false

  route_response_selection_expression = "$default"
}

And integration/route responses:

resource "aws_apigatewayv2_integration_response" "send" {
  api_id                   = var.api_id
  integration_id           = aws_apigatewayv2_integration.send.id
  integration_response_key = "/200/"
}

resource "aws_apigatewayv2_route_response" "send" {
  api_id             = var.api_id
  route_id           = aws_apigatewayv2_route.send.id
  route_response_key = "$default"
}

Deploying the Gateway

After preparing all routes, the only remaining step is to deploy our API.

# components/gateway/stages.tf
resource "aws_apigatewayv2_stage" "_" {
  name        = var.stage_name
  api_id      = aws_apigatewayv2_api._.id
  description = "Default Stage"
  auto_deploy = true

  access_log_settings {
    format          = jsonencode(local.access_logs_format)
    destination_arn = aws_cloudwatch_log_group.access.arn
  }

  default_route_settings {
    logging_level          = null
    throttling_burst_limit = 5000
    throttling_rate_limit  = 10000
  }

  lifecycle {
    ignore_changes = [
      deployment_id,
    ]
  }
}

The stage deployment mode is set to auto deploy, to avoid deploying the API on each change.

Presence Watchdog

Source: Presence Watchdog

Watchdog

There are multiple applications (eg: Chat Applications, Game Applications...) that requires users presence tracking, and they want the ability to get notified whenever a user gone offline and whenever he's back online.

These presence events can then be dispatched by backend applications to all other users or a subset of users (eg, Friends, Groups). To provide this capability we need:

  • Connections/Disconnections stream: a stream of INSERT and DELETE events from the DynamoDB connections table.
  • Presence watchdog: a Lambda Function listening to the stream and decide if the user is offline or still online.
  • Presence Events Source: where the watchdog will publish/send presence events. we will use the already created messages topic/queue. external applications can decide the presence events' source to be a topic or a queue, and eventually subscribe to the messages' topic for presence events or poll from the queue, and broadcast the events to other parties.

We already have the DynamoDB steam enabled and the SNS topic in place, so let's create the presence watchdog Lambda Function:

# components/presence/sources/main.py
from __future__ import print_function

import json
import logging
import os
import time

import boto3
from boto3.dynamodb.conditions import Key

logger = logging.getLogger()
logger.setLevel(logging.WARNING)
PRESENCE_SOURCE = os.environ.get("PRESENCE_SOURCE", "queue")


def user_still_online(user_id: str):
    """
    Check if user was disconnected from all devices
    :param user_id: user principal id
    :return: True if offline False if online
    """
    resource = boto3.resource("dynamodb")
    connections_table = resource.Table(os.environ["CONNECTIONS_TABLE"])
    active_connections = connections_table.query(
        KeyConditionExpression=Key("user_id").eq(user_id)
    ).get("Items", [])
    return len(active_connections)


def publish_user_presence(user_id: str, present: bool = True, event_time: float = 0):
    """
    Notify online/offline events
    :param user_id: user principal id
    :param present: True if online False if online
    :param event_time: useful for precedence check if user
    connects/disconnects rapidly and events came unordered
    """
    event = json.dumps({
        "type": "presence",
        "message": {
            "user_id": user_id,
            "status": "ONLINE" if present else "OFFLINE",
            "timestamp": event_time
        }
    })
    attributes = dict(
        source={"DataType": "String", "StringValue": "lambda.presence.watchdog", },
        user_id={"DataType": "String", "StringValue": user_id, },
        timestamp={"DataType": "Number", "StringValue": f"{int(time.time())}", },
    )
    if PRESENCE_SOURCE == "topic":
        boto3.client("sns").publish(
            TargetArn=os.environ.get("MESSAGES_TOPIC_ARN"),
            Message=event,
            MessageAttributes=attributes
        )
    elif PRESENCE_SOURCE == "queue":
        boto3.client("sqs").send_message(
            QueueUrl=os.environ.get("MESSAGES_QUEUE_URL"),
            MessageBody=event,
            MessageAttributes=attributes
        )
    else:
        print("Subscribe to presence directly from DynamoDB stream")


def handler(event, context):
    print(event)
    try:
        for record in event["Records"]:
            event_time = record["dynamodb"]["ApproximateCreationDateTime"]
            user_id = record["dynamodb"]["Keys"]["user_id"]["S"]
            if record["eventName"] == "INSERT":
                print(f"user {user_id} is online, notify!")
                publish_user_presence(user_id, True, event_time)
            elif record["eventName"] == "REMOVE":
                print(f"user {user_id} gone offline!, check other user devices...")
                if not user_still_online(user_id):
                    print(f"user {user_id} gone offline from all devices!, notify!")
                    publish_user_presence(user_id, False, event_time)
                else:
                    print(f"user {user_id} still online on other devices, skip!")
    except Exception as error:
        logger.exception(error)

After user's connection, the API Gateway will add the connection to the connections table. DynamoDB will produce an INSERT stream that will be caught by our watchdog, and finally the Watchdog will fanout an online event using the SNS Topic or SQS Queue, depending on what presence source the external application have chosen.

After user's disconnection. the API Gateway will remove the connection from connections table. DynamoDB will produce a REMOVE stream that will be caught by our watchdog. If the user still have active connections in other devices the watchdog will just skip, but if it is the last connection, the watchdog will fanout an offline event to the SNS Topic or SQS Queue.

In order for the lambda function to receive connections/disconnections streams, we have to subscribe our Lambda function to the DynamoDB stream, and to make sure the presence Lambda does not retry infinitely in case of a malformed event, we've set the maximum_retry_attempts to 5 and configured the source mapping to ignore old events.

# components/presence/stream.tf
resource "aws_lambda_event_source_mapping" "presence_stream" {
  enabled                = true
  event_source_arn       = var.connections_table.stream_arn
  function_name          = aws_lambda_function.function.arn

  starting_position             = "LATEST"
  maximum_retry_attempts        = 5 # Retry for five times
  maximum_record_age_in_seconds = 60 # Ignore Offline/Online events older than 1minutes
}

We also need to add the required DynamoDB/SNS permissions to Lambda Function role so it can Query DynamoDB, Publish to the SNS topic and SendMessage to the SQS queue:

# components/presence/iam.tf
data "aws_iam_policy_document" "custom_policy_doc" {
  statement {
    effect = "Allow"
    actions = [
      "dynamodb:DescribeStream",
      "dynamodb:GetRecords",
      "dynamodb:GetShardIterator",
      "dynamodb:ListStreams"
    ]
    resources = [
      "*",
    ]
  }

  statement {
    actions = [
      "dynamodb:Query",
    ]

    resources = [
      var.connections_table.arn,
    ]
  }

  statement {
    actions = [
      "sns:Publish",
    ]

    resources = [
      var.messages_topic.arn
    ]
  }

  statement {
    actions = [
      "sqs:SendMessage",
    ]

    resources = [
      var.messages_queue.arn
    ]
  }
}

Finally, we create our Lambda Function:

data "archive_file" "lambda_zip" {
  type        = "zip"
  output_path = "${path.module}/sources/dist.zip"
  source_dir  = "${path.module}/sources"
}
resource "aws_lambda_function" "function" {
  function_name = local.prefix
  role          = aws_iam_role.role.arn
  # runtime
  runtime = var.runtime
  handler = var.handler
  # resources
  memory_size = var.memory_size
  timeout     = var.timeout
  # package
  filename         = data.archive_file.lambda_zip.output_path
  source_code_hash = data.archive_file.lambda_zip.output_base64sha256

  environment {
    variables = {
      CONNECTIONS_TABLE  = var.connections_table.name
      MESSAGES_TOPIC_ARN = var.messages_topic.arn
      MESSAGES_QUEUE_URL = var.messages_queue.url
      PRESENCE_SOURCE    = var.presence_source
    }
  }

  tags = merge(
    local.common_tags,
    {description = var.description}
  )
  depends_on = [data.archive_file.lambda_zip]
}

Websocket Messages Pusher

Source: AWS API Gateway Websocket Pusher

Pusher

We have built the Websocket API for clients to send messages to backend applications. In addition to that, we will need to provide a Websocket Notifications Pusher for backend applications to push messages to connected clients. to achieve this we will need:

  • Notifications SNS Topic: where external applications will publish the notification requests.
  • Notifications SQS Queue: where external applications will send notification requests.
  • Connections Table: a tables with all active users connections.
  • Python AsyncIO Lambda Function: an async lambda function that will receive notification requests from SNS/SQS and send the notifications to users asynchronously.

We already have the SNS Topic, SQS Queue and the Connections Table, so we only need to build the Lambda function, subscribe it to SNS Topic and hook it to SQS event source.

We've picked Python AsyncIO because we want to take advantage of the event loop, this comes with two benefits:

  • Fast Notification: we can't instruct AWS API Gateway to send a message to multiple users because API Gateway forces us to post a message to one connection at a time. and we can't get connections of a list of users from Dynamo DB as we can query only by a single user_id. post_to_connection and query are IO operations, and we don't want to block our pusher awaiting the response from API Gateway and DynamoDB.however, we can send many requests asynchronously without waiting by using AsyncIO event loop. this will lead us to a huge difference and notifications will be delivered faster than the sequential processing approach.

  • Cost effective: when using sequential processing, you are paying to AWS the computing cost and the waiting cost between each postToConnection and query requests. however, when using async processing you are paying just the computing cost and the actual Lambda runtime.

AWS: If you develop an AWS Lambda function with Node.js, you can call multiple web services without waiting for a response due to its asynchronous nature. All requests are initiated almost in parallel, so you can get results much faster than a series of sequential calls to each web service. Considering the maximum execution duration for Lambda, it is beneficial for I/O bound tasks to run in parallel.

We will follow AWS recommendation, but we will do it the Python way 😎.

import asyncio
import json
import os
from datetime import time, timedelta
from typing import List, Set
import time

import aioboto3
from boto3.dynamodb.conditions import Key
from botocore.errorfactory import ClientError

session = aioboto3.Session()


def handler(event, context):
    """
    :param event: SQS message
    :param context: Lambda Context
    :return: AsyncIO loop
    """
    print(event)
    parser_func = parse_sns_record if "Sns" in event["Records"][0] else parse_sqs_record
    pusher = Pusher(
        event["Records"],
        parser_func
    )
    loop = asyncio.get_event_loop()
    return loop.run_until_complete(pusher.notify_all_records())


class Pusher:
    """
    Asynchronous Batch Notifications Pusher
    """

    def __init__(self, records: List, parse_record_func):
        """
        :param records: SQS Records (Notifications Tasks)
        """
        self.endpoint_url = os.environ["APIGW_ENDPOINT"]
        self.connections_table_name = os.environ["CONNECTIONS_TABLE_NAME"]
        self.records = records
        self.start_time = time.time()
        self.stale_connections = []
        self.total_notified_connections = 0
        self.deleted_stale_connections = 0
        self.parse_record_func = parse_record_func

    @staticmethod
    async def retrieve_all_users_connections(table, exclude_users_ids: List[str]):
        """
        Coroutine to retrieve single user connections
        :param table: connections table
        :param exclude_users_ids: list users to exclude
        :return: List of connections
        """
        params = get_exclusion_filter(exclude_users_ids)
        result = await table.scan(**params)
        connections = result.get("Items")
        while result.get("LastEvaluatedKey"):
            result = await table.scan(ExclusiveStartKey=result["LastEvaluatedKey"], **params)
            connections.extend(result["Items"])
        return connections

    @staticmethod
    async def retrieve_user_connections(table, user_id: str):
        """
        Coroutine to retrieve single user connections
        :param table: connections table
        :param user_id: the user id (Hash Key)
        :return: List of connections
        """
        result = await table.query(
            KeyConditionExpression=Key("user_id").eq(user_id)
        )
        return result.get("Items", [])

    async def notify_connection(self, apigw, user_id: str, connection_id: str, data: str):
        """
        Coroutine to notify single user connection
        :param apigw: APIGatewayManagementAPI client
        :param user_id: the user id (Hash Key)
        :param connection_id: API Gateway connection id
        :param data: binary data
        """
        try:
            await apigw.post_to_connection(
                Data=data,
                ConnectionId=connection_id
            )
            self.total_notified_connections += 1
        except ClientError as error:
            if error.response['Error']['Code'] == 'GoneException':
                self.stale_connections.append(
                    {'user_id': user_id, 'connection_id': connection_id}
                )

            else:
                print(error)

    async def notify_user(self, table, apigw, user_id: str, data):
        """
        Coroutine to notify all connections of a single user
        :param table: connections table
        :param apigw: APIGatewayManagementAPI client
        :param user_id: user_id
        :param data: binary data
        :return: binary data to send to single user
        """
        connections = await self.retrieve_user_connections(table, user_id)
        # If user has active connections (Online), then notify.
        if len(connections):
            notifications = [
                self.notify_connection(
                    apigw, user_id, connection["connection_id"], data
                )
                for connection in connections
            ]

            await asyncio.wait(notifications)

    async def notify_selected_users(self, table, apigw, users_ids: Set[str], data):
        """
        Coroutine to notify all connections of selected users
        :param table: connections table
        :param apigw: APIGatewayManagementAPI client
        :param users_ids: List of users' ids
        :param data: binary data to send to all users
        """
        notifications = [
            self.notify_user(table, apigw, user_id, data) for user_id in users_ids
        ]

        await asyncio.wait(notifications)

    async def notify_all_users(self, table, apigw, exclude_users_ids: List[str], data):
        """
        Coroutine to notify all connections of all users
        :param table: connections table
        :param apigw: APIGatewayManagementAPI client
        :param exclude_users_ids: APIGatewayManagementAPI client
        :param data: binary data to send to all users
        """
        connections = await self.retrieve_all_users_connections(table, exclude_users_ids)
        if len(connections):
            notifications = [
                self.notify_connection(
                    apigw, connection["user_id"], connection["connection_id"], data
                )
                for connection in connections
            ]

            await asyncio.wait(notifications)

    async def delete_all_stale_connections(self, table):
        """
        Coroutine to delete all stale connections
        :param table: connections table
        """
        async with table.batch_writer() as batch:
            for stale_connection in self.stale_connections:
                await batch.delete_item(Key=stale_connection)

    async def notify_all_records(self):
        """
        Coroutine to notify all connections of all or selected users in all SQS batch records
        """
        async with session.resource("dynamodb") as ddb:
            table = await ddb.Table(self.connections_table_name)
            async with session.client("apigatewaymanagementapi", endpoint_url=self.endpoint_url) as apigw:
                notifications = []
                for record in self.records:
                    users, exclude_users, data = self.parse_record_func(record)
                    if users:
                        notifications.append(self.notify_selected_users(table, apigw, users, data))
                    else:
                        notifications.append(self.notify_all_users(table, apigw, exclude_users, data))
                await asyncio.wait(notifications)
            await self.delete_all_stale_connections(table)
        await self.print_stats()

    async def print_stats(self):
        elapsed = (time.time() - self.start_time)
        total_elapsed_human = str(timedelta(seconds=elapsed))
        print(f"[STATS] Processed {len(self.records)} SQS records")
        print(f"[STATS] Notified {self.total_notified_connections} connections")
        print(f"[STATS] Finished in {total_elapsed_human}")
        print(f"[STATS] Deleted {len(self.stale_connections)} stale connections")


####################
# Helpers Functions
####################
def get_unique_users(users: List[str]):
    return set(users or [])


def parse_sqs_record(record):
    body = json.loads(record["body"])
    users = get_unique_users(body.get("users", []))
    exclude_users = get_unique_users(body.get("exclude_users", []))
    data = json.dumps(body["data"])
    return users, exclude_users, data


def parse_sns_record(record):
    message = json.loads(record["Sns"]["Message"])
    users = get_unique_users(message.get("users", []))
    exclude_users = get_unique_users(message.get("exclude_users", []))
    data = json.dumps(message["data"])
    return users, exclude_users, data


def get_exclusion_filter(exclude_users_ids):
    if exclude_users_ids:
        excluded_users = ', '.join([f":id{idx}" for idx, _ in enumerate(exclude_users_ids)])
        return dict(
            ExpressionAttributeNames={
                "#user_id": "user_id"
            },
            FilterExpression=f"NOT(#user_id in ({excluded_users}))",
            ExpressionAttributeValues={f":id{idx}": user_id for idx, user_id in enumerate(exclude_users_ids)}
        )
    else:
        return {}

Let's deep dive into what we have built and what it can offer:

  • Multiple sources: capable of receiving notification requests from SQS and SNS, which make it suitable for processing single notification request from SNS and high number of notification requests in batches when polling from SQS

  • Multicast Notifications: the pusher can send messages to only a subset of users (eg: users in a chat room).

  • Broadcast Notifications: sometimes you just want to send the same message for all connected users (eg: broad announcements)

  • Exclusion Notifications: the pusher can broadcast messages to all users except a list of excluded users (eg: online/offline presence events can be sent to all except the originator)

  • Stale Connections Pruning: capable of detecting and deleting stale connections from DynamoDB connections store in case API Gateway missed cleaning them.

  • Asynchronous Processing: the pusher is using AsyncIO to notify multiple users/connections concurrently to not wait for inflight requests to DynamoDB/API Gateway so you don't pay AWS the waiting time 😎

  • Batch Processing: when using SQS as event source the Pusher will be able to process batches of notification requests, also concurrently.

  • Duplicate Users Detection: able to detect duplicate users in a notification requests and make them unique set of users. to avoid double notifications.

Now that we have the pusher Lambda function, we have to configure it to receive messages from SQS/SNS. for SQS we will need an event source mapping. we've set the batch size to 10, and we've set the maximum_batching_window_in_seconds to 0 to not wait the batch size to be exactly 10 to start consuming:

# components/pusher/sqs_event_source.tf
resource "aws_lambda_event_source_mapping" "_" {
  enabled                            = true
  batch_size                         = 10
  event_source_arn                   = var.notifications_queue_arn
  function_name                      = module.pusher.lambda["alias_arn"]
  maximum_batching_window_in_seconds = 0 # Do not wait until batch size is fulfilled
}

For SNS, we will create a subscription and give SNS permission to invoke the Lambda Function:

# components/pusher/sns_subscription.tf
resource "aws_sns_topic_subscription" "_" {
  topic_arn = var.notifications_topic_arn
  protocol  = "lambda"
  endpoint  = module.pusher.lambda["alias_arn"]
}

resource "aws_lambda_permission" "_" {
  statement_id  = "AllowExecutionFromSNS"
  action        = "lambda:InvokeFunction"
  function_name = module.pusher.lambda["arn"]
  qualifier     = module.pusher.lambda["alias"]
  principal     = "sns.amazonaws.com"
  source_arn    = var.notifications_topic_arn
}

In addition to the Lambda SQS event source, we need to give the Lambda access permissions to DynamoDB, API Gateway and SQS:

# components/pusher/iam.tf
data "aws_iam_policy_document" "custom_policy_doc" {
  statement {
    actions = [
      "dynamodb:Query",
      "dynamodb:Scan",
      "dynamodb:DeleteItem",
      "dynamodb:BatchWriteItem"
    ]

    resources = [
      var.connections_table.arn,
    ]
  }

  statement {
    actions = [
      "execute-api:ManageConnections",
    ]

    resources = [
      "${var.agma_arn}/*"
    ]
  }

  statement {
    actions = [
      "sqs:ChangeMessageVisibility",
      "sqs:ChangeMessageVisibilityBatch",
      "sqs:DeleteMessage",
      "sqs:DeleteMessageBatch",
      "sqs:GetQueueAttributes",
      "sqs:ReceiveMessage"
    ]

    resources = [
      var.notifications_queue_arn
    ]
  }
}

That's it, we have built a fast asynchronouse AWS API Gateway websocket notifications pusher.

Expose it!

Source: AWS API Gateway APIs Exposer

Expose

After deploying SUMU, now we need to expose it to the outer world with a beautiful domain name instead of the ugly one generated by AWS. for that we will use the module from Part 3.

We need these prerequisites before exposing our API:

  • AWS route53 or cloudflare zone.
  • AWS ACM Certificate for the subdomain that we will use with our API.
  • An A record for the APEX domain (the custom domain creation will fail otherwise).

If you already have these requirements let's create our custom API Gateway domain which will replace the default invoke URL provided by API gateway:

module "gato" {
  source      = "git::https://github.com/obytes/terraform-aws-gato//modules/core-route53"
  prefix      = local.prefix
  common_tags = local.common_tags

  # DNS
  r53_zone_id = aws_route53_zone.prerequisite.zone_id
  cert_arn    = aws_acm_certificate.prerequisite.arn
  domain_name = "kodhive.com"
  sub_domains = {
    stateless = "api"
    statefull = "live"
  }

  # Rest APIS
  http_apis = []

  ws_apis = [
    {
      id    = module.sumu.ws_api_id
      key   = "push"
      stage = module.sumu.ws_api_stage_name
    }
  ]
}

We have created an API Gateway Mapping to map the deployed API Stage mvp with the custom domain name live.kodhive.com, and we have chosen push as the mapping key for SUMU. so, the exposed URL will be: https://live.kodhive.com/push

Demo Application

Source: AWS Sumu Demo

Demo

In this section, we are going to build a demo application implementing the following features:

  • Sign in using Google SSO with Firebase.
  • Connects to SUMU websocket API Gateway.
  • Disconnect from SUMU websocket API Gateway.
  • Periodic KeepAlive (Ping/Pong) to keep connections active.
  • Publishing messages to the messages SNS Topic through API Gateway.
  • Sending messages to the messages SQS Queue through API Gateway.
  • Retrieve the list of connected users.
  • Receives users ONLINE/OFFLINE events and update the connected users list accordingly.
  • Implement multiple publish and send push modes:
    • UNICAST: sends an echo message that will be received by just the sender.
    • MULTICAST: sends a message to multiple selected users except the sender.
    • BROADCAST: sends a message to all connected users except the sender.

The live demo application is located at https://sumu.kodhive.com/.

Server

The server is a simple Lambda Function that can process SNS and SQS messages coming from the clients and send notifications back to clients, it can handle presence event, publish/send events and return users connections to requesters.

#
import json
import os
from typing import Callable, Any

import boto3


class Channels:

    def __init__(self, records):
        self.sns = boto3.client("sns")
        self.sqs = boto3.client("sqs")
        self.records = records

    def publish_message(self, message):
        self.sns.publish(
            TargetArn=os.environ.get("NOTIFICATIONS_TOPIC_ARN"),
            Message=json.dumps(message),
        )

    def send_message(self, message):
        self.sqs.send_message(
            QueueUrl=os.environ.get("NOTIFICATIONS_QUEUE_URL"),
            MessageBody=json.dumps(message),
        )

    def parse_message(self):
        if "Sns" in self.records[0]:
            value_key = "Value"
            payload = self.records[0]["Sns"]
            sxs_message = json.loads(payload["Message"])
            attributes = payload["MessageAttributes"]
            push = self.publish_message
        else:
            value_key = "stringValue"
            payload = self.records[0]
            sxs_message = json.loads(payload["body"])
            attributes = payload["messageAttributes"]
            push = self.send_message

        attrs = {key: item[value_key] for key, item in attributes.items()}
        return sxs_message, attrs, push


def handler(event, context):
    """
    Demo processor that does not have any real utility
    Just used to illustrate how backend applications can interact with SUMU
    :param event: events coming from SUMU integration (SNS or SQS)
    :param context: Lambda Context
    """
    print(event)
    channels = Channels(event["Records"])
    handle_event(*channels.parse_message())


def handle_presence(message: Any, originator_id):
    """
    Broadcast the presence event to all users except originator
    :param message: message
    :param originator_id: originator user
    """
    return {
        "exclude_users": [originator_id],
        "data": {
            "type": "presence",
            "message": message,
        }
    }


def handle_message(message: Any, message_timestamp: int, sender_id: str):
    """
    Send a chat message from originator client to receivers clients
    When push mode is:
    1 - UNICAST: echo message to its sender.
    2 - MULTICAST: send message to multiple clients except originators
    3 - BROADCAST: send message to all clients except originators
    :param message: message to push
    :param message_timestamp: messages timestamp
    :param sender_id: sender id
    :return:
    """
    # Decide push mode
    push_mode = message["push_mode"]
    users = None
    exclude_users = None
    if push_mode == "UNICAST":
        users = [sender_id]
    elif push_mode == "MULTICAST":
        try:
            users = message["users"]
        except KeyError:
            print("[WARNING] Push mode is multicast. however, no users provided. Skip!")
            return
    elif push_mode == "BROADCAST":
        exclude_users = [sender_id, ]
    return {
        "users": users,
        "exclude_users": exclude_users,
        "data": {
            "type": "message",
            "message": {
                "text": message["text"],
                "user_id": sender_id,
                "timestamp": message_timestamp
            }
        }
    }


def handle_connections_request(requester_id):
    """
    Get all distinct connected users except requester user
    :param requester_id: requester
    :return: message
    """
    connections = get_connections([requester_id])
    connected_users = list(set([connection["user_id"] for connection in connections]))
    return {
        "users": [requester_id],
        "data": {
            "type": "connected_users",
            "message": {
                "users": connected_users,
            }
        }
    }


def handle_event(sxs_message: Any, attributes: dict, push: Callable):
    """
    Presence/Routes events, could be:
    1 - coming from APIGW->DDB->Watchdog-SNS (Subscription)
    2 - coming from APIGW->DDB->Watchdog-SQS (Polling)
    3 - directly from APIGW->DDB (Stream)
    Broadcast the event to all users except originator
    1 - SNS->PUSHER-APIGW
    2 - SQS->PUSHER-APIGW
    :param sxs_message:
    :param attributes:
    :param push: channel push method
    """
    caller_id = attributes["user_id"]
    event_source = attributes["source"]
    message = sxs_message["message"]
    m_type = sxs_message["type"]
    m = None

    if event_source == "lambda.presence.watchdog" and m_type == "presence":
        # Handle events coming from presence watchdog
        m = handle_presence(message, caller_id)
    elif event_source.startswith("apigw.route."):
        # Handle events coming from "apigw.route.publish" and "apigw.route.send"
        if m_type == "message":
            m = handle_message(message, int(attributes["timestamp"]), caller_id)
        elif m_type == "get_connected_users":
            m = handle_connections_request(caller_id)
    if m:
        push(m)


def get_connections(exclude_users_ids):
    excluded_users = ', '.join([f":id{idx}" for idx, _ in enumerate(exclude_users_ids)])
    params = dict(
        ExpressionAttributeNames={
            "#user_id": "user_id"
        },
        FilterExpression=f"NOT(#user_id in ({excluded_users}))",
        ExpressionAttributeValues={f":id{idx}": user_id for idx, user_id in enumerate(exclude_users_ids)}
    )
    resource = boto3.resource("dynamodb")
    connections_table = resource.Table(os.environ["CONNECTIONS_TABLE"])
    result = connections_table.scan(**params)
    connections = result.get("Items")
    while result.get("LastEvaluatedKey"):
        result = connections_table.scan(ExclusiveStartKey=result["LastEvaluatedKey"], **params)
        connections.extend(result["Items"])
    return connections

Client

The client is a React Application that integrates with Firebase for authentication and with SUMU Websocket API:

// app/client/src/index.tsx
import React from 'react';
import ReactDOM from 'react-dom';
import App from './App';
import {AuthProvider} from "./context/AuthProvider";
import {SumuProvider} from "./context/SumuProvider";

ReactDOM.render(
  <React.StrictMode>
      <AuthProvider>
          <SumuProvider>
              <App />
          </SumuProvider>
      </AuthProvider>
  </React.StrictMode>,
  document.getElementById('root')
);

The auth provider will be responsible for authenticating users using firebase, it will return the auth component if the user is still not connected. otherwise, it will return the child component which is SumuProvider:

// app/client/src/context/AuthProvider.tsx
import React, {useState, useEffect} from 'react';
import {initializeApp} from 'firebase/app';
import {getAuth, signInWithPopup, GoogleAuthProvider, signOut, onAuthStateChanged} from "firebase/auth";
import {Spin, message} from 'antd';
import Auth from '../containers/Auth'

initializeApp({
    "apiKey": process.env.REACT_APP_FIREBASE_API_KEY,
    "authDomain": process.env.REACT_APP_FIREBASE_AUTH_DOMAIN,
    "projectId": process.env.REACT_APP_FIREBASE_PROJECT_ID,
    "measurementId": process.env.REACT_APP_FIREBASE_MEASUREMENT_ID
});


interface AuthContextData {
    user: any;
    loading: boolean;

    login(): void;

    logout(): void;
}

const initial = {
    user: null,
};

const AuthContext = React.createContext<AuthContextData>(initial as AuthContextData);
const auth = getAuth()

function AuthProvider({children}: any) {
    const [loading, setLoading] = useState<boolean>(true);
    const [bootstrapping, setBootstrapping] = useState<boolean>(true);
    const [user, setUser] = useState<any>(null);

    function login() {
        setLoading(true);
        const provider = new GoogleAuthProvider();
        signInWithPopup(auth, provider).then(function (result) {
            setUser(result.user);
            setLoading(false)
        }).catch(function (error: any) {
            console.log(error.message);
            message.error("Unable to sign in");
            setLoading(false)
        });
    }

    function logout() {
        signOut(auth).then(function () {
            // Sign-out successful.
        }).catch(function (error: any) {
            console.log(error)
        });
    }

    /** ======================
     *  Hooks
     ---------------------- */
    useEffect(() => {
        setBootstrapping(true);
        setLoading(true);
        onAuthStateChanged(auth, (user: any) => {
            setUser(user);
            setBootstrapping(false);
            setLoading(false)
        });
    }, []);

    return (
        <AuthContext.Provider value={
            {
                user: user,
                loading: loading,
                login: login,
                logout: logout
            }
        }>
            {
                user ? children : bootstrapping ?
                    <Spin spinning={loading} size="large" style={{
                        width: "100%",
                        height: "100vh",
                        lineHeight: "100vh"
                    }}/> :
                    <Auth/>
            }
        </AuthContext.Provider>
    )
}

const useAuth = () => React.useContext(AuthContext);
export {AuthProvider, useAuth}

SumuProvider is a wrapper around SUMU functions, it will always return the main application. and provide it with the connect, disconnect, ping, send and publish functions. In addition to that it will provide it with connecting, connected and connectedUsers states:

// app/client/src/context/SumuProvider.tsx
import React, {useState, useEffect} from 'react'
import {message, notification} from 'antd';
import Sockette from "sockette";

import {useAuth} from "./AuthProvider";

interface SumuContextData {
    connecting: boolean;
    connected: boolean;
    connectedUsers: any;

    connect(): void;

    disconnect(raise?: boolean): void;

    ping(): void;

    send(type: string, message: {}): void;

    publish(type: string, message: {}): void;
}

const initial = {
    connecting: false,
    connected: false
};

const SumuContext = React.createContext<SumuContextData>(initial as SumuContextData);
let keepAliveInterval: any = null;

function SumuProvider({children}: any) {

    const [connected, setConnected] = useState<boolean>(false);
    const [connecting, setConnecting] = useState<boolean>(false);
    const [ws, setWS] = useState<Sockette | null>(null);
    const {user} = useAuth();
    const [connectedUsers, setConnectedUsers] = useState(new Set());

    function connect(): void {
        if (connected || connecting) {
            message.error("Already connected!")
        } else {
            // Initiate connection through react hook
            setConnecting(true);
        }
    }

    function disconnect(raise = true): void {
        if (ws && connected && !connecting) {
            clearInterval(keepAliveInterval);
            console.log("Closing connections");
            ws.close()
        } else {
            if (raise) {
                message.error("Already disconnected!")
            }
        }
    }

    function send(type: string, msg: {}): void {
        if (ws && connected) {
            console.log("Send message");
            ws.json({
                action: 'send',
                message: {
                    type: type,
                    message: msg
                }
            });
            message.success("Message sent!");
        } else message.error("Not yet connected!")
    }

    function publish(type: string, msg: {}): void {
        if (ws && connected) {
            console.log("Publish message");
            ws.json({
                action: 'publish',
                message: {
                    type: type,
                    message: msg
                }
            });
            message.success("Message published!");
        } else message.error("Not yet connected!")
    }


    function ping() {
        if (ws && connected) {
            console.log("Send ping")
            ws.json({action: 'ping'});
        } else message.error("Not yet connected!")
    }

    function keepAlive() {
        if (ws && connected) {
            console.log("Keep alive")
            let interval = 3 * 60 * 1000 // Every 3 minutes
            clearInterval(keepAliveInterval)
            keepAliveInterval = setInterval(ping, interval)
        } else message.error("Not yet connected!")
    }

    /** ======================
     *  Hooks
     ---------------------- */
    useEffect(() => {
        if (connected && !connecting) {
            keepAlive();
            publish("get_connected_users", {usage: "contact"});
        }
        // eslint-disable-next-line react-hooks/exhaustive-deps
    }, [connected, connecting]);


    useEffect(() => {
        return () => {
            if (ws) {
                console.log("Tear down")
                clearInterval(keepAliveInterval);
                ws.close();
            }
        };
    }, [ws]);

    useEffect(() => {
        if (connecting) {
            user.getIdToken().then((accessToken: string) => {
                let endpoint = `${process.env.REACT_APP_WEBSOCKET_URL}?authorization=${accessToken}`;
                let sumuWebsocket = new Sockette(
                    endpoint,
                    {
                        timeout: 5e3,
                        maxAttempts: 5,
                        onopen: e => {
                            notification.success({
                                message: "Connected",
                                placement: 'bottomLeft'
                            });
                            setConnected(true)
                            setConnecting(false)
                        },
                        onmessage: messageHandler,
                        onreconnect: e => {
                            notification.warning({
                                message: "Reconnecting...",
                                placement: "bottomLeft"
                            });
                        },
                        onmaximum: e => {
                            notification.error({
                                message: "Could not connect to server, stop attempting!",
                                placement: "bottomLeft"
                            });
                            setConnected(false)
                        },
                        onclose: e => {
                            console.log("Closed!", e);
                            notification.error({
                                message: "Disconnected!",
                                placement: 'bottomLeft'
                            });
                            setConnected(false)
                        },
                        onerror: e => {
                            console.log("Error:", e);
                            setConnected(false)
                        },
                    }
                );
                setWS(sumuWebsocket)
            });
        }
    }, [connecting])

    const messageHandler = (e: any) => {
        let payload = JSON.parse(e.data);
        let m = payload.message;
        switch (payload.type) {
            case "connected_users":
                setConnectedUsers(new Set(m["users"]))
                break;
            case "message":
                notification.warning({
                    message: "New message",
                    description: m.text,
                    placement: "topRight"
                });
                break;
            case "presence":
                let presence = `${m.user_id} is ${m.status}`;
                if (m.status === "OFFLINE") {
                    notification.error({
                        message: "Presence",
                        description: presence,
                        placement: "topRight"
                    });
                    const newConnected = new Set(connectedUsers);
                    newConnected.delete(m.user_id)
                    setConnectedUsers(newConnected);
                } else if (m.status === "ONLINE") {
                    notification.success({
                        message: "Presence",
                        description: presence,
                        placement: "topRight"
                    });
                    const newConnected = new Set(connectedUsers).add(m.user_id)
                    console.log(newConnected)
                    setConnectedUsers(newConnected);
                }


                break;
            case "pong":
                notification.warning({
                    message: "Keep Alive",
                    description: "Received Pong from API Gateway",
                    placement: "bottomLeft"
                });
                break;
            default:
                break;
        }
    }

    return (
        <SumuContext.Provider value={
            {
                connecting: connecting,
                connected: connected,
                connectedUsers: connectedUsers,
                connect: connect,
                disconnect: disconnect,
                ping: ping,
                send: send,
                publish: publish
            }
        }>
            {
                children
            }
        </SumuContext.Provider>
    )
}

const useSumu = () => React.useContext(SumuContext);
export {SumuProvider, useSumu}

Finally, we have the main application that uses all the elements provided by SumuProvider and AuthProvider:

import React, {useEffect, useState} from 'react';
import {Col, Row, Layout, Button, Card, Radio, Input, Typography, Divider, message, Select} from "antd";
import {CheckCircleOutlined, CloseCircleOutlined, SyncOutlined} from "@ant-design/icons";
import './App.less';

import {useAuth} from "./context/AuthProvider";
import {useSumu} from "./context/SumuProvider";

import Code from "./containers/Code";

const {Content, Header} = Layout;
const {TextArea} = Input;
const {Paragraph, Text} = Typography;

const App = (props: any) => {

    const {logout} = useAuth();
    const {connect, disconnect, publish, send, ping, connected, connecting, connectedUsers} = useSumu()
    const [pushMode, setPushMode] = useState<string>("UNICAST");
    const [msg, setMsg] = useState<string>("");
    const [users, setUsers] = useState<string[] | undefined>(undefined);

    useEffect(() => {
        connect()
        // eslint-disable-next-line react-hooks/exhaustive-deps
    }, []);

    function doPublish() {
        if (pushMode === "MULTICAST" && !(users && users.length !== 0)){
            message.error("Push mode is multicast, please select users!")
            return
        }

        if (msg) {
            publish("message", {
                text: msg,
                push_mode: pushMode,
                users: users
            })
            setMsg("")
        } else
            message.error("Please enter the message!")
    }

    function doSend() {
        if (pushMode === "MULTICAST" && !(users && users.length !== 0)){
            message.error("Push mode is multicast, please select users!")
            return
        }
        if (msg) {
            send("message", {
                text: msg,
                push_mode: pushMode,
                users: users
            })
            setMsg("")
        } else
            message.error("Please enter the message!")
    }

    function doPing() {
        ping()
    }

    function doConnect() {
        connect()
    }

    function doDisconnect() {
        disconnect()
    }

    function stateColor() {
        if (connected && !connecting) return "#00BFA6";
        else if (!connected && !connecting) return "#F50057";
        else if (connecting) return "#00B0FF";
    }

    function connectedUsersOptions() {
        const options:any = []
        connectedUsers.forEach((user:any) => {
            options.push({
                label: user,
                value: user
            })
        });
        return options;
    }


    return (<>
        <Layout>
            <Header
                style={{
                    backgroundColor: "#141414",
                    borderBottom: `1px solid ${stateColor()}`
                }}
            >
                <Row align={"middle"} justify={"space-between"} style={{width: "100%", height: "100%"}}>
                    {connected && !connecting && <CheckCircleOutlined style={{fontSize: 25, color: stateColor()}}/>}
                    {!connected && !connecting && <CloseCircleOutlined style={{fontSize: 25, color: stateColor()}}/>}
                    {connecting && <SyncOutlined spin style={{fontSize: 25, color: stateColor()}}/>}

                    {connectedUsers.size>0 && <Text style={{color: "#00BFA6"}}>{`${connectedUsers.size}`} Users Connected</Text>}
                    {connectedUsers.size===0 && <Text style={{color: "#F50057"}}>No User Connected</Text>}

                    <Button
                        onClick={e => logout()}
                        type="primary"
                        ghost
                    >
                        Logout
                    </Button>
                </Row>
            </Header>

            <Content
                style={{
                    padding: '0 50px',
                    paddingTop: 10,
                    backgroundColor: "#282c34"
                }}
            >
                <Row>
                    <Row style={{width: "100%" , marginBottom: 16}} align={"middle"} gutter={16}>


                        <Col span={8} style={{display: 'flex'}}>
                            <Code
                                title={"Publish message"} code={"publish"}
                                onClick={doPublish}
                                disabled={!connected}
                            />
                        </Col>

                        <Col span={8} style={{display: 'flex'}}>
                            {/*<Row style={{width: "100%"}}>*/}
                            {/*    <Image src={"/arch.svg"}/>*/}
                            {/*</Row>*/}

                            <Card title={"Message Params"} style={{width: "100%"}}>
                                <Row style={{width: "100%"}} justify={"center"}>
                                    <Radio.Group
                                        defaultValue="UNICAST"
                                        onChange={e => setPushMode(e.target.value)}
                                        style={{marginTop: 16, marginBottom: 16}}
                                    >
                                        <Radio.Button value="UNICAST">UNICAST</Radio.Button>
                                        <Radio.Button value="MULTICAST">MULTICAST</Radio.Button>
                                        <Radio.Button value="BROADCAST">BROADCAST</Radio.Button>
                                    </Radio.Group>
                                </Row>
                                <Row>
                                    {pushMode === "UNICAST" &&
                                    <Paragraph type={"secondary"}>The message will be sent to <Text strong>just
                                        you!</Text></Paragraph>}
                                    {pushMode === "MULTICAST" &&
                                    <Paragraph type={"secondary"}>The message will be sent to <Text strong>selected
                                        users!</Text></Paragraph>}
                                    {pushMode === "BROADCAST" &&
                                    <Paragraph type={"secondary"}>The message will be sent to <Text strong>all
                                        users except you!</Text></Paragraph>}
                                    <TextArea rows={2}
                                              maxLength={150}
                                              allowClear
                                              showCount
                                              autoSize={{minRows: 3, maxRows: 3}}
                                              placeholder={"Your message"}
                                              onChange={e => setMsg(e.target.value)}
                                              style={{width: "100%"}}
                                              value={msg}
                                    />
                                </Row>
                                {pushMode === "MULTICAST" &&
                                    <>
                                        <Divider/>
                                        <Select
                                            mode="multiple"
                                            maxTagCount="responsive"
                                            placeholder="Select users..."
                                            options={connectedUsersOptions()}
                                            value={users}
                                            style={{width: "100%"}}
                                            onChange={(selectedUsers: string[]) => {setUsers(selectedUsers);}}
                                        />
                                    </>
                                }
                            </Card>

                        </Col>


                        <Col span={8} style={{display: 'flex'}}>
                            <Code
                                title={"Send message"} code={"send"}
                                onClick={doSend}
                                disabled={!connected}
                            />
                        </Col>
                    </Row>

                    <Row style={{width: "100%"}} gutter={16}>
                        <Col span={8} style={{display: 'flex'}}>
                            <Code
                                title={"Connect"} code={"connect"}
                                onClick={doConnect}
                                loading={connecting}
                                disabled={connected}
                            />
                        </Col>

                        <Col span={8} style={{display: 'flex'}}>
                            <Code
                                title={"Disconnect"} code={"disconnect"}
                                onClick={doDisconnect}
                                disabled={!connected}
                            />
                        </Col>

                        <Col span={8} style={{display: 'flex'}}>
                            <Code
                                title={"Ping"} code={"ping"}
                                onClick={doPing}
                                disabled={!connected}
                            />
                        </Col>
                    </Row>
                </Row>
            </Content>
        </Layout>
    </>)

}

export default App;

What's next?

Share

Throughout this article we've seen how we can leverage AWS serverless technologies to build a reusable websocket stack.

We've also built a realtime and reactive demo web application that leverages SUMU and implement all the features provided by SUMU.

We didn't tell you how we've deployed the demo client web application, so this will be the topic for our the next article (It's not deployed on Netlify πŸ˜‰).

Share if you like the article and stay tuned for the next article!

Share article

More articles