|
| 1 | +from unittest.case import skipIf |
| 2 | + |
| 3 | +from integration.config.service_names import WEBSOCKET_API |
| 4 | +from integration.helpers.base_test import BaseTest |
| 5 | +from integration.helpers.resource import current_region_does_not_support |
| 6 | + |
| 7 | + |
| 8 | +@skipIf(current_region_does_not_support([WEBSOCKET_API]), "WebSocketApi is not supported in this region") |
| 9 | +class TestWebSocketApiWithAuth(BaseTest): |
| 10 | + |
| 11 | + def test_websocket_api_with_iam_auth(self): |
| 12 | + """ |
| 13 | + Creates a WebSocket API with an IAM authorizer |
| 14 | + """ |
| 15 | + self.create_and_verify_stack("combination/websocket_api_with_iam_auth") |
| 16 | + |
| 17 | + websocket_api_list = self.get_stack_resources("AWS::ApiGatewayV2::Api") |
| 18 | + self.assertEqual(len(websocket_api_list), 1) |
| 19 | + |
| 20 | + stages = self.get_api_v2_stack_stages() |
| 21 | + |
| 22 | + self.assertEqual(len(stages), 1) |
| 23 | + self.assertEqual(stages[0]["StageName"], "default") |
| 24 | + |
| 25 | + websocket_resource = websocket_api_list[0] |
| 26 | + websocket_api_id = websocket_resource["PhysicalResourceId"] |
| 27 | + api_v2_client = self.client_provider.api_v2_client |
| 28 | + routes_list = api_v2_client.get_routes(ApiId=websocket_api_id)["Items"] |
| 29 | + route = routes_list[0] |
| 30 | + self.assertEqual(route["AuthorizationType"], "AWS_IAM") |
| 31 | + |
| 32 | + def test_none_auth(self): |
| 33 | + self.create_and_verify_stack("combination/websocket_api_with_none_auth") |
| 34 | + |
| 35 | + websocket_api_list = self.get_stack_resources("AWS::ApiGatewayV2::Api") |
| 36 | + self.assertEqual(len(websocket_api_list), 1) |
| 37 | + |
| 38 | + websocket_resource = websocket_api_list[0] |
| 39 | + websocket_api_id = websocket_resource["PhysicalResourceId"] |
| 40 | + api_v2_client = self.client_provider.api_v2_client |
| 41 | + routes_list = api_v2_client.get_routes(ApiId=websocket_api_id)["Items"] |
| 42 | + route = routes_list[0] |
| 43 | + self.assertEqual(route["AuthorizationType"], "NONE") |
| 44 | + |
| 45 | + def test_websocket_api_with_lambda_auth_config(self): |
| 46 | + """ |
| 47 | + Creates a WebSocket API with a Lambda authorizer |
| 48 | + """ |
| 49 | + self.create_and_verify_stack("combination/websocket_api_with_lambda_auth") |
| 50 | + |
| 51 | + websocket_api_list = self.get_stack_resources("AWS::ApiGatewayV2::Api") |
| 52 | + self.assertEqual(len(websocket_api_list), 1) |
| 53 | + |
| 54 | + websocket_resource = websocket_api_list[0] |
| 55 | + websocket_api_id = websocket_resource["PhysicalResourceId"] |
| 56 | + api_v2_client = self.client_provider.api_v2_client |
| 57 | + |
| 58 | + route_list = api_v2_client.get_routes(ApiId=websocket_api_id)["Items"] |
| 59 | + self.assertEqual(len(route_list), 1) |
| 60 | + route = route_list[0] |
| 61 | + self.assertEqual(route["AuthorizationType"], "CUSTOM") |
| 62 | + self.assertIsNotNone(route["AuthorizerId"]) |
| 63 | + |
| 64 | + authorizer_list = api_v2_client.get_authorizers(ApiId=websocket_api_id)["Items"] |
| 65 | + self.assertEqual(len(authorizer_list), 1) |
| 66 | + lambda_auth = authorizer_list[0] |
| 67 | + # Not sure this is returning properly either |
| 68 | + self.assertEqual(lambda_auth["AuthorizerType"], "REQUEST") |
| 69 | + # Verify authorizer URI contains expected components |
| 70 | + authorizer_uri = lambda_auth["AuthorizerUri"] |
| 71 | + self.assertIn("lambda:path/2015-03-31/functions", authorizer_uri) |
| 72 | + self.assertIn("MyAuthFn", authorizer_uri) |
| 73 | + self.assertIn("/invocations", authorizer_uri) |
| 74 | + |
| 75 | + # Same authorizer coming from the route and from the authorizers |
| 76 | + self.assertEqual(route["AuthorizerId"], lambda_auth["AuthorizerId"]) |
0 commit comments