[WEB-4533] feat: read replica functionality (#7453)

* feat: read replica functionality

* fix: set use_read_replica to false

* chore: add use_read_replica to external APIs

* chore: remove use_read_replica on read endpoints

* chore: remove md files

* Updated all the necessary endpoints to use read replica

---------

Co-authored-by: Dheeraj Kumar Ketireddy <dheeru0198@gmail.com>
This commit is contained in:
Sangeetha 2025-07-28 17:41:02 +05:30 committed by GitHub
parent b1162395ed
commit 84879ee3bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 884 additions and 6 deletions

View file

@ -405,6 +405,8 @@ class UserServerAssetEndpoint(BaseAPIView):
class GenericAssetEndpoint(BaseAPIView):
"""This endpoint is used to upload generic assets that can be later bound to entities."""
use_read_replica = True
@asset_docs(
operation_id="get_generic_asset",
summary="Get presigned URL for asset download",

View file

@ -20,6 +20,7 @@ from plane.api.middleware.api_authentication import APIKeyAuthentication
from plane.api.rate_limit import ApiKeyRateThrottle, ServiceTokenRateThrottle
from plane.utils.exception_logger import log_exception
from plane.utils.paginator import BasePaginator
from plane.utils.core.mixins import ReadReplicaControlMixin
class TimezoneMixin:
@ -36,11 +37,15 @@ class TimezoneMixin:
timezone.deactivate()
class BaseAPIView(TimezoneMixin, GenericAPIView, BasePaginator):
class BaseAPIView(
TimezoneMixin, GenericAPIView, ReadReplicaControlMixin, BasePaginator
):
authentication_classes = [APIKeyAuthentication]
permission_classes = [IsAuthenticated]
use_read_replica = False
def filter_queryset(self, queryset):
for backend in list(self.filter_backends):
queryset = backend().filter_queryset(self.request, queryset, self)

View file

@ -86,6 +86,7 @@ class CycleListCreateAPIEndpoint(BaseAPIView):
model = Cycle
webhook_event = "cycle"
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -373,6 +374,7 @@ class CycleDetailAPIEndpoint(BaseAPIView):
model = Cycle
webhook_event = "cycle"
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -633,6 +635,7 @@ class CycleArchiveUnarchiveAPIEndpoint(BaseAPIView):
"""Cycle Archive and Unarchive Endpoint"""
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -831,6 +834,7 @@ class CycleIssueListCreateAPIEndpoint(BaseAPIView):
model = CycleIssue
webhook_event = "cycle_issue"
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -1045,6 +1049,7 @@ class CycleIssueDetailAPIEndpoint(BaseAPIView):
webhook_event = "cycle_issue"
bulk = True
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (

View file

@ -51,8 +51,10 @@ class IntakeIssueListCreateAPIEndpoint(BaseAPIView):
"""Intake Work Item List and Create Endpoint"""
serializer_class = IntakeIssueSerializer
model = Intake
permission_classes = [ProjectLitePermission]
use_read_replica = True
def get_queryset(self):
intake = Intake.objects.filter(
@ -214,6 +216,7 @@ class IntakeIssueDetailAPIEndpoint(BaseAPIView):
serializer_class = IntakeIssueSerializer
model = IntakeIssue
use_read_replica = True
filterset_fields = ["status"]

View file

@ -156,6 +156,7 @@ class WorkspaceIssueAPIEndpoint(BaseAPIView):
webhook_event = "issue"
permission_classes = [ProjectEntityPermission]
serializer_class = IssueSerializer
use_read_replica = True
@property
def project_identifier(self):
@ -231,6 +232,7 @@ class IssueListCreateAPIEndpoint(BaseAPIView):
webhook_event = "issue"
permission_classes = [ProjectEntityPermission]
serializer_class = IssueSerializer
use_read_replica = True
def get_queryset(self):
return (
@ -495,6 +497,7 @@ class IssueDetailAPIEndpoint(BaseAPIView):
webhook_event = "issue"
permission_classes = [ProjectEntityPermission]
serializer_class = IssueSerializer
use_read_replica = True
def get_queryset(self):
return (
@ -822,6 +825,7 @@ class LabelListCreateAPIEndpoint(BaseAPIView):
serializer_class = LabelSerializer
model = Label
permission_classes = [ProjectMemberPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -949,6 +953,7 @@ class LabelDetailAPIEndpoint(BaseAPIView):
serializer_class = LabelSerializer
model = Label
permission_classes = [ProjectMemberPermission]
use_read_replica = True
@label_docs(
operation_id="get_labels",
@ -1057,6 +1062,7 @@ class IssueLinkListCreateAPIEndpoint(BaseAPIView):
serializer_class = IssueLinkSerializer
model = IssueLink
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -1163,6 +1169,7 @@ class IssueLinkDetailAPIEndpoint(BaseAPIView):
model = IssueLink
serializer_class = IssueLinkSerializer
use_read_replica = True
def get_queryset(self):
return (
@ -1319,6 +1326,7 @@ class IssueCommentListCreateAPIEndpoint(BaseAPIView):
model = IssueComment
webhook_event = "issue_comment"
permission_classes = [ProjectLitePermission]
use_read_replica = True
def get_queryset(self):
return (
@ -1477,6 +1485,7 @@ class IssueCommentDetailAPIEndpoint(BaseAPIView):
model = IssueComment
webhook_event = "issue_comment"
permission_classes = [ProjectLitePermission]
use_read_replica = True
def get_queryset(self):
return (
@ -1657,6 +1666,7 @@ class IssueCommentDetailAPIEndpoint(BaseAPIView):
class IssueActivityListAPIEndpoint(BaseAPIView):
permission_classes = [ProjectEntityPermission]
use_read_replica = True
@issue_activity_docs(
operation_id="list_work_item_activities",
@ -1712,6 +1722,7 @@ class IssueActivityDetailAPIEndpoint(BaseAPIView):
"""Issue Activity Detail Endpoint"""
permission_classes = [ProjectEntityPermission]
use_read_replica = True
@issue_activity_docs(
operation_id="retrieve_work_item_activity",
@ -1770,6 +1781,7 @@ class IssueAttachmentListCreateAPIEndpoint(BaseAPIView):
serializer_class = IssueAttachmentSerializer
model = FileAsset
permission_classes = [ProjectEntityPermission]
use_read_replica = True
@issue_attachment_docs(
operation_id="create_work_item_attachment",
@ -1977,6 +1989,7 @@ class IssueAttachmentDetailAPIEndpoint(BaseAPIView):
serializer_class = IssueAttachmentSerializer
permission_classes = [ProjectEntityPermission]
model = FileAsset
use_read_replica = True
@issue_attachment_docs(
operation_id="delete_work_item_attachment",
@ -2146,6 +2159,8 @@ class IssueAttachmentDetailAPIEndpoint(BaseAPIView):
class IssueSearchEndpoint(BaseAPIView):
"""Endpoint to search across multiple fields in the issues"""
use_read_replica = True
@extend_schema(
operation_id="search_work_items",
tags=["Work Items"],

View file

@ -24,9 +24,8 @@ from plane.utils.openapi import (
class WorkspaceMemberAPIEndpoint(BaseAPIView):
permission_classes = [
WorkSpaceAdminPermission,
]
permission_classes = [WorkSpaceAdminPermission]
use_read_replica = True
@extend_schema(
operation_id="get_workspace_members",
@ -92,6 +91,7 @@ class WorkspaceMemberAPIEndpoint(BaseAPIView):
# API endpoint to get and insert users inside the workspace
class ProjectMemberAPIEndpoint(BaseAPIView):
permission_classes = [ProjectMemberPermission]
use_read_replica = True
@extend_schema(
operation_id="get_project_members",

View file

@ -80,6 +80,7 @@ class ModuleListCreateAPIEndpoint(BaseAPIView):
model = Module
webhook_event = "module"
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -282,6 +283,7 @@ class ModuleDetailAPIEndpoint(BaseAPIView):
permission_classes = [ProjectEntityPermission]
serializer_class = ModuleSerializer
webhook_event = "module"
use_read_replica = True
def get_queryset(self):
return (
@ -550,6 +552,7 @@ class ModuleIssueListCreateAPIEndpoint(BaseAPIView):
model = ModuleIssue
webhook_event = "module_issue"
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -769,6 +772,7 @@ class ModuleIssueDetailAPIEndpoint(BaseAPIView):
model = ModuleIssue
webhook_event = "module_issue"
bulk = True
use_read_replica = True
permission_classes = [ProjectEntityPermission]
@ -916,6 +920,7 @@ class ModuleIssueDetailAPIEndpoint(BaseAPIView):
class ModuleArchiveUnarchiveAPIEndpoint(BaseAPIView):
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (

View file

@ -67,6 +67,7 @@ class ProjectListCreateAPIEndpoint(BaseAPIView):
model = Project
webhook_event = "project"
permission_classes = [ProjectBasePermission]
use_read_replica = True
def get_queryset(self):
return (
@ -331,6 +332,7 @@ class ProjectDetailAPIEndpoint(BaseAPIView):
webhook_event = "project"
permission_classes = [ProjectBasePermission]
use_read_replica = True
def get_queryset(self):
return (

View file

@ -38,6 +38,7 @@ class StateListCreateAPIEndpoint(BaseAPIView):
serializer_class = StateSerializer
model = State
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (
@ -164,6 +165,7 @@ class StateDetailAPIEndpoint(BaseAPIView):
serializer_class = StateSerializer
model = State
permission_classes = [ProjectEntityPermission]
use_read_replica = True
def get_queryset(self):
return (

View file

@ -24,6 +24,7 @@ from rest_framework.viewsets import ModelViewSet
from plane.authentication.session import BaseSessionAuthentication
from plane.utils.exception_logger import log_exception
from plane.utils.paginator import BasePaginator
from plane.utils.core.mixins import ReadReplicaControlMixin
class TimezoneMixin:
@ -40,7 +41,7 @@ class TimezoneMixin:
timezone.deactivate()
class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator):
class BaseViewSet(TimezoneMixin, ReadReplicaControlMixin, ModelViewSet, BasePaginator):
model = None
permission_classes = [IsAuthenticated]
@ -53,6 +54,8 @@ class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator):
search_fields = []
use_read_replica = False
def get_queryset(self):
try:
return self.model.objects.all()
@ -149,7 +152,7 @@ class BaseViewSet(TimezoneMixin, ModelViewSet, BasePaginator):
return expand if expand else None
class BaseAPIView(TimezoneMixin, APIView, BasePaginator):
class BaseAPIView(TimezoneMixin, ReadReplicaControlMixin, APIView, BasePaginator):
permission_classes = [IsAuthenticated]
filter_backends = (DjangoFilterBackend, SearchFilter)
@ -160,6 +163,8 @@ class BaseAPIView(TimezoneMixin, APIView, BasePaginator):
search_fields = []
use_read_replica = False
def filter_queryset(self, queryset):
for backend in list(self.filter_backends):
queryset = backend().filter_queryset(self.request, queryset, self)

View file

@ -19,6 +19,7 @@ from plane.db.models import IssueActivity, IssueComment, CommentReaction, Intake
class IssueActivityEndpoint(BaseAPIView):
permission_classes = [ProjectEntityPermission]
use_read_replica = True
@method_decorator(gzip_page)
@allow_permission([ROLE.ADMIN, ROLE.MEMBER, ROLE.GUEST])

View file

@ -232,6 +232,8 @@ class NotificationViewSet(BaseViewSet, BasePaginator):
class UnreadNotificationEndpoint(BaseAPIView):
use_read_replica = True
@allow_permission(
allowed_roles=[ROLE.ADMIN, ROLE.MEMBER, ROLE.GUEST], level="WORKSPACE"
)

View file

@ -46,6 +46,7 @@ class ProjectViewSet(BaseViewSet):
serializer_class = ProjectListSerializer
model = Project
webhook_event = "project"
use_read_replica = True
def get_queryset(self):
sort_order = ProjectMember.objects.filter(

View file

@ -312,6 +312,7 @@ class ProjectMemberUserEndpoint(BaseAPIView):
class UserProjectRolesEndpoint(BaseAPIView):
permission_classes = [WorkspaceUserPermission]
use_read_replica = True
def get(self, request, slug):
project_members = ProjectMember.objects.filter(

View file

@ -44,6 +44,7 @@ from django.views.decorators.vary import vary_on_cookie
class UserEndpoint(BaseViewSet):
serializer_class = UserSerializer
model = User
use_read_replica = True
def get_object(self):
return self.request.user

View file

@ -177,6 +177,7 @@ class WorkSpaceViewSet(BaseViewSet):
class UserWorkSpacesEndpoint(BaseAPIView):
search_fields = ["name"]
filterset_fields = ["owner"]
use_read_replica = True
def get(self, request):
fields = [field for field in request.GET.get("fields", "").split(",") if field]

View file

@ -12,6 +12,7 @@ from plane.utils.cache import cache_response
class WorkspaceEstimatesEndpoint(BaseAPIView):
permission_classes = [WorkspaceEntityPermission]
use_read_replica = True
@cache_response(60 * 60 * 2)
def get(self, request, slug):

View file

@ -14,6 +14,8 @@ from plane.app.permissions import allow_permission, ROLE
class WorkspaceFavoriteEndpoint(BaseAPIView):
use_read_replica = True
@allow_permission(allowed_roles=[ROLE.ADMIN, ROLE.MEMBER], level="WORKSPACE")
def get(self, request, slug):
# the second filter is to check if the user is a member of the project

View file

@ -12,6 +12,7 @@ from plane.utils.cache import cache_response
class WorkspaceLabelsEndpoint(BaseAPIView):
permission_classes = [WorkspaceViewerPermission]
use_read_replica = True
@cache_response(60 * 60 * 2)
def get(self, request, slug):

View file

@ -28,6 +28,7 @@ class WorkSpaceMemberViewSet(BaseViewSet):
model = WorkspaceMember
search_fields = ["member__display_name", "member__first_name"]
use_read_replica = True
def get_queryset(self):
return self.filter_queryset(
@ -214,6 +215,8 @@ class WorkspaceMemberUserViewsEndpoint(BaseAPIView):
class WorkspaceMemberUserEndpoint(BaseAPIView):
use_read_replica = True
def get(self, request, slug):
draft_issue_count = (
DraftIssue.objects.filter(

View file

@ -11,6 +11,7 @@ from plane.app.permissions import allow_permission, ROLE
class QuickLinkViewSet(BaseViewSet):
model = WorkspaceUserLink
use_read_replica = True
def get_serializer_class(self):
return WorkspaceUserLinkSerializer

View file

@ -12,6 +12,7 @@ from plane.app.permissions import allow_permission, ROLE
class UserRecentVisitViewSet(BaseViewSet):
model = UserRecentVisit
use_read_replica = True
def get_serializer_class(self):
return WorkspaceRecentVisitSerializer

View file

@ -13,6 +13,7 @@ from collections import defaultdict
class WorkspaceStatesEndpoint(BaseAPIView):
permission_classes = [WorkspaceEntityPermission]
use_read_replica = True
@cache_response(60 * 60 * 2)
def get(self, request, slug):

View file

@ -12,6 +12,7 @@ from plane.app.serializers import StickySerializer
class WorkspaceStickyViewSet(BaseViewSet):
serializer_class = StickySerializer
model = Sticky
use_read_replica = True
def get_queryset(self):
return self.filter_queryset(

View file

@ -13,6 +13,7 @@ from rest_framework import status
class WorkspaceUserPreferenceViewSet(BaseAPIView):
model = WorkspaceUserPreference
use_read_replica = True
def get_serializer_class(self):
return WorkspaceUserPreferenceSerializer

View file

@ -0,0 +1,162 @@
"""
Database routing middleware for read replica selection.
This middleware determines whether database queries should be routed to
read replicas or the primary database based on HTTP method and view configuration.
"""
import logging
from typing import Callable, Optional
from django.http import HttpRequest, HttpResponse
from plane.utils.core import (
set_use_read_replica,
clear_read_replica_context,
)
logger = logging.getLogger("plane.api")
class ReadReplicaRoutingMiddleware:
"""
Middleware for intelligent database routing to read replicas.
Routing Logic:
Non-GET requests (POST, PUT, DELETE, PATCH) Primary database
GET requests:
- View has use_read_replica=False Primary database
- View has use_read_replica=True Read replica
- View has no use_read_replica attribute Primary database (safe default)
The middleware supports both Django CBVs and DRF APIViews/ViewSets.
Context is properly isolated per request to prevent data leakage.
"""
# HTTP methods that are considered read-only by default
READ_ONLY_METHODS = {"GET", "HEAD", "OPTIONS"}
def __init__(self, get_response):
"""
Initialize the middleware with the next middleware/view in the chain.
Args:
get_response: The next middleware or view function
"""
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
"""
Process the request and determine database routing.
Args:
request: The HTTP request object
Returns:
HttpResponse: The HTTP response from the view
"""
# For non-read operations, set primary database immediately
if request.method not in self.READ_ONLY_METHODS:
set_use_read_replica(False)
logger.debug(f"Routing {request.method} {request.path} to primary database")
try:
# Process the request through the middleware chain
response = self.get_response(request)
return response
finally:
# Always clean up context, even if an exception occurs
# This prevents context leakage between requests
clear_read_replica_context()
def process_view(
self,
request: HttpRequest,
view_func: Callable,
view_args: tuple,
view_kwargs: dict,
) -> None:
"""
Hook called just before Django calls the view.
This is more efficient than resolving URLs in __call__ since Django
provides the view function directly.
Args:
request: The HTTP request object
view_func: The view function to be called
view_args: Positional arguments for the view
view_kwargs: Keyword arguments for the view
"""
# Only process read operations (write operations already handled in __call__)
if request.method in self.READ_ONLY_METHODS:
use_replica = self._should_use_read_replica(view_func)
set_use_read_replica(use_replica)
db_type = "read replica" if use_replica else "primary database"
logger.debug(f"Routing {request.method} {request.path} to {db_type}")
# Return None to continue normal request processing
return None
def _should_use_read_replica(self, view_func: Callable) -> bool:
"""
Determine if the view should use read replica based on its configuration.
Args:
view_func: The view function to inspect
Returns:
bool: True if should use read replica, False for primary database
"""
use_replica_attr = self._get_use_replica_attribute(view_func)
# Default to primary database for GET requests if no explicit setting
# This ensures only views that explicitly opt-in use read replicas
if use_replica_attr is None:
return False
return bool(use_replica_attr)
def _get_use_replica_attribute(self, view_func: Callable) -> Optional[bool]:
"""
Extract the use_read_replica attribute from various view types.
Args:
view_func: The view function to inspect
Returns:
Optional[bool]: The use_read_replica setting, or None if not found
"""
# Return None if view_func is None to prevent AttributeError
if view_func is None:
return None
# Check function-based view attribute
use_replica = getattr(view_func, "use_read_replica", None)
if use_replica is not None:
return use_replica
# Check Django CBV wrapper
if hasattr(view_func, "view_class"):
use_replica = getattr(view_func.view_class, "use_read_replica", None)
if use_replica is not None:
return use_replica
# Check DRF wrapper (APIView / ViewSet)
if hasattr(view_func, "cls"):
use_replica = getattr(view_func.cls, "use_read_replica", None)
if use_replica is not None:
return use_replica
return None
def process_exception(self, request: HttpRequest, exception: Exception) -> None:
"""
Handle exceptions that occur during view processing.
This provides an additional safety net for context cleanup when views
raise exceptions, complementing the try/finally in __call__.
Args:
request: The HTTP request object
exception: The exception that was raised
Returns:
None: Don't handle the exception, just clean up context
"""
# Clean up context on exception as a safety measure
# The try/finally in __call__ should handle most cases, but this
# provides extra protection specifically for view exceptions
clear_read_replica_context()
logger.debug(
f"Cleaned up read replica context due to exception: {type(exception).__name__}"
)
# Return None to let the exception continue propagating
return None

View file

@ -149,6 +149,29 @@ else:
}
}
if os.environ.get("ENABLE_READ_REPLICA", "0") == "1":
if bool(os.environ.get("DATABASE_READ_REPLICA_URL")):
# Parse database configuration from $DATABASE_URL
DATABASES["replica"] = dj_database_url.parse(
os.environ.get("DATABASE_READ_REPLICA_URL")
)
else:
DATABASES["replica"] = {
"ENGINE": "django.db.backends.postgresql",
"NAME": os.environ.get("POSTGRES_READ_REPLICA_DB"),
"USER": os.environ.get("POSTGRES_READ_REPLICA_USER"),
"PASSWORD": os.environ.get("POSTGRES_READ_REPLICA_PASSWORD"),
"HOST": os.environ.get("POSTGRES_READ_REPLICA_HOST"),
"PORT": os.environ.get("POSTGRES_READ_REPLICA_PORT", "5432"),
}
# Database Routers
DATABASE_ROUTERS = ["plane.utils.core.dbrouters.ReadReplicaRouter"]
# Add middleware at the end for read replica routing
MIDDLEWARE.append("plane.middleware.db_routing.ReadReplicaRoutingMiddleware")
# Redis Config
REDIS_URL = os.environ.get("REDIS_URL")
REDIS_SSL = REDIS_URL and "rediss" in REDIS_URL

View file

@ -0,0 +1,433 @@
"""
Unit tests for ReadReplicaRoutingMiddleware.
This module contains comprehensive tests for the ReadReplicaRoutingMiddleware
that handles intelligent database routing to read replicas based on HTTP methods
and view configuration.
Test Organization:
- TestReadReplicaRoutingMiddleware: Core middleware functionality
- TestProcessView: process_view method behavior
- TestReplicaDecisionLogic: Decision logic for replica usage
- TestAttributeDetection: View attribute detection methods
- TestExceptionHandling: Exception handling and cleanup
- TestRealViewIntegration: Real Django/DRF view integration
- TestEdgeCases: Edge cases and error conditions
"""
import pytest
from unittest.mock import Mock, patch
from django.http import HttpResponse
from django.test import RequestFactory
from django.views import View
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from plane.middleware.db_routing import ReadReplicaRoutingMiddleware
# Pytest fixtures
@pytest.fixture
def mock_get_response():
"""Fixture for mocked get_response callable."""
return Mock(return_value=HttpResponse())
@pytest.fixture
def middleware(mock_get_response):
"""Fixture for ReadReplicaRoutingMiddleware instance."""
return ReadReplicaRoutingMiddleware(mock_get_response)
@pytest.fixture
def request_factory():
"""Fixture for Django RequestFactory."""
return RequestFactory()
@pytest.fixture
def mock_view_func():
"""Fixture for a basic mocked view function."""
view = Mock()
view.use_read_replica = True
return view
@pytest.fixture
def get_request(request_factory):
"""Fixture for a GET request."""
return request_factory.get("/api/test/")
@pytest.fixture
def post_request(request_factory):
"""Fixture for a POST request."""
return request_factory.post("/api/test/")
@pytest.mark.unit
class TestReadReplicaRoutingMiddleware:
"""Test cases for ReadReplicaRoutingMiddleware core functionality."""
def test_middleware_initialization(self, middleware, mock_get_response):
"""Test middleware initializes correctly with expected attributes."""
assert middleware.get_response == mock_get_response
assert hasattr(middleware, "READ_ONLY_METHODS")
assert "GET" in middleware.READ_ONLY_METHODS
assert "HEAD" in middleware.READ_ONLY_METHODS
assert "OPTIONS" in middleware.READ_ONLY_METHODS
def test_read_only_methods_constant(self, middleware):
"""Test READ_ONLY_METHODS contains expected HTTP methods."""
expected_methods = {"GET", "HEAD", "OPTIONS"}
assert middleware.READ_ONLY_METHODS == expected_methods
@patch("plane.middleware.db_routing.set_use_read_replica")
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_call_routes_write_methods_to_primary(
self, mock_clear, mock_set, middleware, post_request, mock_get_response
):
"""Test __call__ routes write methods to primary database."""
response = middleware(post_request)
mock_set.assert_called_once_with(False) # Primary database
mock_clear.assert_called_once()
assert response == mock_get_response.return_value
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_call_with_read_methods_waits_for_process_view(
self, mock_clear, middleware, get_request, mock_get_response
):
"""Test __call__ with read methods waits for process_view."""
response = middleware(get_request)
mock_clear.assert_called_once()
assert response == mock_get_response.return_value
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_call_always_cleans_up_context(self, mock_clear, middleware, get_request):
"""Test __call__ always cleans up context."""
middleware(get_request)
mock_clear.assert_called_once()
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_call_cleans_up_context_on_exception(
self, mock_clear, middleware, get_request, mock_get_response
):
"""Test __call__ cleans up context even if get_response raises."""
mock_get_response.side_effect = Exception("Test exception")
with pytest.raises(Exception, match="Test exception"):
middleware(get_request)
mock_clear.assert_called_once()
@pytest.mark.unit
class TestProcessView:
"""Test cases for process_view method functionality."""
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_read_method_and_replica_true(self, mock_set, middleware, get_request):
"""Test process_view with GET request and use_read_replica=True."""
view_func = Mock()
view_func.use_read_replica = True
result = middleware.process_view(get_request, view_func, (), {})
mock_set.assert_called_once_with(True)
assert result is None
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_read_method_and_replica_false(
self, mock_set, middleware, get_request
):
"""Test process_view with GET request and use_read_replica=False."""
view_func = Mock()
view_func.use_read_replica = False
result = middleware.process_view(get_request, view_func, (), {})
mock_set.assert_called_once_with(False)
assert result is None
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_read_method_and_no_replica_attribute(
self, mock_set, middleware, get_request
):
"""Test process_view with GET request and no use_read_replica attr."""
view_func = Mock(spec=[]) # No use_read_replica attribute
result = middleware.process_view(get_request, view_func, (), {})
mock_set.assert_called_once_with(False) # Default to primary
assert result is None
def test_with_write_method_ignores_view_attributes(self, middleware, post_request):
"""Test process_view with write methods ignores view attributes."""
view_func = Mock()
view_func.use_read_replica = True # This should be ignored for POST
result = middleware.process_view(post_request, view_func, (), {})
assert result is None # Should not process for write methods
@pytest.mark.unit
class TestReplicaDecisionLogic:
"""Test cases for replica decision logic methods."""
def test_should_use_read_replica_with_true_attribute(self, middleware):
"""Test _should_use_read_replica returns True for True attribute."""
view_func = Mock()
view_func.use_read_replica = True
result = middleware._should_use_read_replica(view_func)
assert result is True
def test_should_use_read_replica_with_false_attribute(self, middleware):
"""Test _should_use_read_replica returns False for False attribute."""
view_func = Mock()
view_func.use_read_replica = False
result = middleware._should_use_read_replica(view_func)
assert result is False
def test_should_use_read_replica_with_no_attribute_defaults_false(self, middleware):
"""Test _should_use_read_replica defaults to False for missing attr."""
view_func = Mock(spec=[]) # No use_read_replica attribute
result = middleware._should_use_read_replica(view_func)
assert result is False
@pytest.mark.unit
class TestAttributeDetection:
"""Test cases for view attribute detection methods."""
def test_get_use_replica_attribute_function_based_view(self, middleware):
"""Test _get_use_replica_attribute with function-based view."""
# Test with True
view_func = Mock()
view_func.use_read_replica = True
result = middleware._get_use_replica_attribute(view_func)
assert result is True
# Test with False
view_func.use_read_replica = False
result = middleware._get_use_replica_attribute(view_func)
assert result is False
# Test with no attribute
view_func = Mock(spec=[])
result = middleware._get_use_replica_attribute(view_func)
assert result is None
def test_get_use_replica_attribute_django_cbv(self, middleware):
"""Test _get_use_replica_attribute with Django CBV wrapper."""
view_class = Mock()
view_class.use_read_replica = True
view_func = Mock()
view_func.view_class = view_class
# Remove use_read_replica from view_func to ensure it checks view_class
del view_func.use_read_replica
result = middleware._get_use_replica_attribute(view_func)
assert result is True
def test_get_use_replica_attribute_drf_wrapper(self, middleware):
"""Test _get_use_replica_attribute with DRF wrapper."""
# Create a real object to avoid Mock issues
class ViewClass:
use_read_replica = True
class ViewFunc:
cls = ViewClass()
view_func = ViewFunc()
result = middleware._get_use_replica_attribute(view_func)
assert result is True
def test_get_use_replica_attribute_priority_order(self, middleware):
"""Test attribute priority: direct > view_class > cls."""
view_func = Mock()
view_func.use_read_replica = True # Direct attribute (highest priority)
# Add conflicting attributes with lower priority
view_class = Mock()
view_class.use_read_replica = False
view_func.view_class = view_class
cls = Mock()
cls.use_read_replica = False
view_func.cls = cls
result = middleware._get_use_replica_attribute(view_func)
assert result is True # Should use direct attribute
@pytest.mark.parametrize(
"value,expected",
[
(True, True),
(False, False),
(1, True),
(0, False),
("yes", True),
("", False),
([], False),
([1], True),
(None, False),
],
)
def test_should_use_read_replica_truthy_falsy_values(
self, middleware, value, expected
):
"""Test _should_use_read_replica with various truthy/falsy values."""
# Create a real object to test the attribute handling
class TestView:
pass
view_func = TestView()
view_func.use_read_replica = value
result = middleware._should_use_read_replica(view_func)
assert result == expected
@pytest.mark.unit
class TestExceptionHandling:
"""Test cases for exception handling and cleanup."""
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_process_exception_cleans_up_context(
self, mock_clear, middleware, request_factory
):
"""Test process_exception cleans up context."""
request = request_factory.get("/api/test/")
exception = Exception("Test exception")
result = middleware.process_exception(request, exception)
mock_clear.assert_called_once()
assert result is None # Don't handle the exception
@patch("plane.middleware.db_routing.set_use_read_replica")
@patch("plane.middleware.db_routing.clear_read_replica_context")
def test_integration_full_request_cycle(
self, mock_clear, mock_set, middleware, request_factory, mock_get_response
):
"""Test complete request cycle from __call__ through process_view."""
request = request_factory.get("/api/test/")
view_func = Mock()
view_func.use_read_replica = True
# Call middleware and process_view manually
response = middleware(request)
middleware.process_view(request, view_func, (), {})
mock_set.assert_called_once_with(True)
mock_clear.assert_called_once()
assert response == mock_get_response.return_value
@pytest.mark.unit
class TestRealViewIntegration:
"""Test middleware with real Django/DRF view classes."""
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_django_class_based_view(self, mock_set, middleware, request_factory):
"""Test middleware with actual Django CBV."""
class TestView(View):
use_read_replica = True
# Simulate Django's URL resolver creating a view wrapper
view_func = TestView.as_view()
request = request_factory.get("/api/test/")
middleware.process_view(request, view_func, (), {})
mock_set.assert_called_once_with(True)
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_drf_api_view(self, mock_set, middleware, request_factory):
"""Test middleware with DRF APIView."""
class TestAPIView(APIView):
use_read_replica = True
# Simulate DRF's URL pattern creating a view wrapper
view_func = TestAPIView.as_view()
request = request_factory.get("/api/test/")
middleware.process_view(request, view_func, (), {})
mock_set.assert_called_once_with(True)
@patch("plane.middleware.db_routing.set_use_read_replica")
def test_with_drf_viewset(self, mock_set, middleware, request_factory):
"""Test middleware with DRF ViewSet."""
class TestViewSet(ViewSet):
use_read_replica = True
# Simulate DRF router creating viewset action
view_func = TestViewSet.as_view({"get": "list"})
request = request_factory.get("/api/test/")
middleware.process_view(request, view_func, (), {})
mock_set.assert_called_once_with(True)
@pytest.mark.unit
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_process_view_with_none_view_func(self, middleware, request_factory):
"""Test process_view handles None view_func gracefully."""
request = request_factory.get("/api/test/")
result = middleware.process_view(request, None, (), {})
assert result is None # Should not crash
def test_get_use_replica_attribute_with_attribute_error(self, middleware):
"""Test _get_use_replica_attribute with view that raises AttributeError."""
# Create a view class that raises AttributeError on access
class ProblematicView:
def __getattr__(self, name):
if name == "use_read_replica":
raise AttributeError("Simulated attribute error")
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
view_func = ProblematicView()
result = middleware._get_use_replica_attribute(view_func)
assert result is None # Should handle gracefully
def test_multiple_exception_calls_are_safe(self, middleware, request_factory):
"""Test that multiple calls to process_exception don't cause issues."""
request = request_factory.get("/api/test/")
exception = Exception("Test exception")
# Call multiple times
result1 = middleware.process_exception(request, exception)
result2 = middleware.process_exception(request, exception)
assert result1 is None # Both should return None safely
assert result2 is None

View file

@ -0,0 +1,21 @@
"""
Core utilities for Plane database routing and request scoping.
This package contains essential components for managing read replica routing
and request-scoped context in the Plane application.
"""
from .dbrouters import ReadReplicaRouter
from .mixins import ReadReplicaControlMixin
from .request_scope import (
set_use_read_replica,
should_use_read_replica,
clear_read_replica_context,
)
__all__ = [
"ReadReplicaRouter",
"ReadReplicaControlMixin",
"set_use_read_replica",
"should_use_read_replica",
"clear_read_replica_context",
]

View file

@ -0,0 +1,73 @@
"""
Database router for read replica selection.
This router determines which database to use for read/write operations
based on the request context set by the ReadReplicaRoutingMiddleware.
"""
import logging
from typing import Type
from django.db import models
from .request_scope import should_use_read_replica
logger = logging.getLogger("plane.db")
class ReadReplicaRouter:
"""
Database router that directs read operations to replica when appropriate.
This router works in conjunction with ReadReplicaRoutingMiddleware to:
- Route read operations to replica database when request context allows
- Always route write operations to primary database
- Ensure migrations only run on primary database
"""
def db_for_read(self, model: Type[models.Model], **hints) -> str:
"""
Determine which database to use for read operations.
Args:
model: The Django model class being queried
**hints: Additional routing hints
Returns:
str: Database alias ('replica' or 'default')
"""
if should_use_read_replica():
logger.debug(f"Routing read for {model._meta.label} to replica database")
return "replica"
else:
logger.debug(f"Routing read for {model._meta.label} to primary database")
return "default"
def db_for_write(self, model: Type[models.Model], **hints) -> str:
"""
Determine which database to use for write operations.
All write operations always go to the primary database to ensure
data consistency and avoid replication lag issues.
Args:
model: The Django model class being written to
**hints: Additional routing hints
Returns:
str: Always returns 'default' (primary database)
"""
logger.debug(f"Routing write for {model._meta.label} to primary database")
return "default"
def allow_migrate(
self, db: str, app_label: str, model_name: str = None, **hints
) -> bool:
"""
Ensure migrations only run on the primary database.
Args:
db: Database alias
app_label: Application label
model_name: Model name (optional)
**hints: Additional routing hints
Returns:
bool: True if migration is allowed on this database
"""
# Only allow migrations on the primary database
allowed = db == "default"
if not allowed:
logger.debug(f"Blocking migration for {app_label} on {db} database")
return allowed

View file

@ -0,0 +1,11 @@
"""
Core mixins for read replica functionality.
This package provides mixins for different aspects of read replica management
in Django and Django REST Framework applications.
"""
from .view import ReadReplicaControlMixin
__all__ = [
"ReadReplicaControlMixin",
]

View file

@ -0,0 +1,20 @@
"""
Mixins for Django REST Framework views.
"""
class ReadReplicaControlMixin:
"""
Mixin to control read replica usage in DRF views.
Set use_read_replica = True/False to route read operations to
replica/primary database. Works with ReadReplicaRoutingMiddleware.
Usage:
class MyViewSet(ReadReplicaControlMixin, ModelViewSet):
use_read_replica = True # Use replica for GET requests
Note:
- Only affects GET, HEAD, OPTIONS requests
- Write operations always use primary database
- Defaults to True for safe replica usage
"""
use_read_replica: bool = True

View file

@ -0,0 +1,72 @@
"""
Database routing utilities for read replica selection.
This module provides request-scoped context management for database routing,
specifically for determining when to use read replicas vs primary database.
Used in conjunction with middleware and DRF views that set use_read_replica=True.
The context is maintained per request to ensure proper isolation between
concurrent requests in async environments.
"""
from asgiref.local import Local
__all__ = [
"set_use_read_replica",
"should_use_read_replica",
"clear_read_replica_context",
]
# Request-scoped context storage for database routing preferences
# Uses asgiref.local.Local which provides ContextVar under the hood
# This ensures proper context isolation per request in async environments
_db_routing_context = Local()
def set_use_read_replica(use_replica: bool) -> None:
"""
Mark the current request context to use read replica database.
This function sets a request-scoped flag that determines database routing.
The context is isolated per request to ensure thread safety in async environments.
This function is typically called from:
- Middleware that detects read-only operations
- DRF views with use_read_replica=True attribute
- API endpoints that only perform read operations
Args:
use_replica (bool): True to route database queries to read replica,
False to use primary database
Note:
The context is automatically isolated per request and should be
cleared at the end of each request using clear_read_replica_context().
"""
_db_routing_context.use_read_replica = bool(use_replica)
def should_use_read_replica() -> bool:
"""
Check if the current request should use read replica database.
This function reads the request-scoped context to determine database routing.
It's called by the database router to decide which connection to use.
Returns:
bool: True if queries should be routed to read replica,
False if they should use primary database (default)
Note:
Returns False by default if no context is set for the current request.
The context is automatically isolated per request.
"""
return getattr(_db_routing_context, "use_read_replica", False)
def clear_read_replica_context() -> None:
"""
Clear the read replica context for the current request.
This function should be called at the end of each request to ensure
that context doesn't leak between requests. Typically called from
middleware during request cleanup.
This is important for:
- Preventing context leakage between requests
- Ensuring clean state for each new request
- Proper memory management in long-running processes
"""
try:
delattr(_db_routing_context, "use_read_replica")
except AttributeError:
pass