From 84879ee3bdf52e3a8bd3c78f3615429a249d1e40 Mon Sep 17 00:00:00 2001 From: Sangeetha Date: Mon, 28 Jul 2025 17:41:02 +0530 Subject: [PATCH] [WEB-4533] feat: read replica functionality (#7453) * feat: read replica functionality * fix: set use_read_replica to false * chore: add use_read_replica to external APIs * chore: remove use_read_replica on read endpoints * chore: remove md files * Updated all the necessary endpoints to use read replica --------- Co-authored-by: Dheeraj Kumar Ketireddy --- apps/api/plane/api/views/asset.py | 2 + apps/api/plane/api/views/base.py | 7 +- apps/api/plane/api/views/cycle.py | 5 + apps/api/plane/api/views/intake.py | 3 + apps/api/plane/api/views/issue.py | 15 + apps/api/plane/api/views/member.py | 6 +- apps/api/plane/api/views/module.py | 5 + apps/api/plane/api/views/project.py | 2 + apps/api/plane/api/views/state.py | 2 + apps/api/plane/app/views/base.py | 9 +- apps/api/plane/app/views/issue/activity.py | 1 + apps/api/plane/app/views/notification/base.py | 2 + apps/api/plane/app/views/project/base.py | 1 + apps/api/plane/app/views/project/member.py | 1 + apps/api/plane/app/views/user/base.py | 1 + apps/api/plane/app/views/workspace/base.py | 1 + .../api/plane/app/views/workspace/estimate.py | 1 + .../api/plane/app/views/workspace/favorite.py | 2 + apps/api/plane/app/views/workspace/label.py | 1 + apps/api/plane/app/views/workspace/member.py | 3 + .../plane/app/views/workspace/quick_link.py | 1 + .../plane/app/views/workspace/recent_visit.py | 1 + apps/api/plane/app/views/workspace/state.py | 1 + apps/api/plane/app/views/workspace/sticky.py | 1 + .../app/views/workspace/user_preference.py | 1 + apps/api/plane/middleware/db_routing.py | 162 +++++++ apps/api/plane/settings/common.py | 23 + .../plane/tests/unit/middleware/__init__.py | 0 .../tests/unit/middleware/test_db_routing.py | 433 ++++++++++++++++++ apps/api/plane/utils/core/__init__.py | 21 + apps/api/plane/utils/core/dbrouters.py | 73 +++ apps/api/plane/utils/core/mixins/__init__.py | 11 + apps/api/plane/utils/core/mixins/view.py | 20 + apps/api/plane/utils/core/request_scope.py | 72 +++ 34 files changed, 884 insertions(+), 6 deletions(-) create mode 100644 apps/api/plane/middleware/db_routing.py create mode 100644 apps/api/plane/tests/unit/middleware/__init__.py create mode 100644 apps/api/plane/tests/unit/middleware/test_db_routing.py create mode 100644 apps/api/plane/utils/core/__init__.py create mode 100644 apps/api/plane/utils/core/dbrouters.py create mode 100644 apps/api/plane/utils/core/mixins/__init__.py create mode 100644 apps/api/plane/utils/core/mixins/view.py create mode 100644 apps/api/plane/utils/core/request_scope.py diff --git a/apps/api/plane/api/views/asset.py b/apps/api/plane/api/views/asset.py index 061a79010..2e668c15d 100644 --- a/apps/api/plane/api/views/asset.py +++ b/apps/api/plane/api/views/asset.py @@ -405,6 +405,8 @@ class UserServerAssetEndpoint(BaseAPIView): class GenericAssetEndpoint(BaseAPIView): """This endpoint is used to upload generic assets that can be later bound to entities.""" + use_read_replica = True + @asset_docs( operation_id="get_generic_asset", summary="Get presigned URL for asset download", diff --git a/apps/api/plane/api/views/base.py b/apps/api/plane/api/views/base.py index a4c14cf0d..ea5bcba02 100644 --- a/apps/api/plane/api/views/base.py +++ b/apps/api/plane/api/views/base.py @@ -20,6 +20,7 @@ from plane.api.middleware.api_authentication import APIKeyAuthentication from plane.api.rate_limit import ApiKeyRateThrottle, ServiceTokenRateThrottle from plane.utils.exception_logger import log_exception from plane.utils.paginator import BasePaginator +from plane.utils.core.mixins import ReadReplicaControlMixin class TimezoneMixin: @@ -36,11 +37,15 @@ class TimezoneMixin: timezone.deactivate() -class BaseAPIView(TimezoneMixin, GenericAPIView, BasePaginator): +class BaseAPIView( + TimezoneMixin, GenericAPIView, ReadReplicaControlMixin, BasePaginator +): authentication_classes = [APIKeyAuthentication] permission_classes = [IsAuthenticated] + use_read_replica = False + def filter_queryset(self, queryset): for backend in list(self.filter_backends): queryset = backend().filter_queryset(self.request, queryset, self) diff --git a/apps/api/plane/api/views/cycle.py b/apps/api/plane/api/views/cycle.py index e7a7b8fcc..e10d3d16e 100644 --- a/apps/api/plane/api/views/cycle.py +++ b/apps/api/plane/api/views/cycle.py @@ -86,6 +86,7 @@ class CycleListCreateAPIEndpoint(BaseAPIView): model = Cycle webhook_event = "cycle" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -373,6 +374,7 @@ class CycleDetailAPIEndpoint(BaseAPIView): model = Cycle webhook_event = "cycle" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -633,6 +635,7 @@ class CycleArchiveUnarchiveAPIEndpoint(BaseAPIView): """Cycle Archive and Unarchive Endpoint""" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -831,6 +834,7 @@ class CycleIssueListCreateAPIEndpoint(BaseAPIView): model = CycleIssue webhook_event = "cycle_issue" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -1045,6 +1049,7 @@ class CycleIssueDetailAPIEndpoint(BaseAPIView): webhook_event = "cycle_issue" bulk = True permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( diff --git a/apps/api/plane/api/views/intake.py b/apps/api/plane/api/views/intake.py index 3ee977d2a..1ea9c73fd 100644 --- a/apps/api/plane/api/views/intake.py +++ b/apps/api/plane/api/views/intake.py @@ -51,8 +51,10 @@ class IntakeIssueListCreateAPIEndpoint(BaseAPIView): """Intake Work Item List and Create Endpoint""" serializer_class = IntakeIssueSerializer + model = Intake permission_classes = [ProjectLitePermission] + use_read_replica = True def get_queryset(self): intake = Intake.objects.filter( @@ -214,6 +216,7 @@ class IntakeIssueDetailAPIEndpoint(BaseAPIView): serializer_class = IntakeIssueSerializer model = IntakeIssue + use_read_replica = True filterset_fields = ["status"] diff --git a/apps/api/plane/api/views/issue.py b/apps/api/plane/api/views/issue.py index 5ae15ea2e..6b7256dd1 100644 --- a/apps/api/plane/api/views/issue.py +++ b/apps/api/plane/api/views/issue.py @@ -156,6 +156,7 @@ class WorkspaceIssueAPIEndpoint(BaseAPIView): webhook_event = "issue" permission_classes = [ProjectEntityPermission] serializer_class = IssueSerializer + use_read_replica = True @property def project_identifier(self): @@ -231,6 +232,7 @@ class IssueListCreateAPIEndpoint(BaseAPIView): webhook_event = "issue" permission_classes = [ProjectEntityPermission] serializer_class = IssueSerializer + use_read_replica = True def get_queryset(self): return ( @@ -495,6 +497,7 @@ class IssueDetailAPIEndpoint(BaseAPIView): webhook_event = "issue" permission_classes = [ProjectEntityPermission] serializer_class = IssueSerializer + use_read_replica = True def get_queryset(self): return ( @@ -822,6 +825,7 @@ class LabelListCreateAPIEndpoint(BaseAPIView): serializer_class = LabelSerializer model = Label permission_classes = [ProjectMemberPermission] + use_read_replica = True def get_queryset(self): return ( @@ -949,6 +953,7 @@ class LabelDetailAPIEndpoint(BaseAPIView): serializer_class = LabelSerializer model = Label permission_classes = [ProjectMemberPermission] + use_read_replica = True @label_docs( operation_id="get_labels", @@ -1057,6 +1062,7 @@ class IssueLinkListCreateAPIEndpoint(BaseAPIView): serializer_class = IssueLinkSerializer model = IssueLink permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -1163,6 +1169,7 @@ class IssueLinkDetailAPIEndpoint(BaseAPIView): model = IssueLink serializer_class = IssueLinkSerializer + use_read_replica = True def get_queryset(self): return ( @@ -1319,6 +1326,7 @@ class IssueCommentListCreateAPIEndpoint(BaseAPIView): model = IssueComment webhook_event = "issue_comment" permission_classes = [ProjectLitePermission] + use_read_replica = True def get_queryset(self): return ( @@ -1477,6 +1485,7 @@ class IssueCommentDetailAPIEndpoint(BaseAPIView): model = IssueComment webhook_event = "issue_comment" permission_classes = [ProjectLitePermission] + use_read_replica = True def get_queryset(self): return ( @@ -1657,6 +1666,7 @@ class IssueCommentDetailAPIEndpoint(BaseAPIView): class IssueActivityListAPIEndpoint(BaseAPIView): permission_classes = [ProjectEntityPermission] + use_read_replica = True @issue_activity_docs( operation_id="list_work_item_activities", @@ -1712,6 +1722,7 @@ class IssueActivityDetailAPIEndpoint(BaseAPIView): """Issue Activity Detail Endpoint""" permission_classes = [ProjectEntityPermission] + use_read_replica = True @issue_activity_docs( operation_id="retrieve_work_item_activity", @@ -1770,6 +1781,7 @@ class IssueAttachmentListCreateAPIEndpoint(BaseAPIView): serializer_class = IssueAttachmentSerializer model = FileAsset permission_classes = [ProjectEntityPermission] + use_read_replica = True @issue_attachment_docs( operation_id="create_work_item_attachment", @@ -1977,6 +1989,7 @@ class IssueAttachmentDetailAPIEndpoint(BaseAPIView): serializer_class = IssueAttachmentSerializer permission_classes = [ProjectEntityPermission] model = FileAsset + use_read_replica = True @issue_attachment_docs( operation_id="delete_work_item_attachment", @@ -2146,6 +2159,8 @@ class IssueAttachmentDetailAPIEndpoint(BaseAPIView): class IssueSearchEndpoint(BaseAPIView): """Endpoint to search across multiple fields in the issues""" + use_read_replica = True + @extend_schema( operation_id="search_work_items", tags=["Work Items"], diff --git a/apps/api/plane/api/views/member.py b/apps/api/plane/api/views/member.py index a6d7176d7..8ae662520 100644 --- a/apps/api/plane/api/views/member.py +++ b/apps/api/plane/api/views/member.py @@ -24,9 +24,8 @@ from plane.utils.openapi import ( class WorkspaceMemberAPIEndpoint(BaseAPIView): - permission_classes = [ - WorkSpaceAdminPermission, - ] + permission_classes = [WorkSpaceAdminPermission] + use_read_replica = True @extend_schema( operation_id="get_workspace_members", @@ -92,6 +91,7 @@ class WorkspaceMemberAPIEndpoint(BaseAPIView): # API endpoint to get and insert users inside the workspace class ProjectMemberAPIEndpoint(BaseAPIView): permission_classes = [ProjectMemberPermission] + use_read_replica = True @extend_schema( operation_id="get_project_members", diff --git a/apps/api/plane/api/views/module.py b/apps/api/plane/api/views/module.py index e0392dfba..63112cd66 100644 --- a/apps/api/plane/api/views/module.py +++ b/apps/api/plane/api/views/module.py @@ -80,6 +80,7 @@ class ModuleListCreateAPIEndpoint(BaseAPIView): model = Module webhook_event = "module" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -282,6 +283,7 @@ class ModuleDetailAPIEndpoint(BaseAPIView): permission_classes = [ProjectEntityPermission] serializer_class = ModuleSerializer webhook_event = "module" + use_read_replica = True def get_queryset(self): return ( @@ -550,6 +552,7 @@ class ModuleIssueListCreateAPIEndpoint(BaseAPIView): model = ModuleIssue webhook_event = "module_issue" permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -769,6 +772,7 @@ class ModuleIssueDetailAPIEndpoint(BaseAPIView): model = ModuleIssue webhook_event = "module_issue" bulk = True + use_read_replica = True permission_classes = [ProjectEntityPermission] @@ -916,6 +920,7 @@ class ModuleIssueDetailAPIEndpoint(BaseAPIView): class ModuleArchiveUnarchiveAPIEndpoint(BaseAPIView): permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( diff --git a/apps/api/plane/api/views/project.py b/apps/api/plane/api/views/project.py index b89129a7f..da946e3c3 100644 --- a/apps/api/plane/api/views/project.py +++ b/apps/api/plane/api/views/project.py @@ -67,6 +67,7 @@ class ProjectListCreateAPIEndpoint(BaseAPIView): model = Project webhook_event = "project" permission_classes = [ProjectBasePermission] + use_read_replica = True def get_queryset(self): return ( @@ -331,6 +332,7 @@ class ProjectDetailAPIEndpoint(BaseAPIView): webhook_event = "project" permission_classes = [ProjectBasePermission] + use_read_replica = True def get_queryset(self): return ( diff --git a/apps/api/plane/api/views/state.py b/apps/api/plane/api/views/state.py index 327c6c890..7b5d842de 100644 --- a/apps/api/plane/api/views/state.py +++ b/apps/api/plane/api/views/state.py @@ -38,6 +38,7 @@ class StateListCreateAPIEndpoint(BaseAPIView): serializer_class = StateSerializer model = State permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( @@ -164,6 +165,7 @@ class StateDetailAPIEndpoint(BaseAPIView): serializer_class = StateSerializer model = State permission_classes = [ProjectEntityPermission] + use_read_replica = True def get_queryset(self): return ( diff --git a/apps/api/plane/app/views/base.py b/apps/api/plane/app/views/base.py index 92c374966..4cefb75a1 100644 --- a/apps/api/plane/app/views/base.py +++ b/apps/api/plane/app/views/base.py @@ -24,6 +24,7 @@ from rest_framework.viewsets import ModelViewSet from plane.authentication.session import BaseSessionAuthentication from plane.utils.exception_logger import log_exception from plane.utils.paginator import BasePaginator +from plane.utils.core.mixins import ReadReplicaControlMixin class TimezoneMixin: @@ -40,7 +41,7 @@ class TimezoneMixin: timezone.deactivate() -class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator): +class BaseViewSet(TimezoneMixin, ReadReplicaControlMixin, ModelViewSet, BasePaginator): model = None permission_classes = [IsAuthenticated] @@ -53,6 +54,8 @@ class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator): search_fields = [] + use_read_replica = False + def get_queryset(self): try: return self.model.objects.all() @@ -149,7 +152,7 @@ class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator): return expand if expand else None -class BaseAPIView(TimezoneMixin, APIView, BasePaginator): +class BaseAPIView(TimezoneMixin, ReadReplicaControlMixin, APIView, BasePaginator): permission_classes = [IsAuthenticated] filter_backends = (DjangoFilterBackend, SearchFilter) @@ -160,6 +163,8 @@ class BaseAPIView(TimezoneMixin, APIView, BasePaginator): search_fields = [] + use_read_replica = False + def filter_queryset(self, queryset): for backend in list(self.filter_backends): queryset = backend().filter_queryset(self.request, queryset, self) diff --git a/apps/api/plane/app/views/issue/activity.py b/apps/api/plane/app/views/issue/activity.py index 91b973f11..b9ef58ffd 100644 --- a/apps/api/plane/app/views/issue/activity.py +++ b/apps/api/plane/app/views/issue/activity.py @@ -19,6 +19,7 @@ from plane.db.models import IssueActivity, IssueComment, CommentReaction, Intake class IssueActivityEndpoint(BaseAPIView): permission_classes = [ProjectEntityPermission] + use_read_replica = True @method_decorator(gzip_page) @allow_permission([ROLE.ADMIN, ROLE.MEMBER, ROLE.GUEST]) diff --git a/apps/api/plane/app/views/notification/base.py b/apps/api/plane/app/views/notification/base.py index e84cf4d29..329599c15 100644 --- a/apps/api/plane/app/views/notification/base.py +++ b/apps/api/plane/app/views/notification/base.py @@ -232,6 +232,8 @@ class NotificationViewSet(BaseViewSet, BasePaginator): class UnreadNotificationEndpoint(BaseAPIView): + use_read_replica = True + @allow_permission( allowed_roles=[ROLE.ADMIN, ROLE.MEMBER, ROLE.GUEST], level="WORKSPACE" ) diff --git a/apps/api/plane/app/views/project/base.py b/apps/api/plane/app/views/project/base.py index 1da2aa84b..b4ee113c4 100644 --- a/apps/api/plane/app/views/project/base.py +++ b/apps/api/plane/app/views/project/base.py @@ -46,6 +46,7 @@ class ProjectViewSet(BaseViewSet): serializer_class = ProjectListSerializer model = Project webhook_event = "project" + use_read_replica = True def get_queryset(self): sort_order = ProjectMember.objects.filter( diff --git a/apps/api/plane/app/views/project/member.py b/apps/api/plane/app/views/project/member.py index 60d960fe5..0b09c1366 100644 --- a/apps/api/plane/app/views/project/member.py +++ b/apps/api/plane/app/views/project/member.py @@ -312,6 +312,7 @@ class ProjectMemberUserEndpoint(BaseAPIView): class UserProjectRolesEndpoint(BaseAPIView): permission_classes = [WorkspaceUserPermission] + use_read_replica = True def get(self, request, slug): project_members = ProjectMember.objects.filter( diff --git a/apps/api/plane/app/views/user/base.py b/apps/api/plane/app/views/user/base.py index 4eca872f3..08389d50c 100644 --- a/apps/api/plane/app/views/user/base.py +++ b/apps/api/plane/app/views/user/base.py @@ -44,6 +44,7 @@ from django.views.decorators.vary import vary_on_cookie class UserEndpoint(BaseViewSet): serializer_class = UserSerializer model = User + use_read_replica = True def get_object(self): return self.request.user diff --git a/apps/api/plane/app/views/workspace/base.py b/apps/api/plane/app/views/workspace/base.py index 922b39cc9..a37624d2a 100644 --- a/apps/api/plane/app/views/workspace/base.py +++ b/apps/api/plane/app/views/workspace/base.py @@ -177,6 +177,7 @@ class WorkSpaceViewSet(BaseViewSet): class UserWorkSpacesEndpoint(BaseAPIView): search_fields = ["name"] filterset_fields = ["owner"] + use_read_replica = True def get(self, request): fields = [field for field in request.GET.get("fields", "").split(",") if field] diff --git a/apps/api/plane/app/views/workspace/estimate.py b/apps/api/plane/app/views/workspace/estimate.py index beef2a8ec..8b0981f9e 100644 --- a/apps/api/plane/app/views/workspace/estimate.py +++ b/apps/api/plane/app/views/workspace/estimate.py @@ -12,6 +12,7 @@ from plane.utils.cache import cache_response class WorkspaceEstimatesEndpoint(BaseAPIView): permission_classes = [WorkspaceEntityPermission] + use_read_replica = True @cache_response(60 * 60 * 2) def get(self, request, slug): diff --git a/apps/api/plane/app/views/workspace/favorite.py b/apps/api/plane/app/views/workspace/favorite.py index ad2f24883..ee126fa5b 100644 --- a/apps/api/plane/app/views/workspace/favorite.py +++ b/apps/api/plane/app/views/workspace/favorite.py @@ -14,6 +14,8 @@ from plane.app.permissions import allow_permission, ROLE class WorkspaceFavoriteEndpoint(BaseAPIView): + use_read_replica = True + @allow_permission(allowed_roles=[ROLE.ADMIN, ROLE.MEMBER], level="WORKSPACE") def get(self, request, slug): # the second filter is to check if the user is a member of the project diff --git a/apps/api/plane/app/views/workspace/label.py b/apps/api/plane/app/views/workspace/label.py index c93cd44c8..11ca6b913 100644 --- a/apps/api/plane/app/views/workspace/label.py +++ b/apps/api/plane/app/views/workspace/label.py @@ -12,6 +12,7 @@ from plane.utils.cache import cache_response class WorkspaceLabelsEndpoint(BaseAPIView): permission_classes = [WorkspaceViewerPermission] + use_read_replica = True @cache_response(60 * 60 * 2) def get(self, request, slug): diff --git a/apps/api/plane/app/views/workspace/member.py b/apps/api/plane/app/views/workspace/member.py index 7743ff4cd..84985cec3 100644 --- a/apps/api/plane/app/views/workspace/member.py +++ b/apps/api/plane/app/views/workspace/member.py @@ -28,6 +28,7 @@ class WorkSpaceMemberViewSet(BaseViewSet): model = WorkspaceMember search_fields = ["member__display_name", "member__first_name"] + use_read_replica = True def get_queryset(self): return self.filter_queryset( @@ -214,6 +215,8 @@ class WorkspaceMemberUserViewsEndpoint(BaseAPIView): class WorkspaceMemberUserEndpoint(BaseAPIView): + use_read_replica = True + def get(self, request, slug): draft_issue_count = ( DraftIssue.objects.filter( diff --git a/apps/api/plane/app/views/workspace/quick_link.py b/apps/api/plane/app/views/workspace/quick_link.py index b7decea95..104ca00d2 100644 --- a/apps/api/plane/app/views/workspace/quick_link.py +++ b/apps/api/plane/app/views/workspace/quick_link.py @@ -11,6 +11,7 @@ from plane.app.permissions import allow_permission, ROLE class QuickLinkViewSet(BaseViewSet): model = WorkspaceUserLink + use_read_replica = True def get_serializer_class(self): return WorkspaceUserLinkSerializer diff --git a/apps/api/plane/app/views/workspace/recent_visit.py b/apps/api/plane/app/views/workspace/recent_visit.py index 4fe15b513..e1c50c8b6 100644 --- a/apps/api/plane/app/views/workspace/recent_visit.py +++ b/apps/api/plane/app/views/workspace/recent_visit.py @@ -12,6 +12,7 @@ from plane.app.permissions import allow_permission, ROLE class UserRecentVisitViewSet(BaseViewSet): model = UserRecentVisit + use_read_replica = True def get_serializer_class(self): return WorkspaceRecentVisitSerializer diff --git a/apps/api/plane/app/views/workspace/state.py b/apps/api/plane/app/views/workspace/state.py index 08bc2be28..3a7d767fa 100644 --- a/apps/api/plane/app/views/workspace/state.py +++ b/apps/api/plane/app/views/workspace/state.py @@ -13,6 +13,7 @@ from collections import defaultdict class WorkspaceStatesEndpoint(BaseAPIView): permission_classes = [WorkspaceEntityPermission] + use_read_replica = True @cache_response(60 * 60 * 2) def get(self, request, slug): diff --git a/apps/api/plane/app/views/workspace/sticky.py b/apps/api/plane/app/views/workspace/sticky.py index 4870a6abe..8b9654716 100644 --- a/apps/api/plane/app/views/workspace/sticky.py +++ b/apps/api/plane/app/views/workspace/sticky.py @@ -12,6 +12,7 @@ from plane.app.serializers import StickySerializer class WorkspaceStickyViewSet(BaseViewSet): serializer_class = StickySerializer model = Sticky + use_read_replica = True def get_queryset(self): return self.filter_queryset( diff --git a/apps/api/plane/app/views/workspace/user_preference.py b/apps/api/plane/app/views/workspace/user_preference.py index 7cfa740e8..8bcf6b309 100644 --- a/apps/api/plane/app/views/workspace/user_preference.py +++ b/apps/api/plane/app/views/workspace/user_preference.py @@ -13,6 +13,7 @@ from rest_framework import status class WorkspaceUserPreferenceViewSet(BaseAPIView): model = WorkspaceUserPreference + use_read_replica = True def get_serializer_class(self): return WorkspaceUserPreferenceSerializer diff --git a/apps/api/plane/middleware/db_routing.py b/apps/api/plane/middleware/db_routing.py new file mode 100644 index 000000000..dc7ff3fa3 --- /dev/null +++ b/apps/api/plane/middleware/db_routing.py @@ -0,0 +1,162 @@ +""" +Database routing middleware for read replica selection. +This middleware determines whether database queries should be routed to +read replicas or the primary database based on HTTP method and view configuration. +""" + +import logging +from typing import Callable, Optional + +from django.http import HttpRequest, HttpResponse + +from plane.utils.core import ( + set_use_read_replica, + clear_read_replica_context, +) + +logger = logging.getLogger("plane.api") + + +class ReadReplicaRoutingMiddleware: + """ + Middleware for intelligent database routing to read replicas. + Routing Logic: + • Non-GET requests (POST, PUT, DELETE, PATCH) ➜ Primary database + • GET requests: + - View has use_read_replica=False ➜ Primary database + - View has use_read_replica=True ➜ Read replica + - View has no use_read_replica attribute ➜ Primary database (safe default) + The middleware supports both Django CBVs and DRF APIViews/ViewSets. + Context is properly isolated per request to prevent data leakage. + """ + + # HTTP methods that are considered read-only by default + READ_ONLY_METHODS = {"GET", "HEAD", "OPTIONS"} + + def __init__(self, get_response): + """ + Initialize the middleware with the next middleware/view in the chain. + Args: + get_response: The next middleware or view function + """ + self.get_response = get_response + + def __call__(self, request: HttpRequest) -> HttpResponse: + """ + Process the request and determine database routing. + Args: + request: The HTTP request object + Returns: + HttpResponse: The HTTP response from the view + """ + # For non-read operations, set primary database immediately + if request.method not in self.READ_ONLY_METHODS: + set_use_read_replica(False) + logger.debug(f"Routing {request.method} {request.path} to primary database") + + try: + # Process the request through the middleware chain + response = self.get_response(request) + return response + finally: + # Always clean up context, even if an exception occurs + # This prevents context leakage between requests + clear_read_replica_context() + + def process_view( + self, + request: HttpRequest, + view_func: Callable, + view_args: tuple, + view_kwargs: dict, + ) -> None: + """ + Hook called just before Django calls the view. + This is more efficient than resolving URLs in __call__ since Django + provides the view function directly. + Args: + request: The HTTP request object + view_func: The view function to be called + view_args: Positional arguments for the view + view_kwargs: Keyword arguments for the view + """ + # Only process read operations (write operations already handled in __call__) + if request.method in self.READ_ONLY_METHODS: + use_replica = self._should_use_read_replica(view_func) + set_use_read_replica(use_replica) + + db_type = "read replica" if use_replica else "primary database" + logger.debug(f"Routing {request.method} {request.path} to {db_type}") + + # Return None to continue normal request processing + return None + + def _should_use_read_replica(self, view_func: Callable) -> bool: + """ + Determine if the view should use read replica based on its configuration. + Args: + view_func: The view function to inspect + Returns: + bool: True if should use read replica, False for primary database + """ + use_replica_attr = self._get_use_replica_attribute(view_func) + + # Default to primary database for GET requests if no explicit setting + # This ensures only views that explicitly opt-in use read replicas + if use_replica_attr is None: + return False + + return bool(use_replica_attr) + + def _get_use_replica_attribute(self, view_func: Callable) -> Optional[bool]: + """ + Extract the use_read_replica attribute from various view types. + Args: + view_func: The view function to inspect + Returns: + Optional[bool]: The use_read_replica setting, or None if not found + """ + # Return None if view_func is None to prevent AttributeError + if view_func is None: + return None + + # Check function-based view attribute + use_replica = getattr(view_func, "use_read_replica", None) + if use_replica is not None: + return use_replica + + # Check Django CBV wrapper + if hasattr(view_func, "view_class"): + use_replica = getattr(view_func.view_class, "use_read_replica", None) + if use_replica is not None: + return use_replica + + # Check DRF wrapper (APIView / ViewSet) + if hasattr(view_func, "cls"): + use_replica = getattr(view_func.cls, "use_read_replica", None) + if use_replica is not None: + return use_replica + + return None + + def process_exception(self, request: HttpRequest, exception: Exception) -> None: + """ + Handle exceptions that occur during view processing. + This provides an additional safety net for context cleanup when views + raise exceptions, complementing the try/finally in __call__. + Args: + request: The HTTP request object + exception: The exception that was raised + Returns: + None: Don't handle the exception, just clean up context + """ + # Clean up context on exception as a safety measure + # The try/finally in __call__ should handle most cases, but this + # provides extra protection specifically for view exceptions + clear_read_replica_context() + logger.debug( + f"Cleaned up read replica context due to exception: {type(exception).__name__}" + ) + + # Return None to let the exception continue propagating + return None diff --git a/apps/api/plane/settings/common.py b/apps/api/plane/settings/common.py index 8d59f8192..efe3c1496 100644 --- a/apps/api/plane/settings/common.py +++ b/apps/api/plane/settings/common.py @@ -149,6 +149,29 @@ else: } } + +if os.environ.get("ENABLE_READ_REPLICA", "0") == "1": + if bool(os.environ.get("DATABASE_READ_REPLICA_URL")): + # Parse database configuration from $DATABASE_URL + DATABASES["replica"] = dj_database_url.parse( + os.environ.get("DATABASE_READ_REPLICA_URL") + ) + else: + DATABASES["replica"] = { + "ENGINE": "django.db.backends.postgresql", + "NAME": os.environ.get("POSTGRES_READ_REPLICA_DB"), + "USER": os.environ.get("POSTGRES_READ_REPLICA_USER"), + "PASSWORD": os.environ.get("POSTGRES_READ_REPLICA_PASSWORD"), + "HOST": os.environ.get("POSTGRES_READ_REPLICA_HOST"), + "PORT": os.environ.get("POSTGRES_READ_REPLICA_PORT", "5432"), + } + + # Database Routers + DATABASE_ROUTERS = ["plane.utils.core.dbrouters.ReadReplicaRouter"] + # Add middleware at the end for read replica routing + MIDDLEWARE.append("plane.middleware.db_routing.ReadReplicaRoutingMiddleware") + + # Redis Config REDIS_URL = os.environ.get("REDIS_URL") REDIS_SSL = REDIS_URL and "rediss" in REDIS_URL diff --git a/apps/api/plane/tests/unit/middleware/__init__.py b/apps/api/plane/tests/unit/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/api/plane/tests/unit/middleware/test_db_routing.py b/apps/api/plane/tests/unit/middleware/test_db_routing.py new file mode 100644 index 000000000..73f222140 --- /dev/null +++ b/apps/api/plane/tests/unit/middleware/test_db_routing.py @@ -0,0 +1,433 @@ +""" +Unit tests for ReadReplicaRoutingMiddleware. +This module contains comprehensive tests for the ReadReplicaRoutingMiddleware +that handles intelligent database routing to read replicas based on HTTP methods +and view configuration. +Test Organization: +- TestReadReplicaRoutingMiddleware: Core middleware functionality +- TestProcessView: process_view method behavior +- TestReplicaDecisionLogic: Decision logic for replica usage +- TestAttributeDetection: View attribute detection methods +- TestExceptionHandling: Exception handling and cleanup +- TestRealViewIntegration: Real Django/DRF view integration +- TestEdgeCases: Edge cases and error conditions +""" + +import pytest +from unittest.mock import Mock, patch + +from django.http import HttpResponse +from django.test import RequestFactory +from django.views import View +from rest_framework.views import APIView +from rest_framework.viewsets import ViewSet + +from plane.middleware.db_routing import ReadReplicaRoutingMiddleware + + +# Pytest fixtures +@pytest.fixture +def mock_get_response(): + """Fixture for mocked get_response callable.""" + return Mock(return_value=HttpResponse()) + + +@pytest.fixture +def middleware(mock_get_response): + """Fixture for ReadReplicaRoutingMiddleware instance.""" + return ReadReplicaRoutingMiddleware(mock_get_response) + + +@pytest.fixture +def request_factory(): + """Fixture for Django RequestFactory.""" + return RequestFactory() + + +@pytest.fixture +def mock_view_func(): + """Fixture for a basic mocked view function.""" + view = Mock() + view.use_read_replica = True + return view + + +@pytest.fixture +def get_request(request_factory): + """Fixture for a GET request.""" + return request_factory.get("/api/test/") + + +@pytest.fixture +def post_request(request_factory): + """Fixture for a POST request.""" + return request_factory.post("/api/test/") + + +@pytest.mark.unit +class TestReadReplicaRoutingMiddleware: + """Test cases for ReadReplicaRoutingMiddleware core functionality.""" + + def test_middleware_initialization(self, middleware, mock_get_response): + """Test middleware initializes correctly with expected attributes.""" + assert middleware.get_response == mock_get_response + assert hasattr(middleware, "READ_ONLY_METHODS") + assert "GET" in middleware.READ_ONLY_METHODS + assert "HEAD" in middleware.READ_ONLY_METHODS + assert "OPTIONS" in middleware.READ_ONLY_METHODS + + def test_read_only_methods_constant(self, middleware): + """Test READ_ONLY_METHODS contains expected HTTP methods.""" + expected_methods = {"GET", "HEAD", "OPTIONS"} + assert middleware.READ_ONLY_METHODS == expected_methods + + @patch("plane.middleware.db_routing.set_use_read_replica") + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_call_routes_write_methods_to_primary( + self, mock_clear, mock_set, middleware, post_request, mock_get_response + ): + """Test __call__ routes write methods to primary database.""" + response = middleware(post_request) + + mock_set.assert_called_once_with(False) # Primary database + mock_clear.assert_called_once() + assert response == mock_get_response.return_value + + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_call_with_read_methods_waits_for_process_view( + self, mock_clear, middleware, get_request, mock_get_response + ): + """Test __call__ with read methods waits for process_view.""" + response = middleware(get_request) + + mock_clear.assert_called_once() + assert response == mock_get_response.return_value + + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_call_always_cleans_up_context(self, mock_clear, middleware, get_request): + """Test __call__ always cleans up context.""" + middleware(get_request) + + mock_clear.assert_called_once() + + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_call_cleans_up_context_on_exception( + self, mock_clear, middleware, get_request, mock_get_response + ): + """Test __call__ cleans up context even if get_response raises.""" + mock_get_response.side_effect = Exception("Test exception") + + with pytest.raises(Exception, match="Test exception"): + middleware(get_request) + + mock_clear.assert_called_once() + + +@pytest.mark.unit +class TestProcessView: + """Test cases for process_view method functionality.""" + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_read_method_and_replica_true(self, mock_set, middleware, get_request): + """Test process_view with GET request and use_read_replica=True.""" + view_func = Mock() + view_func.use_read_replica = True + + result = middleware.process_view(get_request, view_func, (), {}) + + mock_set.assert_called_once_with(True) + assert result is None + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_read_method_and_replica_false( + self, mock_set, middleware, get_request + ): + """Test process_view with GET request and use_read_replica=False.""" + view_func = Mock() + view_func.use_read_replica = False + + result = middleware.process_view(get_request, view_func, (), {}) + + mock_set.assert_called_once_with(False) + assert result is None + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_read_method_and_no_replica_attribute( + self, mock_set, middleware, get_request + ): + """Test process_view with GET request and no use_read_replica attr.""" + view_func = Mock(spec=[]) # No use_read_replica attribute + + result = middleware.process_view(get_request, view_func, (), {}) + + mock_set.assert_called_once_with(False) # Default to primary + assert result is None + + def test_with_write_method_ignores_view_attributes(self, middleware, post_request): + """Test process_view with write methods ignores view attributes.""" + view_func = Mock() + view_func.use_read_replica = True # This should be ignored for POST + + result = middleware.process_view(post_request, view_func, (), {}) + + assert result is None # Should not process for write methods + + +@pytest.mark.unit +class TestReplicaDecisionLogic: + """Test cases for replica decision logic methods.""" + + def test_should_use_read_replica_with_true_attribute(self, middleware): + """Test _should_use_read_replica returns True for True attribute.""" + view_func = Mock() + view_func.use_read_replica = True + + result = middleware._should_use_read_replica(view_func) + + assert result is True + + def test_should_use_read_replica_with_false_attribute(self, middleware): + """Test _should_use_read_replica returns False for False attribute.""" + view_func = Mock() + view_func.use_read_replica = False + + result = middleware._should_use_read_replica(view_func) + + assert result is False + + def test_should_use_read_replica_with_no_attribute_defaults_false(self, middleware): + """Test _should_use_read_replica defaults to False for missing attr.""" + view_func = Mock(spec=[]) # No use_read_replica attribute + + result = middleware._should_use_read_replica(view_func) + + assert result is False + + +@pytest.mark.unit +class TestAttributeDetection: + """Test cases for view attribute detection methods.""" + + def test_get_use_replica_attribute_function_based_view(self, middleware): + """Test _get_use_replica_attribute with function-based view.""" + # Test with True + view_func = Mock() + view_func.use_read_replica = True + result = middleware._get_use_replica_attribute(view_func) + assert result is True + + # Test with False + view_func.use_read_replica = False + result = middleware._get_use_replica_attribute(view_func) + assert result is False + + # Test with no attribute + view_func = Mock(spec=[]) + result = middleware._get_use_replica_attribute(view_func) + assert result is None + + def test_get_use_replica_attribute_django_cbv(self, middleware): + """Test _get_use_replica_attribute with Django CBV wrapper.""" + view_class = Mock() + view_class.use_read_replica = True + view_func = Mock() + view_func.view_class = view_class + # Remove use_read_replica from view_func to ensure it checks view_class + del view_func.use_read_replica + + result = middleware._get_use_replica_attribute(view_func) + + assert result is True + + def test_get_use_replica_attribute_drf_wrapper(self, middleware): + """Test _get_use_replica_attribute with DRF wrapper.""" + + # Create a real object to avoid Mock issues + class ViewClass: + use_read_replica = True + + class ViewFunc: + cls = ViewClass() + + view_func = ViewFunc() + + result = middleware._get_use_replica_attribute(view_func) + + assert result is True + + def test_get_use_replica_attribute_priority_order(self, middleware): + """Test attribute priority: direct > view_class > cls.""" + view_func = Mock() + view_func.use_read_replica = True # Direct attribute (highest priority) + + # Add conflicting attributes with lower priority + view_class = Mock() + view_class.use_read_replica = False + view_func.view_class = view_class + + cls = Mock() + cls.use_read_replica = False + view_func.cls = cls + + result = middleware._get_use_replica_attribute(view_func) + + assert result is True # Should use direct attribute + + @pytest.mark.parametrize( + "value,expected", + [ + (True, True), + (False, False), + (1, True), + (0, False), + ("yes", True), + ("", False), + ([], False), + ([1], True), + (None, False), + ], + ) + def test_should_use_read_replica_truthy_falsy_values( + self, middleware, value, expected + ): + """Test _should_use_read_replica with various truthy/falsy values.""" + + # Create a real object to test the attribute handling + class TestView: + pass + + view_func = TestView() + view_func.use_read_replica = value + + result = middleware._should_use_read_replica(view_func) + + assert result == expected + + +@pytest.mark.unit +class TestExceptionHandling: + """Test cases for exception handling and cleanup.""" + + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_process_exception_cleans_up_context( + self, mock_clear, middleware, request_factory + ): + """Test process_exception cleans up context.""" + request = request_factory.get("/api/test/") + exception = Exception("Test exception") + + result = middleware.process_exception(request, exception) + + mock_clear.assert_called_once() + assert result is None # Don't handle the exception + + @patch("plane.middleware.db_routing.set_use_read_replica") + @patch("plane.middleware.db_routing.clear_read_replica_context") + def test_integration_full_request_cycle( + self, mock_clear, mock_set, middleware, request_factory, mock_get_response + ): + """Test complete request cycle from __call__ through process_view.""" + request = request_factory.get("/api/test/") + view_func = Mock() + view_func.use_read_replica = True + + # Call middleware and process_view manually + response = middleware(request) + middleware.process_view(request, view_func, (), {}) + + mock_set.assert_called_once_with(True) + mock_clear.assert_called_once() + assert response == mock_get_response.return_value + + +@pytest.mark.unit +class TestRealViewIntegration: + """Test middleware with real Django/DRF view classes.""" + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_django_class_based_view(self, mock_set, middleware, request_factory): + """Test middleware with actual Django CBV.""" + + class TestView(View): + use_read_replica = True + + # Simulate Django's URL resolver creating a view wrapper + view_func = TestView.as_view() + request = request_factory.get("/api/test/") + + middleware.process_view(request, view_func, (), {}) + + mock_set.assert_called_once_with(True) + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_drf_api_view(self, mock_set, middleware, request_factory): + """Test middleware with DRF APIView.""" + + class TestAPIView(APIView): + use_read_replica = True + + # Simulate DRF's URL pattern creating a view wrapper + view_func = TestAPIView.as_view() + request = request_factory.get("/api/test/") + + middleware.process_view(request, view_func, (), {}) + + mock_set.assert_called_once_with(True) + + @patch("plane.middleware.db_routing.set_use_read_replica") + def test_with_drf_viewset(self, mock_set, middleware, request_factory): + """Test middleware with DRF ViewSet.""" + + class TestViewSet(ViewSet): + use_read_replica = True + + # Simulate DRF router creating viewset action + view_func = TestViewSet.as_view({"get": "list"}) + request = request_factory.get("/api/test/") + + middleware.process_view(request, view_func, (), {}) + + mock_set.assert_called_once_with(True) + + +@pytest.mark.unit +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_process_view_with_none_view_func(self, middleware, request_factory): + """Test process_view handles None view_func gracefully.""" + request = request_factory.get("/api/test/") + + result = middleware.process_view(request, None, (), {}) + + assert result is None # Should not crash + + def test_get_use_replica_attribute_with_attribute_error(self, middleware): + """Test _get_use_replica_attribute with view that raises AttributeError.""" + + # Create a view class that raises AttributeError on access + class ProblematicView: + def __getattr__(self, name): + if name == "use_read_replica": + raise AttributeError("Simulated attribute error") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + view_func = ProblematicView() + + result = middleware._get_use_replica_attribute(view_func) + + assert result is None # Should handle gracefully + + def test_multiple_exception_calls_are_safe(self, middleware, request_factory): + """Test that multiple calls to process_exception don't cause issues.""" + request = request_factory.get("/api/test/") + exception = Exception("Test exception") + + # Call multiple times + result1 = middleware.process_exception(request, exception) + result2 = middleware.process_exception(request, exception) + + assert result1 is None # Both should return None safely + assert result2 is None diff --git a/apps/api/plane/utils/core/__init__.py b/apps/api/plane/utils/core/__init__.py new file mode 100644 index 000000000..37c6e3741 --- /dev/null +++ b/apps/api/plane/utils/core/__init__.py @@ -0,0 +1,21 @@ +""" +Core utilities for Plane database routing and request scoping. +This package contains essential components for managing read replica routing +and request-scoped context in the Plane application. +""" + +from .dbrouters import ReadReplicaRouter +from .mixins import ReadReplicaControlMixin +from .request_scope import ( + set_use_read_replica, + should_use_read_replica, + clear_read_replica_context, +) + +__all__ = [ + "ReadReplicaRouter", + "ReadReplicaControlMixin", + "set_use_read_replica", + "should_use_read_replica", + "clear_read_replica_context", +] diff --git a/apps/api/plane/utils/core/dbrouters.py b/apps/api/plane/utils/core/dbrouters.py new file mode 100644 index 000000000..2c5b67a27 --- /dev/null +++ b/apps/api/plane/utils/core/dbrouters.py @@ -0,0 +1,73 @@ +""" +Database router for read replica selection. +This router determines which database to use for read/write operations +based on the request context set by the ReadReplicaRoutingMiddleware. +""" + +import logging +from typing import Type + +from django.db import models + +from .request_scope import should_use_read_replica + +logger = logging.getLogger("plane.db") + + +class ReadReplicaRouter: + """ + Database router that directs read operations to replica when appropriate. + This router works in conjunction with ReadReplicaRoutingMiddleware to: + - Route read operations to replica database when request context allows + - Always route write operations to primary database + - Ensure migrations only run on primary database + """ + + def db_for_read(self, model: Type[models.Model], **hints) -> str: + """ + Determine which database to use for read operations. + Args: + model: The Django model class being queried + **hints: Additional routing hints + Returns: + str: Database alias ('replica' or 'default') + """ + if should_use_read_replica(): + logger.debug(f"Routing read for {model._meta.label} to replica database") + return "replica" + else: + logger.debug(f"Routing read for {model._meta.label} to primary database") + return "default" + + def db_for_write(self, model: Type[models.Model], **hints) -> str: + """ + Determine which database to use for write operations. + All write operations always go to the primary database to ensure + data consistency and avoid replication lag issues. + Args: + model: The Django model class being written to + **hints: Additional routing hints + Returns: + str: Always returns 'default' (primary database) + """ + logger.debug(f"Routing write for {model._meta.label} to primary database") + return "default" + + def allow_migrate( + self, db: str, app_label: str, model_name: str = None, **hints + ) -> bool: + """ + Ensure migrations only run on the primary database. + Args: + db: Database alias + app_label: Application label + model_name: Model name (optional) + **hints: Additional routing hints + Returns: + bool: True if migration is allowed on this database + """ + # Only allow migrations on the primary database + allowed = db == "default" + if not allowed: + logger.debug(f"Blocking migration for {app_label} on {db} database") + return allowed diff --git a/apps/api/plane/utils/core/mixins/__init__.py b/apps/api/plane/utils/core/mixins/__init__.py new file mode 100644 index 000000000..cedd9d455 --- /dev/null +++ b/apps/api/plane/utils/core/mixins/__init__.py @@ -0,0 +1,11 @@ +""" +Core mixins for read replica functionality. +This package provides mixins for different aspects of read replica management +in Django and Django REST Framework applications. +""" + +from .view import ReadReplicaControlMixin + +__all__ = [ + "ReadReplicaControlMixin", +] diff --git a/apps/api/plane/utils/core/mixins/view.py b/apps/api/plane/utils/core/mixins/view.py new file mode 100644 index 000000000..e15ec6771 --- /dev/null +++ b/apps/api/plane/utils/core/mixins/view.py @@ -0,0 +1,20 @@ +""" +Mixins for Django REST Framework views. +""" + + +class ReadReplicaControlMixin: + """ + Mixin to control read replica usage in DRF views. + Set use_read_replica = True/False to route read operations to + replica/primary database. Works with ReadReplicaRoutingMiddleware. + Usage: + class MyViewSet(ReadReplicaControlMixin, ModelViewSet): + use_read_replica = True # Use replica for GET requests + Note: + - Only affects GET, HEAD, OPTIONS requests + - Write operations always use primary database + - Defaults to True for safe replica usage + """ + + use_read_replica: bool = True diff --git a/apps/api/plane/utils/core/request_scope.py b/apps/api/plane/utils/core/request_scope.py new file mode 100644 index 000000000..b09e77101 --- /dev/null +++ b/apps/api/plane/utils/core/request_scope.py @@ -0,0 +1,72 @@ +""" +Database routing utilities for read replica selection. +This module provides request-scoped context management for database routing, +specifically for determining when to use read replicas vs primary database. +Used in conjunction with middleware and DRF views that set use_read_replica=True. +The context is maintained per request to ensure proper isolation between +concurrent requests in async environments. +""" + +from asgiref.local import Local + +__all__ = [ + "set_use_read_replica", + "should_use_read_replica", + "clear_read_replica_context", +] + +# Request-scoped context storage for database routing preferences +# Uses asgiref.local.Local which provides ContextVar under the hood +# This ensures proper context isolation per request in async environments +_db_routing_context = Local() + + +def set_use_read_replica(use_replica: bool) -> None: + """ + Mark the current request context to use read replica database. + This function sets a request-scoped flag that determines database routing. + The context is isolated per request to ensure thread safety in async environments. + This function is typically called from: + - Middleware that detects read-only operations + - DRF views with use_read_replica=True attribute + - API endpoints that only perform read operations + Args: + use_replica (bool): True to route database queries to read replica, + False to use primary database + Note: + The context is automatically isolated per request and should be + cleared at the end of each request using clear_read_replica_context(). + """ + _db_routing_context.use_read_replica = bool(use_replica) + + +def should_use_read_replica() -> bool: + """ + Check if the current request should use read replica database. + This function reads the request-scoped context to determine database routing. + It's called by the database router to decide which connection to use. + Returns: + bool: True if queries should be routed to read replica, + False if they should use primary database (default) + Note: + Returns False by default if no context is set for the current request. + The context is automatically isolated per request. + """ + return getattr(_db_routing_context, "use_read_replica", False) + + +def clear_read_replica_context() -> None: + """ + Clear the read replica context for the current request. + This function should be called at the end of each request to ensure + that context doesn't leak between requests. Typically called from + middleware during request cleanup. + This is important for: + - Preventing context leakage between requests + - Ensuring clean state for each new request + - Proper memory management in long-running processes + """ + try: + delattr(_db_routing_context, "use_read_replica") + except AttributeError: + pass