# Copyright (C) 2015 Rob Crittenden # # see file 'COPYING' for use and warranty information # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . from ipsilon.util.log import Log class SAMLSession(Log): """ A SAML login session used to track login/logout state. session_id - ID of the login session provider_id - ID of the SP session - the Login session object logoutstate - dict containing logout state info logout state is a dictionary containing (potentially) these attributes: relaystate - The relaystate from the Logout Request or Response id - The Logout request id that initiated the logout request - Dump of the initial logout request """ def __init__(self, session_id, provider_id, session, logoutstate=None): self.session_id = session_id self.provider_id = provider_id self.session = session self.logoutstate = logoutstate def set_logoutstate(self, relaystate, request_id, request=None): self.logoutstate = dict(relaystate=relaystate, id=request_id, request=request) def dump(self): self.debug('session_id %s' % self.session_id) self.debug('provider_id %s' % self.provider_id) self.debug('session %s' % self.session) self.debug('logoutstate %s' % self.logoutstate) class SAMLSessionsContainer(Log): """ Store SAML session information. The sessions are stored in two dicts which represent the state that the session is in. When a user logs in, add_session() is called and a new SAMLSession created and added to the sessions dict, keyed on provider_id. When a user logs out, the next login session is found and moved to sessions_logging_out. remove_session() will look in both when trying to remove a session. """ def __init__(self): self.sessions = dict() self.sessions_logging_out = dict() def add_session(self, session_id, provider_id, session): """ Add a new session to the logged-in bucket. Drop any existing sessions that might exist for this provider. We have no control over the SP's so if it sends us another login, accept it. """ samlsession = SAMLSession(session_id, provider_id, session) self.remove_session_by_provider(provider_id) self.sessions[provider_id] = samlsession self.dump() def remove_session_by_provider(self, provider_id): """ Remove all instances of this provider from either session pool. """ if provider_id in self.sessions: self.sessions.pop(provider_id) if provider_id in self.sessions_logging_out: self.sessions_logging_out.pop(provider_id) def find_session_by_provider(self, provider_id): """ Return a given session from either pool. Return None if no session for a provider is found. """ if provider_id in self.sessions: return self.sessions[provider_id] if provider_id in self.sessions_logging_out: return self.sessions_logging_out[provider_id] return None def start_logout(self, session): """ Move a session into the logging_out state No return value """ if session.provider_id in self.sessions_logging_out: return session = self.sessions.pop(session.provider_id) self.sessions_logging_out[session.provider_id] = session def get_next_logout(self): """ Get the next session in the logged-in state and move it to the logging_out state. Return the session that is found. Return None if no more sessions in login state. """ try: provider_id = self.sessions.keys()[0] except IndexError: return None session = self.sessions.pop(provider_id) if provider_id in self.sessions_logging_out: self.sessions_logging_out.pop(provider_id) self.sessions_logging_out[provider_id] = session return session def get_last_session(self): if self.count() != 1: raise ValueError('Not exactly one session left') try: provider_id = self.sessions_logging_out.keys()[0] except IndexError: return None return self.sessions_logging_out.pop(provider_id) def count(self): """ Return number of active login/logging out sessions. """ return len(self.sessions) + len(self.sessions_logging_out) def dump(self): count = 0 for s in self.sessions: self.debug('Login Session: %d' % count) session = self.sessions[s] session.dump() self.debug('-----------------------') count += 1 for s in self.sessions_logging_out: self.debug('Logging-out Session: %d' % count) session = self.sessions_logging_out[s] session.dump() self.debug('-----------------------') count += 1 if __name__ == '__main__': provider1 = "http://127.0.0.10/saml2" provider2 = "http://127.0.0.11/saml2" saml_sessions = SAMLSessionsContainer() try: testsession = saml_sessions.get_last_session() except ValueError: assert(saml_sessions.count() == 0) saml_sessions.add_session("_123456", provider1, "sessiondata") saml_sessions.add_session("_789012", provider2, "sessiondata") try: testsession = saml_sessions.get_last_session() except ValueError: assert(saml_sessions.count() == 2) testsession = saml_sessions.find_session_by_provider(provider1) assert(testsession.provider_id == provider1) assert(testsession.session_id == "_123456") assert(testsession.session == "sessiondata") # Test get_next_logout() by fetching both values out. Do some # basic accounting to ensure we get both values eventually. providers = [provider1, provider2] testsession = saml_sessions.get_next_logout() providers.remove(testsession.provider_id) # should be one of them testsession = saml_sessions.get_next_logout() assert(testsession.provider_id == providers[0]) # should be the other saml_sessions.start_logout(testsession) saml_sessions.remove_session_by_provider(provider2) assert(saml_sessions.count() == 1) testsession = saml_sessions.get_last_session() assert(testsession.provider_id == provider1) saml_sessions.remove_session_by_provider(provider1) assert(saml_sessions.count() == 0)