mirror of
https://github.com/mediacms-io/mediacms.git
synced 2025-11-20 13:36:05 -05:00
feat: RBAC + SAML support
This commit is contained in:
0
saml_auth/__init__.py
Normal file
0
saml_auth/__init__.py
Normal file
153
saml_auth/adapter.py
Normal file
153
saml_auth/adapter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from allauth.socialaccount.signals import social_account_updated
|
||||
from django.core.files.base import ContentFile
|
||||
from django.dispatch import receiver
|
||||
|
||||
from identity_providers.models import IdentityProviderUserLog
|
||||
from rbac.models import RBACGroup, RBACMembership
|
||||
|
||||
|
||||
class SAMLAccountAdapter(DefaultSocialAccountAdapter):
|
||||
def is_open_for_signup(self, request, socialaccount):
|
||||
return True
|
||||
|
||||
def pre_social_login(self, request, sociallogin):
|
||||
# data = sociallogin.data
|
||||
|
||||
return super().pre_social_login(request, sociallogin)
|
||||
|
||||
def populate_user(self, request, sociallogin, data):
|
||||
user = sociallogin.user
|
||||
user.username = sociallogin.account.uid
|
||||
for item in ["name", "first_name", "last_name"]:
|
||||
if data.get(item):
|
||||
setattr(user, item, data[item])
|
||||
sociallogin.data = data
|
||||
# User is not retrieved through DB. Id is None.
|
||||
|
||||
return user
|
||||
|
||||
def save_user(self, request, sociallogin, form=None):
|
||||
user = super().save_user(request, sociallogin, form)
|
||||
# Runs after new user is created
|
||||
perform_user_actions(user, sociallogin.account)
|
||||
return user
|
||||
|
||||
|
||||
@receiver(social_account_updated)
|
||||
def social_account_updated(sender, request, sociallogin, **kwargs):
|
||||
# Runs after existing user is updated
|
||||
user = sociallogin.user
|
||||
# data is there due to populate_user
|
||||
common_fields = sociallogin.data
|
||||
perform_user_actions(user, sociallogin.account, common_fields)
|
||||
|
||||
|
||||
def perform_user_actions(user, social_account, common_fields=None):
|
||||
# common_fields is data already mapped to the attributes we want
|
||||
if common_fields:
|
||||
# check the following fields, if they are updated from the IDP side, update
|
||||
# the user object too
|
||||
fields_to_update = []
|
||||
for item in ["name", "first_name", "last_name", "email"]:
|
||||
if common_fields.get(item) and common_fields[item] != getattr(user, item):
|
||||
setattr(user, item, common_fields[item])
|
||||
fields_to_update.append(item)
|
||||
if fields_to_update:
|
||||
user.save(update_fields=fields_to_update)
|
||||
|
||||
# extra_data is the plain response from SAML provider
|
||||
|
||||
extra_data = social_account.extra_data
|
||||
# there's no FK from Social Account to Social App
|
||||
social_app = SocialApp.objects.filter(provider_id=social_account.provider).first()
|
||||
saml_configuration = None
|
||||
if social_app:
|
||||
saml_configuration = social_app.saml_configurations.first()
|
||||
|
||||
add_user_logo(user, extra_data)
|
||||
handle_role_mapping(user, extra_data, social_app, saml_configuration)
|
||||
if saml_configuration and saml_configuration.save_saml_response_logs:
|
||||
handle_saml_logs_save(user, extra_data, social_app)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def add_user_logo(user, extra_data):
|
||||
try:
|
||||
if extra_data.get("jpegPhoto") and user.logo.name in ["userlogos/user.jpg", "", None]:
|
||||
base64_string = extra_data.get("jpegPhoto")[0]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
image_content = ContentFile(image_data)
|
||||
user.logo.save('user.jpg', image_content, save=True)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return True
|
||||
|
||||
|
||||
def handle_role_mapping(user, extra_data, social_app, saml_configuration):
|
||||
if not saml_configuration:
|
||||
return False
|
||||
|
||||
rbac_groups = []
|
||||
role = "member"
|
||||
# get groups key from configuration / attributes mapping
|
||||
groups_key = saml_configuration.groups
|
||||
groups = extra_data.get(groups_key, [])
|
||||
# groups is a list of group_ids here
|
||||
|
||||
if groups:
|
||||
rbac_groups = RBACGroup.objects.filter(identity_provider=social_app, uid__in=groups)
|
||||
|
||||
try:
|
||||
# try to get the role, always use member as fallback
|
||||
role_key = saml_configuration.role
|
||||
role = extra_data.get(role_key, "student")
|
||||
if role and isinstance(role, list):
|
||||
role = role[0]
|
||||
|
||||
# populate global role
|
||||
global_role = social_app.global_roles.filter(name=role).first()
|
||||
if global_role:
|
||||
user.set_role_from_mapping(global_role.map_to)
|
||||
|
||||
group_role = social_app.group_roles.filter(name=role).first()
|
||||
if group_role:
|
||||
if group_role.map_to in ['member', 'contributor', 'manager']:
|
||||
role = group_role.map_to
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
|
||||
role = role if role in ['member', 'contributor', 'manager'] else 'member'
|
||||
|
||||
for rbac_group in rbac_groups:
|
||||
membership = RBACMembership.objects.filter(user=user, rbac_group=rbac_group).first()
|
||||
if membership and role != membership.role:
|
||||
membership.role = role
|
||||
membership.save(update_fields=["role"])
|
||||
if not membership:
|
||||
try:
|
||||
# use role from early above
|
||||
membership = RBACMembership.objects.create(user=user, rbac_group=rbac_group, role=role)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
# if remove_from_groups setting is True and user is part of groups for this
|
||||
# social app that are not included anymore on the response, then remove user from group
|
||||
if saml_configuration.remove_from_groups:
|
||||
for group in user.rbac_groups.filter(identity_provider=social_app):
|
||||
if group not in rbac_groups:
|
||||
group.members.remove(user)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def handle_saml_logs_save(user, extra_data, social_app):
|
||||
# do not save jpegPhoto, if it exists
|
||||
extra_data.pop("jpegPhoto", None)
|
||||
log = IdentityProviderUserLog.objects.create(user=user, identity_provider=social_app, logs=extra_data) # noqa
|
||||
return True
|
||||
123
saml_auth/admin.py
Normal file
123
saml_auth/admin.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import csv
|
||||
import logging
|
||||
|
||||
from django import forms
|
||||
from django.conf import settings
|
||||
from django.contrib import admin
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.html import format_html
|
||||
|
||||
from .models import SAMLConfiguration
|
||||
|
||||
|
||||
class SAMLConfigurationForm(forms.ModelForm):
|
||||
import_csv = forms.FileField(required=False, label="CSV file", help_text="Make sure headers are group_id, name")
|
||||
|
||||
class Meta:
|
||||
model = SAMLConfiguration
|
||||
fields = '__all__'
|
||||
|
||||
def clean_import_csv(self):
|
||||
csv_file = self.cleaned_data.get('import_csv')
|
||||
|
||||
if not csv_file:
|
||||
return csv_file
|
||||
|
||||
if not csv_file.name.endswith('.csv'):
|
||||
raise ValidationError("Uploaded file must be a CSV file.")
|
||||
|
||||
try:
|
||||
decoded_file = csv_file.read().decode('utf-8').splitlines()
|
||||
csv_reader = csv.reader(decoded_file)
|
||||
headers = next(csv_reader, None)
|
||||
if not headers or 'group_id' not in headers or 'name' not in headers:
|
||||
raise ValidationError("CSV file must contain 'group_id' and 'name' headers. " f"Found headers: {', '.join(headers) if headers else 'none'}")
|
||||
csv_file.seek(0)
|
||||
return csv_file
|
||||
|
||||
except csv.Error:
|
||||
raise ValidationError("Invalid CSV file. Please ensure the file is properly formatted.")
|
||||
except UnicodeDecodeError:
|
||||
raise ValidationError("Invalid file encoding. Please upload a CSV file with UTF-8 encoding.")
|
||||
|
||||
|
||||
class SAMLConfigurationAdmin(admin.ModelAdmin):
|
||||
form = SAMLConfigurationForm
|
||||
|
||||
list_display = ['social_app', 'idp_id', 'remove_from_groups', 'save_saml_response_logs', 'view_metadata_url']
|
||||
|
||||
list_filter = ['social_app', 'remove_from_groups', 'save_saml_response_logs']
|
||||
|
||||
search_fields = ['social_app__name', 'idp_id', 'sp_metadata_url']
|
||||
|
||||
fieldsets = [
|
||||
('Provider Settings', {'fields': ['social_app', 'idp_id', 'idp_cert']}),
|
||||
('URLs', {'fields': ['sso_url', 'slo_url', 'sp_metadata_url']}),
|
||||
('Group Management', {'fields': ['remove_from_groups', 'save_saml_response_logs']}),
|
||||
('Attribute Mapping', {'fields': ['uid', 'name', 'email', 'groups', 'first_name', 'last_name', 'user_logo', 'role']}),
|
||||
(
|
||||
'Email Settings',
|
||||
{
|
||||
'fields': [
|
||||
'verified_email',
|
||||
'email_authentication',
|
||||
]
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def view_metadata_url(self, obj):
|
||||
"""Display metadata URL as a clickable link"""
|
||||
return format_html('<a href="{}" target="_blank">View Metadata</a>', obj.sp_metadata_url)
|
||||
|
||||
view_metadata_url.short_description = 'Metadata'
|
||||
|
||||
def formfield_for_dbfield(self, db_field, **kwargs):
|
||||
field = super().formfield_for_dbfield(db_field, **kwargs)
|
||||
if db_field.name == 'social_app':
|
||||
field.label = 'IDP Config Name'
|
||||
return field
|
||||
|
||||
def get_fieldsets(self, request, obj=None):
|
||||
fieldsets = super().get_fieldsets(request, obj)
|
||||
|
||||
fieldsets = list(fieldsets)
|
||||
|
||||
fieldsets.append(('BULK GROUP MAPPINGS', {'fields': ('import_csv',), 'description': 'Optionally upload a CSV file with group_id and name as headers to add multiple group mappings at once.'}))
|
||||
|
||||
return fieldsets
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
super().save_model(request, obj, form, change)
|
||||
|
||||
csv_file = form.cleaned_data.get('import_csv')
|
||||
if csv_file:
|
||||
from rbac.models import RBACGroup
|
||||
|
||||
try:
|
||||
csv_file.seek(0)
|
||||
decoded_file = csv_file.read().decode('utf-8').splitlines()
|
||||
csv_reader = csv.DictReader(decoded_file)
|
||||
for row in csv_reader:
|
||||
group_id = row.get('group_id')
|
||||
name = row.get('name')
|
||||
|
||||
if group_id and name:
|
||||
if not RBACGroup.objects.filter(uid=group_id, social_app=obj.social_app).exists():
|
||||
try:
|
||||
rbac_group = RBACGroup.objects.create(uid=group_id, name=name, social_app=obj.social_app) # noqa
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
|
||||
|
||||
if getattr(settings, 'USE_SAML', False):
|
||||
for field in SAMLConfiguration._meta.fields:
|
||||
if field.name == 'social_app':
|
||||
field.verbose_name = "ID Provider"
|
||||
|
||||
admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin)
|
||||
|
||||
SAMLConfiguration._meta.app_config.verbose_name = "SAML settings and logs"
|
||||
SAMLConfiguration._meta.verbose_name_plural = "SAML Configuration"
|
||||
6
saml_auth/apps.py
Normal file
6
saml_auth/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class SamlAuthConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'saml_auth'
|
||||
0
saml_auth/custom/__init__.py
Normal file
0
saml_auth/custom/__init__.py
Normal file
61
saml_auth/custom/provider.py
Normal file
61
saml_auth/custom/provider.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
from allauth.socialaccount.providers.saml.provider import SAMLProvider
|
||||
from django.http import HttpResponseRedirect
|
||||
|
||||
from saml_auth.custom.utils import build_auth
|
||||
|
||||
|
||||
class SAMLAccount(ProviderAccount):
|
||||
pass
|
||||
|
||||
|
||||
class CustomSAMLProvider(SAMLProvider):
|
||||
def _extract(self, data):
|
||||
custom_configuration = self.app.saml_configurations.first()
|
||||
if custom_configuration:
|
||||
provider_config = custom_configuration.saml_provider_settings
|
||||
else:
|
||||
provider_config = self.app.settings
|
||||
|
||||
raw_attributes = data.get_attributes()
|
||||
attributes = {}
|
||||
attribute_mapping = provider_config.get("attribute_mapping", self.default_attribute_mapping)
|
||||
# map configured provider attributes
|
||||
for key, provider_keys in attribute_mapping.items():
|
||||
if isinstance(provider_keys, str):
|
||||
provider_keys = [provider_keys]
|
||||
for provider_key in provider_keys:
|
||||
attribute_list = raw_attributes.get(provider_key, None)
|
||||
# if more than one keys, get them all comma separated
|
||||
if attribute_list is not None and len(attribute_list) > 1:
|
||||
attributes[key] = ",".join(attribute_list)
|
||||
break
|
||||
elif attribute_list is not None and len(attribute_list) > 0:
|
||||
attributes[key] = attribute_list[0]
|
||||
break
|
||||
attributes["email_verified"] = False
|
||||
email_verified = provider_config.get("email_verified", False)
|
||||
if email_verified:
|
||||
if isinstance(email_verified, str):
|
||||
email_verified = email_verified.lower() in ["true", "1", "t", "y", "yes"]
|
||||
attributes["email_verified"] = email_verified
|
||||
# return username as the uid value
|
||||
if "uid" in attributes:
|
||||
attributes["username"] = attributes["uid"]
|
||||
# If we did not find an email, check if the NameID contains the email.
|
||||
if not attributes.get("email") and (
|
||||
data.get_nameid_format() == "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
|
||||
# Alternatively, if `use_id_for_email` is true, then we always interpret the nameID as email
|
||||
or provider_config.get("use_nameid_for_email", False) # noqa
|
||||
):
|
||||
attributes["email"] = data.get_nameid()
|
||||
|
||||
return attributes
|
||||
|
||||
def redirect(self, request, process, next_url=None, data=None, **kwargs):
|
||||
auth = build_auth(request, self)
|
||||
# If we pass `return_to=None` `auth.login` will use the URL of the
|
||||
# current view.
|
||||
redirect = auth.login(return_to="")
|
||||
self.stash_redirect_state(request, process, next_url, data, state_id=auth.get_last_request_id(), **kwargs)
|
||||
return HttpResponseRedirect(redirect)
|
||||
38
saml_auth/custom/urls.py
Normal file
38
saml_auth/custom/urls.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from django.urls import include, path, re_path
|
||||
|
||||
from . import views
|
||||
|
||||
urlpatterns = [
|
||||
re_path(
|
||||
r"^saml/(?P<organization_slug>[^/]+)/",
|
||||
include(
|
||||
[
|
||||
path(
|
||||
"acs/",
|
||||
views.acs,
|
||||
name="saml_acs",
|
||||
),
|
||||
path(
|
||||
"acs/finish/",
|
||||
views.finish_acs,
|
||||
name="saml_finish_acs",
|
||||
),
|
||||
path(
|
||||
"sls/",
|
||||
views.sls,
|
||||
name="saml_sls",
|
||||
),
|
||||
path(
|
||||
"metadata/",
|
||||
views.metadata,
|
||||
name="saml_metadata",
|
||||
),
|
||||
path(
|
||||
"login/",
|
||||
views.login,
|
||||
name="saml_login",
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
173
saml_auth/custom/utils.py
Normal file
173
saml_auth/custom/utils.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from allauth.socialaccount.providers.saml.provider import SAMLProvider
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.http import Http404
|
||||
from django.urls import reverse
|
||||
from django.utils.http import urlencode
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth
|
||||
from onelogin.saml2.constants import OneLogin_Saml2_Constants
|
||||
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
|
||||
|
||||
|
||||
def get_app_or_404(request, organization_slug):
|
||||
adapter = get_adapter()
|
||||
try:
|
||||
return adapter.get_app(request, provider=SAMLProvider.id, client_id=organization_slug)
|
||||
except SocialApp.DoesNotExist:
|
||||
raise Http404(f"no SocialApp found with client_id={organization_slug}")
|
||||
|
||||
|
||||
def prepare_django_request(request):
|
||||
result = {
|
||||
"https": "on" if request.is_secure() else "off",
|
||||
"http_host": request.META["HTTP_HOST"],
|
||||
"script_name": request.META["PATH_INFO"],
|
||||
"get_data": request.GET.copy(),
|
||||
# 'lowercase_urlencoding': True,
|
||||
"post_data": request.POST.copy(),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def build_sp_config(request, provider_config, org):
|
||||
acs_url = request.build_absolute_uri(reverse("saml_acs", args=[org]))
|
||||
sls_url = request.build_absolute_uri(reverse("saml_sls", args=[org]))
|
||||
metadata_url = request.build_absolute_uri(reverse("saml_metadata", args=[org]))
|
||||
# SP entity ID generated with the following precedence:
|
||||
# 1. Explicitly configured SP via the SocialApp.settings
|
||||
# 2. Fallback to the SAML metadata urlpattern
|
||||
_sp_config = provider_config.get("sp", {})
|
||||
sp_entity_id = _sp_config.get("entity_id")
|
||||
sp_config = {
|
||||
"entityId": sp_entity_id or metadata_url,
|
||||
"assertionConsumerService": {
|
||||
"url": acs_url,
|
||||
"binding": OneLogin_Saml2_Constants.BINDING_HTTP_POST,
|
||||
},
|
||||
"singleLogoutService": {
|
||||
"url": sls_url,
|
||||
"binding": OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT,
|
||||
},
|
||||
}
|
||||
avd = provider_config.get("advanced", {})
|
||||
if avd.get("x509cert") is not None:
|
||||
sp_config["x509cert"] = avd["x509cert"]
|
||||
|
||||
if avd.get("x509cert_new"):
|
||||
sp_config["x509certNew"] = avd["x509cert_new"]
|
||||
|
||||
if avd.get("private_key") is not None:
|
||||
sp_config["privateKey"] = avd["private_key"]
|
||||
|
||||
if avd.get("name_id_format") is not None:
|
||||
sp_config["NameIDFormat"] = avd["name_id_format"]
|
||||
|
||||
return sp_config
|
||||
|
||||
|
||||
def fetch_metadata_url_config(idp_config):
|
||||
metadata_url = idp_config["metadata_url"]
|
||||
entity_id = idp_config["entity_id"]
|
||||
cache_key = f"saml.metadata.{metadata_url}.{entity_id}"
|
||||
saml_config = cache.get(cache_key)
|
||||
if saml_config is None:
|
||||
saml_config = OneLogin_Saml2_IdPMetadataParser.parse_remote(
|
||||
metadata_url,
|
||||
entity_id=entity_id,
|
||||
timeout=idp_config.get("metadata_request_timeout", 10),
|
||||
)
|
||||
cache.set(
|
||||
cache_key,
|
||||
saml_config,
|
||||
idp_config.get("metadata_cache_timeout", 60 * 60 * 4),
|
||||
)
|
||||
return saml_config
|
||||
|
||||
|
||||
def build_saml_config(request, provider_config, org):
|
||||
avd = provider_config.get("advanced", {})
|
||||
security_config = {
|
||||
"authnRequestsSigned": avd.get("authn_request_signed", False),
|
||||
"digestAlgorithm": avd.get("digest_algorithm", OneLogin_Saml2_Constants.SHA256),
|
||||
"logoutRequestSigned": avd.get("logout_request_signed", False),
|
||||
"logoutResponseSigned": avd.get("logout_response_signed", False),
|
||||
"requestedAuthnContext": False,
|
||||
"signatureAlgorithm": avd.get("signature_algorithm", OneLogin_Saml2_Constants.RSA_SHA256),
|
||||
"signMetadata": avd.get("metadata_signed", False),
|
||||
"wantAssertionsEncrypted": avd.get("want_assertion_encrypted", False),
|
||||
"wantAssertionsSigned": avd.get("want_assertion_signed", False),
|
||||
"wantMessagesSigned": avd.get("want_message_signed", False),
|
||||
"nameIdEncrypted": avd.get("name_id_encrypted", False),
|
||||
"wantNameIdEncrypted": avd.get("want_name_id_encrypted", False),
|
||||
"allowSingleLabelDomains": avd.get("allow_single_label_domains", False),
|
||||
"rejectDeprecatedAlgorithm": avd.get("reject_deprecated_algorithm", True),
|
||||
"wantNameId": avd.get("want_name_id", False),
|
||||
"wantAttributeStatement": avd.get("want_attribute_statement", True),
|
||||
"allowRepeatAttributeName": avd.get("allow_repeat_attribute_name", True),
|
||||
}
|
||||
saml_config = {
|
||||
"strict": avd.get("strict", True),
|
||||
"security": security_config,
|
||||
}
|
||||
contact_person = provider_config.get("contact_person")
|
||||
if contact_person:
|
||||
saml_config["contactPerson"] = contact_person
|
||||
|
||||
organization = provider_config.get("organization")
|
||||
if organization:
|
||||
saml_config["organization"] = organization
|
||||
|
||||
idp = provider_config.get("idp")
|
||||
if idp is None:
|
||||
raise ImproperlyConfigured("`idp` missing")
|
||||
metadata_url = idp.get("metadata_url")
|
||||
if metadata_url:
|
||||
meta_config = fetch_metadata_url_config(idp)
|
||||
saml_config["idp"] = meta_config["idp"]
|
||||
else:
|
||||
saml_config["idp"] = {
|
||||
"entityId": idp["entity_id"],
|
||||
"x509cert": idp["x509cert"],
|
||||
"singleSignOnService": {"url": idp["sso_url"]},
|
||||
}
|
||||
slo_url = idp.get("slo_url")
|
||||
if slo_url:
|
||||
saml_config["idp"]["singleLogoutService"] = {"url": slo_url}
|
||||
|
||||
saml_config["sp"] = build_sp_config(request, provider_config, org)
|
||||
return saml_config
|
||||
|
||||
|
||||
def encode_relay_state(state):
|
||||
params = {"state": state}
|
||||
return urlencode(params)
|
||||
|
||||
|
||||
def decode_relay_state(relay_state):
|
||||
"""According to the spec, RelayState need not be a URL, yet,
|
||||
``onelogin.saml2` exposes it as ``return_to -- The target URL the user
|
||||
should be redirected to after login``. Also, for an IdP initiated login
|
||||
sometimes a URL is used.
|
||||
"""
|
||||
next_url = None
|
||||
if relay_state:
|
||||
parts = urlparse(relay_state)
|
||||
if parts.scheme or parts.netloc or (parts.path and parts.path.startswith("/")):
|
||||
next_url = relay_state
|
||||
return next_url
|
||||
|
||||
|
||||
def build_auth(request, provider):
|
||||
req = prepare_django_request(request)
|
||||
custom_configuration = provider.app.saml_configurations.first()
|
||||
if custom_configuration:
|
||||
custom_settings = custom_configuration.saml_provider_settings
|
||||
config = build_saml_config(request, custom_settings, provider.app.client_id)
|
||||
else:
|
||||
config = build_saml_config(request, provider.app.settings, provider.app.client_id)
|
||||
auth = OneLogin_Saml2_Auth(req, config)
|
||||
return auth
|
||||
180
saml_auth/custom/views.py
Normal file
180
saml_auth/custom/views.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import binascii
|
||||
import logging
|
||||
|
||||
from allauth.account.adapter import get_adapter as get_account_adapter
|
||||
from allauth.account.internal.decorators import login_not_required
|
||||
from allauth.core.internal import httpkit
|
||||
from allauth.socialaccount.helpers import (
|
||||
complete_social_login,
|
||||
render_authentication_error,
|
||||
)
|
||||
from allauth.socialaccount.providers.base.constants import AuthError, AuthProcess
|
||||
from allauth.socialaccount.providers.base.views import BaseLoginView
|
||||
from allauth.socialaccount.sessions import LoginSession
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
|
||||
from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views import View
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Settings
|
||||
from onelogin.saml2.errors import OneLogin_Saml2_Error
|
||||
|
||||
from .utils import build_auth, build_saml_config, decode_relay_state, get_app_or_404
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAMLViewMixin:
|
||||
def get_app(self, organization_slug):
|
||||
app = get_app_or_404(self.request, organization_slug)
|
||||
return app
|
||||
|
||||
def get_provider(self, organization_slug):
|
||||
app = self.get_app(organization_slug)
|
||||
return app.get_provider(self.request)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
@method_decorator(login_not_required, name="dispatch")
|
||||
class ACSView(SAMLViewMixin, View):
|
||||
def dispatch(self, request, organization_slug):
|
||||
url = reverse(
|
||||
"saml_finish_acs",
|
||||
kwargs={"organization_slug": organization_slug},
|
||||
)
|
||||
response = HttpResponseRedirect(url)
|
||||
acs_session = LoginSession(request, "saml_acs_session", "saml-acs-session")
|
||||
acs_session.store.update({"request": httpkit.serialize_request(request)})
|
||||
acs_session.save(response)
|
||||
return response
|
||||
|
||||
|
||||
acs = ACSView.as_view()
|
||||
|
||||
|
||||
@method_decorator(login_not_required, name="dispatch")
|
||||
class FinishACSView(SAMLViewMixin, View):
|
||||
def dispatch(self, request, organization_slug):
|
||||
provider = self.get_provider(organization_slug)
|
||||
acs_session = LoginSession(request, "saml_acs_session", "saml-acs-session")
|
||||
acs_request = None
|
||||
acs_request_data = acs_session.store.get("request")
|
||||
if acs_request_data:
|
||||
acs_request = httpkit.deserialize_request(acs_request_data, HttpRequest())
|
||||
acs_session.delete()
|
||||
if not acs_request:
|
||||
logger.error("Unable to finish login, SAML ACS session missing")
|
||||
return render_authentication_error(request, provider)
|
||||
|
||||
auth = build_auth(acs_request, provider)
|
||||
error_reason = None
|
||||
errors = []
|
||||
try:
|
||||
# We're doing the check for a valid `InResponeTo` ourselves later on
|
||||
# (*) by checking if there is a matching state stashed.
|
||||
auth.process_response(request_id=None)
|
||||
except binascii.Error:
|
||||
errors = ["invalid_response"]
|
||||
error_reason = "Invalid response"
|
||||
except OneLogin_Saml2_Error as e:
|
||||
errors = ["error"]
|
||||
error_reason = str(e)
|
||||
if not errors:
|
||||
errors = auth.get_errors()
|
||||
if errors:
|
||||
# e.g. ['invalid_response']
|
||||
error_reason = auth.get_last_error_reason() or error_reason
|
||||
logger.error("Error processing SAML ACS response: %s: %s" % (", ".join(errors), error_reason))
|
||||
return render_authentication_error(
|
||||
request,
|
||||
provider,
|
||||
extra_context={
|
||||
"saml_errors": errors,
|
||||
"saml_last_error_reason": error_reason,
|
||||
},
|
||||
)
|
||||
if not auth.is_authenticated():
|
||||
return render_authentication_error(request, provider, error=AuthError.CANCELLED)
|
||||
login = provider.sociallogin_from_response(request, auth)
|
||||
# (*) If we (the SP) initiated the login, there should be a matching
|
||||
# state.
|
||||
state_id = auth.get_last_response_in_response_to()
|
||||
if state_id:
|
||||
login.state = provider.unstash_redirect_state(request, state_id)
|
||||
else:
|
||||
# IdP initiated SSO
|
||||
reject = provider.app.settings.get("advanced", {}).get("reject_idp_initiated_sso", True)
|
||||
if reject:
|
||||
logger.error("IdP initiated SSO rejected")
|
||||
return render_authentication_error(request, provider)
|
||||
next_url = decode_relay_state(acs_request.POST.get("RelayState"))
|
||||
login.state["process"] = AuthProcess.LOGIN
|
||||
if next_url:
|
||||
login.state["next"] = next_url
|
||||
return complete_social_login(request, login)
|
||||
|
||||
|
||||
finish_acs = FinishACSView.as_view()
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
@method_decorator(login_not_required, name="dispatch")
|
||||
class SLSView(SAMLViewMixin, View):
|
||||
def dispatch(self, request, organization_slug):
|
||||
provider = self.get_provider(organization_slug)
|
||||
auth = build_auth(self.request, provider)
|
||||
should_logout = request.user.is_authenticated
|
||||
account_adapter = get_account_adapter(request)
|
||||
|
||||
def force_logout():
|
||||
account_adapter.logout(request)
|
||||
|
||||
redirect_to = None
|
||||
error_reason = None
|
||||
try:
|
||||
redirect_to = auth.process_slo(delete_session_cb=force_logout, keep_local_session=not should_logout)
|
||||
except OneLogin_Saml2_Error as e:
|
||||
error_reason = str(e)
|
||||
errors = auth.get_errors()
|
||||
if errors:
|
||||
error_reason = auth.get_last_error_reason() or error_reason
|
||||
logger.error("Error processing SAML SLS response: %s: %s" % (", ".join(errors), error_reason))
|
||||
resp = HttpResponse(error_reason, content_type="text/plain")
|
||||
resp.status_code = 400
|
||||
return resp
|
||||
if not redirect_to:
|
||||
redirect_to = account_adapter.get_logout_redirect_url(request)
|
||||
return HttpResponseRedirect(redirect_to)
|
||||
|
||||
|
||||
sls = SLSView.as_view()
|
||||
|
||||
|
||||
@method_decorator(login_not_required, name="dispatch")
|
||||
class MetadataView(SAMLViewMixin, View):
|
||||
def dispatch(self, request, organization_slug):
|
||||
provider = self.get_provider(organization_slug)
|
||||
config = build_saml_config(self.request, provider.app.settings, organization_slug)
|
||||
saml_settings = OneLogin_Saml2_Settings(settings=config, sp_validation_only=True)
|
||||
metadata = saml_settings.get_sp_metadata()
|
||||
errors = saml_settings.validate_metadata(metadata)
|
||||
|
||||
if len(errors) > 0:
|
||||
resp = JsonResponse({"errors": errors})
|
||||
resp.status_code = 500
|
||||
return resp
|
||||
|
||||
return HttpResponse(content=metadata, content_type="text/xml")
|
||||
|
||||
|
||||
metadata = MetadataView.as_view()
|
||||
|
||||
|
||||
@method_decorator(login_not_required, name="dispatch")
|
||||
class LoginView(SAMLViewMixin, BaseLoginView):
|
||||
def get_provider(self):
|
||||
app = self.get_app(self.kwargs["organization_slug"])
|
||||
return app.get_provider(self.request)
|
||||
|
||||
|
||||
login = LoginView.as_view()
|
||||
44
saml_auth/migrations/0001_initial.py
Normal file
44
saml_auth/migrations/0001_initial.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-18 17:40
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('socialaccount', '0006_alter_socialaccount_extra_data'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='SAMLConfiguration',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('sso_url', models.URLField(help_text='Sign-in URL')),
|
||||
('slo_url', models.URLField(help_text='Sign-out URL')),
|
||||
('sp_metadata_url', models.URLField(help_text='https://host/saml/metadata')),
|
||||
('idp_id', models.URLField(help_text='Identity Provider ID')),
|
||||
('idp_cert', models.TextField(help_text='x509cert')),
|
||||
('uid', models.CharField(help_text='eg eduPersonPrincipalName', max_length=100)),
|
||||
('name', models.CharField(blank=True, help_text='eg displayName', max_length=100, null=True)),
|
||||
('email', models.CharField(blank=True, help_text='eg mail', max_length=100, null=True)),
|
||||
('groups', models.CharField(blank=True, help_text='eg isMemberOf', max_length=100, null=True)),
|
||||
('first_name', models.CharField(blank=True, help_text='eg gn', max_length=100, null=True)),
|
||||
('last_name', models.CharField(blank=True, help_text='eg sn', max_length=100, null=True)),
|
||||
('user_logo', models.CharField(blank=True, help_text='eg jpegPhoto', max_length=100, null=True)),
|
||||
('role', models.CharField(blank=True, help_text='eduPersonPrimaryAffiliation', max_length=100, null=True)),
|
||||
('verified_email', models.BooleanField(default=False, help_text='Mark email as verified')),
|
||||
('email_authentication', models.BooleanField(default=False, help_text='Use email authentication too')),
|
||||
('remove_from_groups', models.BooleanField(default=False, help_text='Automatically remove from groups')),
|
||||
('save_saml_response_logs', models.BooleanField(default=True)),
|
||||
('social_app', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='saml_configurations', to='socialaccount.socialapp')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'SAML Configuration',
|
||||
'verbose_name_plural': 'SAML Configurations',
|
||||
'unique_together': {('social_app', 'idp_id')},
|
||||
},
|
||||
),
|
||||
]
|
||||
0
saml_auth/migrations/__init__.py
Normal file
0
saml_auth/migrations/__init__.py
Normal file
72
saml_auth/models.py
Normal file
72
saml_auth/models.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
|
||||
|
||||
class SAMLConfiguration(models.Model):
|
||||
social_app = models.ForeignKey(SocialApp, on_delete=models.CASCADE, related_name='saml_configurations')
|
||||
|
||||
# URLs
|
||||
sso_url = models.URLField(help_text='Sign-in URL')
|
||||
slo_url = models.URLField(help_text='Sign-out URL')
|
||||
sp_metadata_url = models.URLField(help_text='https://host/saml/metadata')
|
||||
idp_id = models.URLField(help_text='Identity Provider ID')
|
||||
|
||||
# Certificates
|
||||
idp_cert = models.TextField(help_text='x509cert')
|
||||
|
||||
# Attribute Mapping Fields
|
||||
uid = models.CharField(max_length=100, help_text='eg eduPersonPrincipalName')
|
||||
name = models.CharField(max_length=100, blank=True, null=True, help_text='eg displayName')
|
||||
email = models.CharField(max_length=100, blank=True, null=True, help_text='eg mail')
|
||||
groups = models.CharField(max_length=100, blank=True, null=True, help_text='eg isMemberOf')
|
||||
first_name = models.CharField(max_length=100, blank=True, null=True, help_text='eg gn')
|
||||
last_name = models.CharField(max_length=100, blank=True, null=True, help_text='eg sn')
|
||||
user_logo = models.CharField(max_length=100, blank=True, null=True, help_text='eg jpegPhoto')
|
||||
role = models.CharField(max_length=100, blank=True, null=True, help_text='eduPersonPrimaryAffiliation')
|
||||
|
||||
verified_email = models.BooleanField(default=False, help_text='Mark email as verified')
|
||||
|
||||
email_authentication = models.BooleanField(default=False, help_text='Use email authentication too')
|
||||
|
||||
remove_from_groups = models.BooleanField(default=False, help_text='Automatically remove from groups')
|
||||
save_saml_response_logs = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
verbose_name = 'SAML Configuration'
|
||||
verbose_name_plural = 'SAML Configurations'
|
||||
unique_together = ['social_app', 'idp_id']
|
||||
|
||||
def __str__(self):
|
||||
return f'SAML Config for {self.social_app.name} - {self.idp_id}'
|
||||
|
||||
def clean(self):
|
||||
existing_conf = SAMLConfiguration.objects.filter(social_app=self.social_app)
|
||||
|
||||
if self.pk:
|
||||
existing_conf = existing_conf.exclude(pk=self.pk)
|
||||
|
||||
if existing_conf.exists():
|
||||
raise ValidationError({'social_app': 'Cannot create configuration for the same social app because one configuration already exists.'})
|
||||
|
||||
super().clean()
|
||||
|
||||
@property
|
||||
def saml_provider_settings(self):
|
||||
# provide settings in a way for Social App SAML provider
|
||||
provider_settings = {}
|
||||
provider_settings["sp"] = {"entity_id": self.sp_metadata_url}
|
||||
provider_settings["idp"] = {"slo_url": self.slo_url, "sso_url": self.sso_url, "x509cert": self.idp_cert, "entity_id": self.idp_id}
|
||||
|
||||
provider_settings["attribute_mapping"] = {
|
||||
"uid": self.uid,
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"email": self.email,
|
||||
"groups": self.groups,
|
||||
"first_name": self.first_name,
|
||||
"last_name": self.last_name,
|
||||
}
|
||||
provider_settings["email_verified"] = self.verified_email
|
||||
provider_settings["email_authentication"] = self.email_authentication
|
||||
return provider_settings
|
||||
0
saml_auth/tests.py
Normal file
0
saml_auth/tests.py
Normal file
0
saml_auth/views.py
Normal file
0
saml_auth/views.py
Normal file
Reference in New Issue
Block a user