bb-plane-fork/apps/api/plane/app/views/external/base.py
sriram veeraghanta 9237f568dd
[WEB-5044] fix: ruff lint and format errors (#7868)
* fix: lint errors

* fix: file formatting

* fix: code refactor
2025-09-29 19:15:32 +05:30

239 lines
7.6 KiB
Python

# Python import
import os
from typing import List, Dict, Tuple
# Third party import
from openai import OpenAI
import requests
from rest_framework import status
from rest_framework.response import Response
# Module import
from plane.app.permissions import ROLE, allow_permission
from plane.app.serializers import ProjectLiteSerializer, WorkspaceLiteSerializer
from plane.db.models import Project, Workspace
from plane.license.utils.instance_value import get_configuration_value
from plane.utils.exception_logger import log_exception
from ..base import BaseAPIView
class LLMProvider:
"""Base class for LLM provider configurations"""
name: str = ""
models: List[str] = []
default_model: str = ""
@classmethod
def get_config(cls) -> Dict[str, str | List[str]]:
return {
"name": cls.name,
"models": cls.models,
"default_model": cls.default_model,
}
class OpenAIProvider(LLMProvider):
name = "OpenAI"
models = ["gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o", "o1-mini", "o1-preview"]
default_model = "gpt-4o-mini"
class AnthropicProvider(LLMProvider):
name = "Anthropic"
models = [
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-2.1",
"claude-2",
"claude-instant-1.2",
"claude-instant-1",
]
default_model = "claude-3-sonnet-20240229"
class GeminiProvider(LLMProvider):
name = "Gemini"
models = ["gemini-pro", "gemini-1.5-pro-latest", "gemini-pro-vision"]
default_model = "gemini-pro"
SUPPORTED_PROVIDERS = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"gemini": GeminiProvider,
}
def get_llm_config() -> Tuple[str | None, str | None, str | None]:
"""
Helper to get LLM configuration values, returns:
- api_key, model, provider
"""
api_key, provider_key, model = get_configuration_value(
[
{
"key": "LLM_API_KEY",
"default": os.environ.get("LLM_API_KEY", None),
},
{
"key": "LLM_PROVIDER",
"default": os.environ.get("LLM_PROVIDER", "openai"),
},
{
"key": "LLM_MODEL",
"default": os.environ.get("LLM_MODEL", None),
},
]
)
provider = SUPPORTED_PROVIDERS.get(provider_key.lower())
if not provider:
log_exception(ValueError(f"Unsupported provider: {provider_key}"))
return None, None, None
if not api_key:
log_exception(ValueError(f"Missing API key for provider: {provider.name}"))
return None, None, None
# If no model specified, use provider's default
if not model:
model = provider.default_model
# Validate model is supported by provider
if model not in provider.models:
log_exception(
ValueError(
f"Model {model} not supported by {provider.name}. Supported models: {', '.join(provider.models)}"
)
)
return None, None, None
return api_key, model, provider_key
def get_llm_response(task, prompt, api_key: str, model: str, provider: str) -> Tuple[str | None, str | None]:
"""Helper to get LLM completion response"""
final_text = task + "\n" + prompt
try:
# For Gemini, prepend provider name to model
if provider.lower() == "gemini":
model = f"gemini/{model}"
client = OpenAI(api_key=api_key)
chat_completion = client.chat.completions.create(
model=model, messages=[{"role": "user", "content": final_text}]
)
text = chat_completion.choices[0].message.content
return text, None
except Exception as e:
log_exception(e)
error_type = e.__class__.__name__
if error_type == "AuthenticationError":
return None, f"Invalid API key for {provider}"
elif error_type == "RateLimitError":
return None, f"Rate limit exceeded for {provider}"
else:
return None, f"Error occurred while generating response from {provider}"
class GPTIntegrationEndpoint(BaseAPIView):
@allow_permission([ROLE.ADMIN, ROLE.MEMBER])
def post(self, request, slug, project_id):
api_key, model, provider = get_llm_config()
if not api_key or not model or not provider:
return Response(
{"error": "LLM provider API key and model are required"},
status=status.HTTP_400_BAD_REQUEST,
)
task = request.data.get("task", False)
if not task:
return Response({"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST)
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider)
if not text and error:
return Response(
{"error": "An internal error has occurred."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
workspace = Workspace.objects.get(slug=slug)
project = Project.objects.get(pk=project_id)
return Response(
{
"response": text,
"response_html": text.replace("\n", "<br/>"),
"project_detail": ProjectLiteSerializer(project).data,
"workspace_detail": WorkspaceLiteSerializer(workspace).data,
},
status=status.HTTP_200_OK,
)
class WorkspaceGPTIntegrationEndpoint(BaseAPIView):
@allow_permission(allowed_roles=[ROLE.ADMIN, ROLE.MEMBER], level="WORKSPACE")
def post(self, request, slug):
api_key, model, provider = get_llm_config()
if not api_key or not model or not provider:
return Response(
{"error": "LLM provider API key and model are required"},
status=status.HTTP_400_BAD_REQUEST,
)
task = request.data.get("task", False)
if not task:
return Response({"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST)
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider)
if not text and error:
return Response(
{"error": "An internal error has occurred."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return Response(
{
"response": text,
"response_html": text.replace("\n", "<br/>"),
},
status=status.HTTP_200_OK,
)
class UnsplashEndpoint(BaseAPIView):
def get(self, request):
(UNSPLASH_ACCESS_KEY,) = get_configuration_value(
[
{
"key": "UNSPLASH_ACCESS_KEY",
"default": os.environ.get("UNSPLASH_ACCESS_KEY"),
}
]
)
# Check unsplash access key
if not UNSPLASH_ACCESS_KEY:
return Response([], status=status.HTTP_200_OK)
# Query parameters
query = request.GET.get("query", False)
page = request.GET.get("page", 1)
per_page = request.GET.get("per_page", 20)
url = (
f"https://api.unsplash.com/search/photos/?client_id={UNSPLASH_ACCESS_KEY}&query={query}&page=${page}&per_page={per_page}"
if query
else f"https://api.unsplash.com/photos/?client_id={UNSPLASH_ACCESS_KEY}&page={page}&per_page={per_page}"
)
headers = {"Content-Type": "application/json"}
resp = requests.get(url=url, headers=headers)
return Response(resp.json(), status=resp.status_code)