diff --git a/apiserver/plane/authentication/adapter/oauth.py b/apiserver/plane/authentication/adapter/oauth.py index 348f32f66..c4e78573c 100644 --- a/apiserver/plane/authentication/adapter/oauth.py +++ b/apiserver/plane/authentication/adapter/oauth.py @@ -3,6 +3,7 @@ import requests # Django imports from django.utils import timezone +from django.db import DatabaseError, IntegrityError # Module imports from plane.db.models import Account @@ -12,6 +13,7 @@ from plane.authentication.adapter.error import ( AuthenticationException, AUTHENTICATION_ERROR_CODES, ) +from plane.utils.exception_logger import log_exception class OauthAdapter(Adapter): @@ -97,20 +99,48 @@ class OauthAdapter(Adapter): self.user_data = data def create_update_account(self, user): - account, created = Account.objects.update_or_create( - user=user, - provider=self.provider, - provider_account_id=self.user_data.get("user").get("provider_id"), - defaults={ - "access_token": self.token_data.get("access_token"), - "refresh_token": self.token_data.get("refresh_token", None), - "access_token_expired_at": self.token_data.get( + try: + # Check if the account already exists + account = Account.objects.filter( + user=user, + provider=self.provider, + provider_account_id=self.user_data.get("user").get( + "provider_id" + ), + ).first() + # Update the account if it exists + if account: + account.access_token = self.token_data.get("access_token") + account.refresh_token = self.token_data.get( + "refresh_token", None + ) + account.access_token_expired_at = self.token_data.get( "access_token_expired_at" - ), - "refresh_token_expired_at": self.token_data.get( + ) + account.refresh_token_expired_at = self.token_data.get( "refresh_token_expired_at" - ), - "last_connected_at": timezone.now(), - "id_token": self.token_data.get("id_token", ""), - }, - ) + ) + account.last_connected_at = timezone.now() + account.id_token = self.token_data.get("id_token", "") + account.save() + # Create a new account if it does not exist + else: + Account.objects.create( + user=user, + provider=self.provider, + provider_account_id=self.user_data.get("user", {}).get( + "provider_id" + ), + access_token=self.token_data.get("access_token"), + refresh_token=self.token_data.get("refresh_token", None), + access_token_expired_at=self.token_data.get( + "access_token_expired_at" + ), + refresh_token_expired_at=self.token_data.get( + "refresh_token_expired_at" + ), + last_connected_at=timezone.now(), + id_token=self.token_data.get("id_token", ""), + ) + except (DatabaseError, IntegrityError) as e: + log_exception(e)