# Copyright (C) 2015 Ipsilon project Contributors, for license see COPYING from cherrypy import config as cherrypy_config from ipsilon.util.log import Log from ipsilon.util.data import SAML2SessionStore import datetime from lasso import ( SAML2_METADATA_BINDING_SOAP, SAML2_METADATA_BINDING_REDIRECT, ) LOGGED_IN = 1 INIT_LOGOUT = 2 LOGGING_OUT = 4 LOGGED_OUT = 8 class SAMLSession(Log): """ A SAML login session. uuidval - Unique ID stored in the database session_id - ID of the login session provider_id - ID of the SP user - the login name of the user that owns the session login_session - the Login session object logoutstate - an integer constant representing where in the logout process this request is relaystate - where the user will be redirected when logout is complete request_id - the logout request ID if initiated from IdP. The logout response will include an InResponseTo value which matches this. logout_request - the Logout request object expiration_time - the time the login session expires supported_logout_mechs - logout mechanisms supported by this session """ def __init__(self, uuidval, session_id, provider_id, user, login_session, logoutstate=None, relaystate=None, logout_request=None, request_id=None, expiration_time=None, supported_logout_mechs=None): self.uuidval = uuidval self.session_id = session_id self.provider_id = provider_id self.user = user self.login_session = login_session self.logoutstate = logoutstate self.relaystate = relaystate self.request_id = request_id self.logout_request = logout_request self.expiration_time = expiration_time if supported_logout_mechs is None: supported_logout_mechs = [] self.supported_logout_mechs = supported_logout_mechs def set_logoutstate(self, relaystate=None, request=None, request_id=None): """ Update attributes needed to determine the state of the session for logout. The database is not updated when these are set. It is expected that this is called prior to start_logout() """ if relaystate: self.relaystate = relaystate if request: self.logout_request = request if request_id: self.request_id = request_id def dump(self): self.debug('session_id %s' % self.session_id) self.debug('provider_id %s' % self.provider_id) self.debug('login session %s' % self.login_session) self.debug('logoutstate %s' % self.logoutstate) self.debug('logout mech %s' % self.supported_logout_mechs) def convert(self): """ Convert this object into something suitable to store in the data backend. """ data = dict() data['session_id'] = self.session_id data['provider_id'] = self.provider_id data['user'] = self.user data['login_session'] = self.login_session data['logoutstate'] = self.logoutstate data['relaystate'] = self.relaystate data['logout_request'] = self.logout_request data['request_id'] = self.request_id data['expiration_time'] = self.expiration_time return {self.uuidval: data} class SAMLSessionFactory(Log): """ Access SAML session information. The sessions are stored via the data backend. When a user logs in, add_session() is called and a new SAMLSession created and added to the table. When a user logs out, the next login session is found and moved to sessions_logging_out. remove_session() will look in both when trying to remove a session. Returns a SAMLSession object representing the new session. """ def __init__(self, database_url): self._ss = SAML2SessionStore(database_url=database_url) self.user = None def _data_to_samlsession(self, uuidval, data): """ Convert data from the data backend to a SAMLSession object. """ return SAMLSession(uuidval, data.get('session_id'), data.get('provider_id'), data.get('user'), data.get('login_session'), data.get('logoutstate'), data.get('relaystate'), data.get('logout_request'), data.get('request_id'), data.get('expiration_time'), data.get('supported_logout_mechs')) def add_session(self, session_id, provider_id, user, login_session, request_id, supported_logout_mechs): """ Add a new login session to the table. :param session_id: The login session ID :param provider_id: The URL of the SP :param user: The NameID username :param login_session: The lasso Login session :param request_id: The request ID of the Logout :param supported_logout_mechs: A list of logout protocols supported """ self.user = user timeout = cherrypy_config['tools.sessions.timeout'] t = datetime.timedelta(seconds=timeout * 60) expiration_time = datetime.datetime.now() + t data = {'session_id': session_id, 'provider_id': provider_id, 'user': user, 'login_session': login_session, 'logoutstate': LOGGED_IN, 'expiration_time': expiration_time, 'request_id': request_id, 'supported_logout_mechs': supported_logout_mechs} uuidval = self._ss.new_session(data) return SAMLSession(uuidval, session_id, provider_id, user, login_session, LOGGED_IN, request_id=request_id, expiration_time=expiration_time) def get_session_by_id(self, session_id): """ Retrieve a session by session ID """ uuidval, data = self._ss.get_session(session_id=session_id) if uuidval is None: return None return self._data_to_samlsession(uuidval, data) def get_session_id_by_provider_id(self, provider_id): """ Return a tuple of logged-in session IDs by provider_id """ candidates = self._ss.get_user_sessions(self.user) session_ids = [] for c in candidates: key = c.keys()[0] if c[key].get('provider_id') == provider_id: samlsession = self._data_to_samlsession(key, c[key]) session_ids.append(samlsession.session_id.encode('utf-8')) return tuple(session_ids) def get_session_by_request_id(self, request_id): """ Retrieve a session by logout request ID """ uuidval, data = self._ss.get_session(request_id=request_id) if uuidval is None: return None return self._data_to_samlsession(uuidval, data) def remove_session(self, samlsession): return self._ss.remove_session(samlsession.uuidval) def remove_session_by_session_id(self, session_id): session = self.get_session_by_id(session_id) return self._ss.remove_session(session.uuidval) def start_logout(self, samlsession, relaystate=None, initial=True): """ Move a session into the logging_out state samlsession: the SAMLSession object to start logging out relaystate: URL to redirect user to when logout is completed initial: boolean to indicate if this session started logout. Only the initial session's relaystate is used. No return value """ if initial: samlsession.logoutstate = INIT_LOGOUT else: samlsession.logoutstate = LOGGING_OUT if relaystate: samlsession.relaystate = relaystate datum = samlsession.convert() self._ss.update_session(datum) def get_next_logout(self, peek=False, logout_mechs=None): """ Get the next session in the logged-in state and move it to the logging_out state. Return the session that is found. :param peek: for IdP-initiated logout we can't remove the session otherwise when the request comes back in the user won't be seen as being logged-on. :param logout_mechs: An ordered list of logout mechanisms you're looking for. For each mechanism in order loop through all sessions. If If no sessions of this method are available then try the next mechanism until exhausted. In that case None is returned. Returns a tuple of (mechanism, session) or (None, None) if no more sessions in LOGGED_IN state. """ candidates = self._ss.get_user_sessions(self.user) if logout_mechs is None: logout_mechs = [SAML2_METADATA_BINDING_REDIRECT, ] for mech in logout_mechs: for c in candidates: key = c.keys()[0] if ((int(c[key].get('logoutstate', 0)) == LOGGED_IN) and (mech in c[key].get('supported_logout_mechs'))): samlsession = self._data_to_samlsession(key, c[key]) self.start_logout(samlsession, initial=False) return (mech, samlsession) return (None, None) def get_initial_logout(self): """ Get the initial logout request. Raises ValueError if no sessions in INIT_LOGOUT state. """ candidates = self._ss.get_user_sessions(self.user) # FIXME: what does it mean if there are multiple in init? We # just return the first one for now. How do we know # it's the "right" one if multiple logouts are started # at the same time from different SPs? for c in candidates: key = c.keys()[0] if int(c[key].get('logoutstate', 0)) == INIT_LOGOUT: samlsession = self._data_to_samlsession(key, c[key]) return samlsession raise ValueError() def wipe_data(self): self._ss.wipe_data() def dump(self): """ Dump all sessions to debug log """ candidates = self._ss.get_user_sessions(self.user) count = 0 for c in candidates: key = c.keys()[0] samlsession = self._data_to_samlsession(key, c[key]) self.debug('session %d: %s' % (count, samlsession.convert())) count += 1 if __name__ == '__main__': provider1 = "http://127.0.0.10/saml2" provider2 = "http://127.0.0.11/saml2" # temporary values to simulate cherrypy cherrypy_config['tools.sessions.timeout'] = 60 factory = SAMLSessionFactory('/tmp/saml2sessions.sqlite') factory.wipe_data() sess1 = factory.add_session('_123456', provider1, "admin", "", '_1234', [SAML2_METADATA_BINDING_REDIRECT]) sess2 = factory.add_session('_789012', provider2, "testuser", "", '_7890', [SAML2_METADATA_BINDING_SOAP, SAML2_METADATA_BINDING_REDIRECT]) # Test finding sessions by provider ids = factory.get_session_id_by_provider_id(provider2) assert(len(ids) == 1) sess3 = factory.add_session('_345678', provider2, "testuser", "", '_3456', [SAML2_METADATA_BINDING_REDIRECT]) ids = factory.get_session_id_by_provider_id(provider2) assert(len(ids) == 2) # Test finding sessions by session ID test1 = factory.get_session_by_id('_123456') assert(test1.user == 'admin') assert(test1.provider_id == provider1) # Log out and remove the first session test1.set_logoutstate('http://www.example.com/idp') factory.start_logout(test1, initial=True) test1 = factory.get_session_by_id('_123456') assert(test1.relaystate == 'http://www.example.com/idp') factory.remove_session_by_session_id('_123456') # Make sure it is gone from the db test1 = factory.get_session_by_id('_123456') assert(test1 is None) test2 = factory.get_session_by_id('_789012') factory.start_logout(test2, initial=True) (lmech, test3) = factory.get_next_logout() assert(test3.session_id == '_345678') test4 = factory.get_initial_logout() assert(test4.session_id == '_789012') # Even though we've started logout, make sure we can still find # all sessions for a provider. ids = factory.get_session_id_by_provider_id(provider2) assert(len(ids) == 2)