2025-04-05 12:44:21 +03:00

181 lines
6.9 KiB
Python

import binascii
import logging
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.internal.decorators import login_not_required
from allauth.core.internal import httpkit
from allauth.socialaccount.helpers import (
complete_social_login,
render_authentication_error,
)
from allauth.socialaccount.providers.base.constants import AuthError, AuthProcess
from allauth.socialaccount.providers.base.views import BaseLoginView
from allauth.socialaccount.sessions import LoginSession
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.views import View
from django.views.decorators.csrf import csrf_exempt
from onelogin.saml2.auth import OneLogin_Saml2_Settings
from onelogin.saml2.errors import OneLogin_Saml2_Error
from .utils import build_auth, build_saml_config, decode_relay_state, get_app_or_404
logger = logging.getLogger(__name__)
class SAMLViewMixin:
def get_app(self, organization_slug):
app = get_app_or_404(self.request, organization_slug)
return app
def get_provider(self, organization_slug):
app = self.get_app(organization_slug)
return app.get_provider(self.request)
@method_decorator(csrf_exempt, name="dispatch")
@method_decorator(login_not_required, name="dispatch")
class ACSView(SAMLViewMixin, View):
def dispatch(self, request, organization_slug):
url = reverse(
"saml_finish_acs",
kwargs={"organization_slug": organization_slug},
)
response = HttpResponseRedirect(url)
acs_session = LoginSession(request, "saml_acs_session", "saml-acs-session")
acs_session.store.update({"request": httpkit.serialize_request(request)})
acs_session.save(response)
return response
acs = ACSView.as_view()
@method_decorator(login_not_required, name="dispatch")
class FinishACSView(SAMLViewMixin, View):
def dispatch(self, request, organization_slug):
provider = self.get_provider(organization_slug)
acs_session = LoginSession(request, "saml_acs_session", "saml-acs-session")
acs_request = None
acs_request_data = acs_session.store.get("request")
if acs_request_data:
acs_request = httpkit.deserialize_request(acs_request_data, HttpRequest())
acs_session.delete()
if not acs_request:
logger.error("Unable to finish login, SAML ACS session missing")
return render_authentication_error(request, provider)
auth = build_auth(acs_request, provider)
error_reason = None
errors = []
try:
# We're doing the check for a valid `InResponeTo` ourselves later on
# (*) by checking if there is a matching state stashed.
auth.process_response(request_id=None)
except binascii.Error:
errors = ["invalid_response"]
error_reason = "Invalid response"
except OneLogin_Saml2_Error as e:
errors = ["error"]
error_reason = str(e)
if not errors:
errors = auth.get_errors()
if errors:
# e.g. ['invalid_response']
error_reason = auth.get_last_error_reason() or error_reason
logger.error("Error processing SAML ACS response: %s: %s" % (", ".join(errors), error_reason))
return render_authentication_error(
request,
provider,
extra_context={
"saml_errors": errors,
"saml_last_error_reason": error_reason,
},
)
if not auth.is_authenticated():
return render_authentication_error(request, provider, error=AuthError.CANCELLED)
login = provider.sociallogin_from_response(request, auth)
# (*) If we (the SP) initiated the login, there should be a matching
# state.
state_id = auth.get_last_response_in_response_to()
if state_id:
login.state = provider.unstash_redirect_state(request, state_id)
else:
# IdP initiated SSO
reject = provider.app.settings.get("advanced", {}).get("reject_idp_initiated_sso", True)
if reject:
logger.error("IdP initiated SSO rejected")
return render_authentication_error(request, provider)
next_url = decode_relay_state(acs_request.POST.get("RelayState"))
login.state["process"] = AuthProcess.LOGIN
if next_url:
login.state["next"] = next_url
return complete_social_login(request, login)
finish_acs = FinishACSView.as_view()
@method_decorator(csrf_exempt, name="dispatch")
@method_decorator(login_not_required, name="dispatch")
class SLSView(SAMLViewMixin, View):
def dispatch(self, request, organization_slug):
provider = self.get_provider(organization_slug)
auth = build_auth(self.request, provider)
should_logout = request.user.is_authenticated
account_adapter = get_account_adapter(request)
def force_logout():
account_adapter.logout(request)
redirect_to = None
error_reason = None
try:
redirect_to = auth.process_slo(delete_session_cb=force_logout, keep_local_session=not should_logout)
except OneLogin_Saml2_Error as e:
error_reason = str(e)
errors = auth.get_errors()
if errors:
error_reason = auth.get_last_error_reason() or error_reason
logger.error("Error processing SAML SLS response: %s: %s" % (", ".join(errors), error_reason))
resp = HttpResponse(error_reason, content_type="text/plain")
resp.status_code = 400
return resp
if not redirect_to:
redirect_to = account_adapter.get_logout_redirect_url(request)
return HttpResponseRedirect(redirect_to)
sls = SLSView.as_view()
@method_decorator(login_not_required, name="dispatch")
class MetadataView(SAMLViewMixin, View):
def dispatch(self, request, organization_slug):
provider = self.get_provider(organization_slug)
config = build_saml_config(self.request, provider.app.settings, organization_slug)
saml_settings = OneLogin_Saml2_Settings(settings=config, sp_validation_only=True)
metadata = saml_settings.get_sp_metadata()
errors = saml_settings.validate_metadata(metadata)
if len(errors) > 0:
resp = JsonResponse({"errors": errors})
resp.status_code = 500
return resp
return HttpResponse(content=metadata, content_type="text/xml")
metadata = MetadataView.as_view()
@method_decorator(login_not_required, name="dispatch")
class LoginView(SAMLViewMixin, BaseLoginView):
def get_provider(self):
app = self.get_app(self.kwargs["organization_slug"])
return app.get_provider(self.request)
login = LoginView.as_view()