summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLukas Slebodnik <lslebodn@redhat.com>2016-08-10 20:05:52 +0200
committerLukas Slebodnik <lslebodn@redhat.com>2016-08-24 13:56:40 +0200
commitc596fc4d75304ff224cbad0aa2aecd3cbe82d2ff (patch)
treed5cc194b3a2277e11c7d9c97e440ce0df1807630
parent5691b2d668541585d2a8ae3ddb834f29d828036e (diff)
downloadsssd-c596fc4d75304ff224cbad0aa2aecd3cbe82d2ff.tar.gz
sssd-c596fc4d75304ff224cbad0aa2aecd3cbe82d2ff.tar.xz
sssd-c596fc4d75304ff224cbad0aa2aecd3cbe82d2ff.zip
sssd_netgroup.py: Resolve nested netgroups
Reviewed-by: Petr Čech <pcech@redhat.com>
-rw-r--r--src/tests/intg/sssd_netgroup.py224
1 files changed, 164 insertions, 60 deletions
diff --git a/src/tests/intg/sssd_netgroup.py b/src/tests/intg/sssd_netgroup.py
index 3525261cb..2c7f76fad 100644
--- a/src/tests/intg/sssd_netgroup.py
+++ b/src/tests/intg/sssd_netgroup.py
@@ -71,49 +71,173 @@ class Netgrent(Structure):
("nip", c_void_p)]
-def call_sssd_setnetgrent(netgroup):
- libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
- libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+class NetgroupRetriever(object):
+ def __init__(self, name):
+ self.name = name
+ self.needed_groups = []
+ self.known_groups = []
+ self.netgroups = []
+
+ @staticmethod
+ def _setnetgrent(netgroup):
+ """
+ This private method is ctypes wrapper for
+ enum nss_status _nss_sss_setnetgrent(const char *netgroup,
+ struct __netgrent *result)
+
+ @param string name name of netgroup
+
+ @return (int, POINTER(Netgrent)) (err, result_p)
+ err is a constant from class NssReturnCode and in case of SUCCESS
+ result_p will contain POINTER(Netgrent) which can be used in
+ _getnetgrent_r or _getnetgrent_r.
+ """
+ libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+ libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+
+ func = libnss_sss._nss_sss_setnetgrent
+ func.restype = c_int
+ func.argtypes = [c_char_p, POINTER(Netgrent)]
+
+ result = Netgrent()
+ result_p = POINTER(Netgrent)(result)
+
+ res = func(c_char_p(netgroup), result_p)
+
+ return (int(res), result_p)
+
+ @staticmethod
+ def _getnetgrent_r(result_p, buff, buff_len):
+ """
+ This private method is ctypes wrapper for
+ enum nss_status _nss_sss_getnetgrent_r(struct __netgrent *result,
+ char *buffer, size_t buflen,
+ int *errnop)
+ @param POINTER(Netgrent) result_p pointer to initialized C structure
+ struct __netgrent
+ @param ctypes.c_char_Array buff buffer used by C functions
+ @param int buff_len size of c_char_Array passed as a paramere buff
+
+ @return (int, int, List[(string, string, string])
+ (err, errno, netgroups)
+ if err is NssReturnCode.SUCCESS netgroups will contain list of
+ touples. Each touple will consist of 3 elemets either string or
+ """
+ libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+ libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+
+ func = libnss_sss._nss_sss_getnetgrent_r
+ func.restype = c_int
+ func.argtypes = [POINTER(Netgrent), POINTER(c_char), c_size_t,
+ POINTER(c_int)]
+
+ errno = POINTER(c_int)(c_int(0))
+
+ res = func(result_p, buff, buff_len, errno)
+
+ return (int(res), int(errno[0]), result_p)
+
+ @staticmethod
+ def _endnetgrent(result_p):
+ """
+ This private method is ctypes wrapper for
+ enum nss_status _nss_sss_endnetgrent(struct __netgrent *result)
+
+ @param POINTER(Netgrent) result_p pointer to initialized C structure
+ struct __netgrent
+
+ @return int a constant from class NssReturnCode
+ """
+ libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+ libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+
+ func = libnss_sss._nss_sss_endnetgrent
+ func.restype = c_int
+ func.argtypes = [POINTER(Netgrent)]
+
+ res = func(result_p)
+
+ return int(res)
+
+ def get_netgroups(self):
+ """
+ Function will return netgroup triplets for given user. All nested
+ netgroups will be retieved as part of executions and will content
+ will be merged with direct triplets.
+ Missing nested netgroups will not cause failure and are considered
+ as an empty netgroup without triplets.
+
+ @param string name name of netgroup
+
+ @return (int, int, List[(string, string, string])
+ (err, errno, netgroups)
+ if err is NssReturnCode.SUCCESS netgroups will contain list of
+ touples. Each touple will consist of 3 elemets either string or
+ None (host, user, domain).
+ """
+ res, errno, result = self._flat_fetch_netgroups(self.name)
+ if res != NssReturnCode.SUCCESS:
+ return (res, errno, self.netgroups)
+
+ self.netgroups += result
+
+ while self.needed_groups:
+ name = self.needed_groups.pop(0)
+
+ nest_res, nest_errno, result = self._flat_fetch_netgroups(name)
+ # do not fail for missing nested netgroup
+ if nest_res not in (NssReturnCode.SUCCESS, NssReturnCode.NOTFOUND):
+ return (nest_res, nest_errno, self.netgroups)
+
+ self.netgroups = result + self.netgroups
+
+ return (res, errno, self.netgroups)
+
+ def _flat_fetch_netgroups(self, name):
+ """
+ Function will return netgroup triplets for given user. The nested
+ netgroups will not be returned. Missing nested netgroups will be
+ appended to the array needed_groups
+
+ @param string name name of netgroup
+
+ @return (int, int, List[(string, string, string])
+ (err, errno, netgroups)
+ if err is NssReturnCode.SUCCESS netgroups will contain list of
+ touples. Each touple will consist of 3 elemets either string or
+ None (host, user, domain).
+ """
+ buff_len = 1024 * 1024
+ buff = create_string_buffer(buff_len)
+
+ result = []
+
+ res, result_p = self._setnetgrent(name)
+ if res != NssReturnCode.SUCCESS:
+ return (res, get_errno(), result)
+
+ res, errno, result_p = self._getnetgrent_r(result_p, buff, buff_len)
+ while res == NssReturnCode.SUCCESS:
+ if result_p[0].type == NetgroupType.GROUP_VAL:
+ nested_netgroup = result_p[0].val.group
+ if nested_netgroup not in self.known_groups:
+ self.needed_groups.append(nested_netgroup)
+ self.known_groups.append(nested_netgroup)
- func = libnss_sss._nss_sss_setnetgrent
- func.restype = c_int
- func.argtypes = [c_char_p, POINTER(Netgrent)]
+ if result_p[0].type == NetgroupType.TRIPLE_VAL:
+ result.append((result_p[0].val.triple.host,
+ result_p[0].val.triple.user,
+ result_p[0].val.triple.domain))
- result = Netgrent()
- result_p = POINTER(Netgrent)(result)
-
- res = func(c_char_p(netgroup), result_p)
-
- return (int(res), result_p)
-
-
-def call_sssd_getnetgrent_r(result_p, buff, buff_len):
- libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
- libnss_sss = cdll.LoadLibrary(libnss_sss_path)
-
- func = libnss_sss._nss_sss_getnetgrent_r
- func.restype = c_int
- func.argtypes = [POINTER(Netgrent), POINTER(c_char), c_size_t,
- POINTER(c_int)]
-
- errno = POINTER(c_int)(c_int(0))
-
- res = func(result_p, buff, buff_len, errno)
-
- return (int(res), int(errno[0]), result_p)
-
-
-def call_sssd_endnetgrent(result_p):
- libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
- libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+ res, errno, result_p = self._getnetgrent_r(result_p, buff,
+ buff_len)
- func = libnss_sss._nss_sss_endnetgrent
- func.restype = c_int
- func.argtypes = [POINTER(Netgrent)]
+ if res != NssReturnCode.RETURN:
+ return (res, errno, result)
- res = func(result_p)
+ res = self._endnetgrent(result_p)
- return int(res)
+ return (res, errno, result)
def get_sssd_netgroups(name):
@@ -129,27 +253,7 @@ def get_sssd_netgroups(name):
Each touple will consist of 3 elemets either string or None
(host, user, domain).
"""
- buff_len = 1024 * 1024
- buff = create_string_buffer(buff_len)
-
- result = []
-
- res, result_p = call_sssd_setnetgrent(name)
- if res != NssReturnCode.SUCCESS:
- return (res, get_errno(), result)
-
- res, errno, result_p = call_sssd_getnetgrent_r(result_p, buff, buff_len)
- while res == NssReturnCode.SUCCESS:
- assert result_p[0].type == NetgroupType.TRIPLE_VAL
- result.append((result_p[0].val.triple.host,
- result_p[0].val.triple.user,
- result_p[0].val.triple.domain))
- res, errno, result_p = call_sssd_getnetgrent_r(result_p, buff,
- buff_len)
-
- if res != NssReturnCode.RETURN:
- return (res, errno, result)
- res = call_sssd_endnetgrent(result_p)
+ retriever = NetgroupRetriever(name)
- return (res, errno, result)
+ return retriever.get_netgroups()