summaryrefslogtreecommitdiffstats
path: root/ipa-python/ipautil.py
blob: 407406de760bc365df07edf8d8156c367cf86d50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
# Authors: Simo Sorce <ssorce@redhat.com>
#
# Copyright (C) 2007    Red Hat
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; version 2 or later
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#

SHARE_DIR = "/usr/share/ipa/"

import string
import tempfile
import logging
import subprocess
import os
import stat
import copy
import readline
import traceback
from types import *

from string import lower
import re
import xmlrpclib
import datetime

def realm_to_suffix(realm_name):
    s = realm_name.split(".")
    terms = ["dc=" + x.lower() for x in s]
    return ",".join(terms)


def template_str(txt, vars):
    return string.Template(txt).substitute(vars)

def template_file(infilename, vars):
    txt = open(infilename).read()
    return template_str(txt, vars)

def write_tmp_file(txt):
    fd = tempfile.NamedTemporaryFile()
    fd.write(txt)
    fd.flush()

    return fd

def run(args, stdin=None):
    p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if stdin:
        stdout,stderr = p.communicate(stdin)
    else:
        stdout,stderr = p.communicate()
    logging.info(stdout)
    logging.info(stderr)

    if p.returncode != 0:
        raise subprocess.CalledProcessError(p.returncode, ' '.join(args))

def file_exists(filename):
    try:
        mode = os.stat(filename)[stat.ST_MODE]
        if stat.S_ISREG(mode):
            return True
        else:
            return False
    except:
        return False

def dir_exists(filename):
    try:
        mode = os.stat(filename)[stat.ST_MODE]
        if stat.S_ISDIR(mode):
            return True
        else:
            return False
    except:
        return False

class CIDict(dict):
    """
    Case-insensitive but case-respecting dictionary.

    This code is derived from python-ldap's cidict.py module,
    written by stroeder: http://python-ldap.sourceforge.net/

    This version extends 'dict' so it works properly with TurboGears.
    If you extend UserDict, isinstance(foo, dict) returns false.
    """

    def __init__(self,default=None):
        super(CIDict, self).__init__()
        self._keys = {}
        self.update(default or {})

    def __getitem__(self,key):
        return super(CIDict,self).__getitem__(lower(key))

    def __setitem__(self,key,value):
        lower_key = lower(key)
        self._keys[lower_key] = key
        return super(CIDict,self).__setitem__(lower(key),value)

    def __delitem__(self,key):
        lower_key = lower(key)
        del self._keys[lower_key]
        return super(CIDict,self).__delitem__(lower(key))

    def update(self,dict):
        for key in dict.keys():
            self[key] = dict[key]

    def has_key(self,key):
        return super(CIDict, self).has_key(lower(key))

    def get(self,key,failobj=None):
        try:
            return self[key]
        except KeyError:
            return failobj

    def keys(self):
        return self._keys.values()

    def items(self):
        result = []
        for k in self._keys.values():
            result.append((k,self[k]))
        return result

    def copy(self):
        copy = {}
        for k in self._keys.values():
            copy[k] = self[k]
        return copy

    def iteritems(self):
        return self.copy().iteritems()

    def iterkeys(self):
        return self.copy().iterkeys()

    def setdefault(self,key,value=None):
        try:
            return self[key]
        except KeyError:
            self[key] = value
            return value

    def pop(self, key, *args):
        try:
            value = self[key]
            del self[key]
            return value
        except KeyError:
            if len(args) == 1:
                return args[0]
            raise

    def popitem(self):
        (lower_key,value) = super(CIDict,self).popitem()
        key = self._keys[lower_key]
        del self._keys[lower_key]

        return (key,value)


#
# The safe_string_re regexp and needs_base64 function are extracted from the
# python-ldap ldif module, which was
# written by Michael Stroeder <michael@stroeder.com>
# http://python-ldap.sourceforge.net
#
# It was extracted because ipaldap.py is naughtily reaching into the ldif
# module and squashing this regexp.
#
SAFE_STRING_PATTERN = '(^(\000|\n|\r| |:|<)|[\000\n\r\200-\377]+|[ ]+$)'
safe_string_re = re.compile(SAFE_STRING_PATTERN)

def needs_base64(s):
  """
  returns 1 if s has to be base-64 encoded because of special chars
  """
  return not safe_string_re.search(s) is None


def wrap_binary_data(data):
    """Converts all binary data strings into Binary objects for transport
       back over xmlrpc."""
    if isinstance(data, str):
        if needs_base64(data):
            return xmlrpclib.Binary(data)
        else:
            return data
    elif isinstance(data, list) or isinstance(data,tuple):
        retval = []
        for value in data:
            retval.append(wrap_binary_data(value))
        return retval
    elif isinstance(data, dict):
        retval = {}
        for (k,v) in data.iteritems():
            retval[k] = wrap_binary_data(v)
        return retval
    else:
        return data


def unwrap_binary_data(data):
    """Converts all Binary objects back into strings."""
    if isinstance(data, xmlrpclib.Binary):
        # The data is decoded by the xmlproxy, but is stored
        # in a binary object for us.
        return str(data)
    elif isinstance(data, str):
        return data
    elif isinstance(data, list) or isinstance(data,tuple):
        retval = []
        for value in data:
            retval.append(unwrap_binary_data(value))
        return retval
    elif isinstance(data, dict):
        retval = {}
        for (k,v) in data.iteritems():
            retval[k] = unwrap_binary_data(v)
        return retval
    else:
        return data

class GeneralizedTimeZone(datetime.tzinfo):
    """This class is a basic timezone wrapper for the offset specified
       in a Generalized Time.  It is dst-ignorant."""
    def __init__(self,offsetstr="Z"):
        super(GeneralizedTimeZone, self).__init__()

        self.name = offsetstr
        self.houroffset = 0
        self.minoffset = 0

        if offsetstr == "Z":
            self.houroffset = 0
            self.minoffset = 0
        else:
            if (len(offsetstr) >= 3) and re.match(r'[-+]\d\d', offsetstr):
                self.houroffset = int(offsetstr[0:3])
                offsetstr = offsetstr[3:]
            if (len(offsetstr) >= 2) and re.match(r'\d\d', offsetstr):
                self.minoffset = int(offsetstr[0:2])
                offsetstr = offsetstr[2:]
            if len(offsetstr) > 0:
                raise ValueError()
        if self.houroffset < 0:
            self.minoffset *= -1

    def utcoffset(self, dt):
        return datetime.timedelta(hours=self.houroffset, minutes=self.minoffset)

    def dst(self, dt):
        return datetime.timedelta(0)

    def tzname(self, dt):
        return self.name


def parse_generalized_time(timestr):
    """Parses are Generalized Time string (as specified in X.680),
       returning a datetime object.  Generalized Times are stored inside
       the krbPasswordExpiration attribute in LDAP.

       This method doesn't attempt to be perfect wrt timezones.  If python
       can't be bothered to implement them, how can we..."""

    if len(timestr) < 8:
        return None
    try:
        date = timestr[:8]
        time = timestr[8:]

        year = int(date[:4])
        month = int(date[4:6])
        day = int(date[6:8])

        hour = min = sec = msec = 0
        tzone = None

        if (len(time) >= 2) and re.match(r'\d', time[0]):
            hour = int(time[:2])
            time = time[2:]
            if len(time) >= 2 and (time[0] == "," or time[0] == "."):
                hour_fraction = "."
                time = time[1:]
                while (len(time) > 0) and re.match(r'\d', time[0]):
                    hour_fraction += time[0]
                    time = time[1:]
                total_secs = int(float(hour_fraction) * 3600)
                min, sec = divmod(total_secs, 60)

        if (len(time) >= 2) and re.match(r'\d', time[0]):
            min = int(time[:2])
            time = time[2:]
            if len(time) >= 2 and (time[0] == "," or time[0] == "."):
                min_fraction = "."
                time = time[1:]
                while (len(time) > 0) and re.match(r'\d', time[0]):
                    min_fraction += time[0]
                    time = time[1:]
                sec = int(float(min_fraction) * 60)

        if (len(time) >= 2) and re.match(r'\d', time[0]):
            sec = int(time[:2])
            time = time[2:]
            if len(time) >= 2 and (time[0] == "," or time[0] == "."):
                sec_fraction = "."
                time = time[1:]
                while (len(time) > 0) and re.match(r'\d', time[0]):
                    sec_fraction += time[0]
                    time = time[1:]
                msec = int(float(sec_fraction) * 1000000)

        if (len(time) > 0):
            tzone = GeneralizedTimeZone(time)

        return datetime.datetime(year, month, day, hour, min, sec, msec, tzone)

    except ValueError:
        return None


def format_list(items, quote=None, page_width=80):
    '''Format a list of items formatting them so they wrap to fit the
    available width. The items will be sorted. 

    The items may optionally be quoted. The quote parameter may either be
    a string, in which case it is added before and after the item. Or the
    quote parameter may be a pair (either a tuple or list). In this case 
    quote[0] is left hand quote and quote[1] is the right hand quote.
    '''
    left_quote = right_quote = ''
    num_items = len(items)
    if not num_items: return text

    if quote is not None:
        if type(quote) in StringTypes:
            left_quote = right_quote = quote
        elif type(quote) is TupleType or type(quote) is ListType:
            left_quote = quote[0]
            right_quote = quote[1]

    max_len = max(map(len, items))
    max_len += len(left_quote) + len(right_quote)
    num_columns = (page_width + max_len) / (max_len+1)
    num_rows = (num_items + num_columns - 1) / num_columns
    items.sort()

    rows = [''] * num_rows
    i = row = col = 0

    while i < num_items:
        row = 0
        if col == 0:
            separator = ''
        else:
            separator = ' '

        while i < num_items and row < num_rows:
            rows[row] += "%s%*s" % (separator, -max_len, "%s%s%s" % (left_quote, items[i], right_quote))
            i += 1
            row += 1
        col += 1
    return '\n'.join(rows)

key_value_re = re.compile("([^\s=]+)\s*=\s*((\S+)|(?P<quote>['\\\"])((?P=quote)|(.*?[^\\\])(?P=quote)))")
def parse_key_value_pairs(input):
    ''' Given a string composed of key=value pairs parse it and return
    a dict of the key/value pairs. Keys must be a word, a key must be followed
    by an equal sign (=) and a value. The value may be a single word or may be
    quoted. Quotes may be either single or double quotes, but must be balanced.
    Inside the quoted text the same quote used to start the quoted value may be
    used if it is escaped by preceding it with a backslash (\).
    White space between the key, the equal sign, and the value is ignored.
    Values are always strings. Empty values must be specified with an empty
    quoted string, it's value after parsing will be an empty string.

    Example: The string

    arg0 = '' arg1 = 1 arg2='two' arg3 = "three's a crowd" arg4 = "this is a \" quote" 
    
    will produce

    arg0=   arg1=1
    arg2=two
    arg3=three's a crowd
    arg4=this is a " quote
    '''

    kv_dict = {}
    for match in key_value_re.finditer(input):
        key = match.group(1)
        quote = match.group('quote')
        if match.group(5):
            value = match.group(6)
            if value is None: value = ''
            value = re.sub('\\\%s' % quote, quote, value)
        else:
            value = match.group(2)
        kv_dict[key] = value
    return kv_dict

class AttributeValueCompleter:
    '''
    Gets input from the user in the form "lhs operator rhs"
    TAB completes partial input.
    lhs completes to a name in @lhs_names
    The lhs is fully parsed if a lhs_delim delimiter is seen, then TAB will
    complete to the operator and a default value.
    Default values for a lhs value can specified as:
      - a string, all lhs values will use this default
      - a dict, the lhs value is looked up in the dict to return the default or None
      - a function with a single arg, the lhs value, it returns the default or None

    After creating the completer you must open it to set the terminal
    up, Then get a line of input from the user by calling read_input()
    which returns two values, the lhs and rhs, which might be None if
    lhs or rhs was not parsed.  After you are done getting input you
    should close the completer to restore the terminal.

    Example: (note this is essentially what the convenience function get_pairs() does)

    This will allow the user to autocomplete foo & foobar, both have
    defaults defined in a dict. In addition the foobar attribute must
    be specified before the prompting loop will exit. Also, this
    example show how to require that each attrbute entered by the user
    is valid.

    attrs = ['foo', 'foobar']
    defaults = {'foo' : 'foo_default', 'foobar' : 'foobar_default'}
    mandatory_attrs = ['foobar']

    c = AttributeValueCompleter(attrs, defaults)
    c.open()
    mandatory_attrs_remaining = copy.copy(mandatory_attrs)

    while True:
        if mandatory_attrs_remaining:
            attribute, value = c.read_input("Enter: ", mandatory_attrs_remaining[0])
            try:
                mandatory_attrs_remaining.remove(attribute)
            except ValueError:
                pass
        else:
            attribute, value = c.read_input("Enter: ")
        if attribute is None:
            # Are we done?
            if mandatory_attrs_remaining:
                print "ERROR, you must specify: %s" % (','.join(mandatory_attrs_remaining))
                continue
            else:
                break
        if attribute not in attrs:
            print "ERROR: %s is not a valid attribute" % (attribute)
        else:
            print "got '%s' = '%s'" % (attribute, value)

    c.close()
    print "exiting..."
    '''

    def __init__(self, lhs_names, default_value=None, lhs_regexp=r'^\s*(?P<lhs>[^ =]+)', lhs_delims=' =',
                 operator='=', strip_rhs=True):
        self.lhs_names = lhs_names
        self.default_value = default_value
        # lhs_regexp must have named group 'lhs' which returns the contents of the lhs
        self.lhs_regexp = lhs_regexp
        self.lhs_re = re.compile(self.lhs_regexp)
        self.lhs_delims = lhs_delims
        self.operator = operator
        self.strip_rhs = strip_rhs
        self._reset()

    def _reset(self):
        self.lhs = None
        self.lhs_complete = False
        self.operator_complete = False
        self.rhs = None

    def open(self):
        # Save state
        self.prev_completer = readline.get_completer()
        self.prev_completer_delims = readline.get_completer_delims()

        # Set up for ourself
        readline.parse_and_bind("tab: complete")
        readline.set_completer(self.complete)
        readline.set_completer_delims(self.lhs_delims)

    def close(self):
        # Restore previous state
        readline.set_completer_delims(self.prev_completer_delims)
        readline.set_completer(self.prev_completer)
        
    def _debug(self):
        print  >> output_fd, "lhs='%s' lhs_complete=%s operator='%s' operator_complete=%s rhs='%s'" % \
            (self.lhs, self.lhs_complete, self.operator, self.operator_complete, self.rhs)


    def parse_input(self):
        '''We are looking for 3 tokens: <lhs,op,rhs>
        Extract as much of each token as possible.
        Set flags indicating if token is fully parsed.
        '''
        try:
            self._reset()
            buf_len = len(self.line_buffer)
            pos = 0
            lhs_match = self.lhs_re.search(self.line_buffer, pos)
            if not lhs_match: return            # no lhs content
            self.lhs = lhs_match.group('lhs')   # get lhs contents
            pos = lhs_match.end('lhs')          # new scanning position
            if pos == buf_len: return           # nothing after lhs, lhs incomplete
            self.lhs_complete = True            # something trails the lhs, lhs is complete
            operator_beg = self.line_buffer.find(self.operator, pos) # locate operator
            if operator_beg == -1: return	# did not find the operator
            self.operator_complete = True       # operator fully parsed
            operator_end = operator_beg + len(self.operator)
            pos = operator_end                  # step over the operator
            self.rhs = self.line_buffer[pos:]
        except Exception, e:
            traceback.print_exc()
            print "Exception in %s.parse_input(): %s" % (self.__class__.__name__, e)

    def get_default_value(self):
        '''default_value can be a string, a dict, or a function.
        If it's a string it's a global default for all attributes.
        If it's a dict the default is looked up in the dict index by attribute.
        If it's a function, the function is called with 1 parameter, the attribute
        and it should return the default value for the attriubte or None'''

        if not self.lhs_complete: raise ValueError("attribute not parsed")
        default_value_type = type(self.default_value)
        if default_value_type is DictType:
            return self.default_value.get(self.lhs, None)
        elif default_value_type is FunctionType:
            return self.default_value(self.lhs)
        elif default_value_type is StringsType:
            return self.default_value
        else:
            return None

    def get_lhs_completions(self, text):
        if text:
            self.completions = [lhs for lhs in self.lhs_names if lhs.startswith(text)]
        else:
            self.completions = self.lhs_names

    def complete(self, text, state):
        self.line_buffer= readline.get_line_buffer()
        self.parse_input()
        if not self.lhs_complete:
            # lhs is not complete, set up to complete the lhs
            if state == 0:
                beg = readline.get_begidx()
                end = readline.get_endidx()
                self.get_lhs_completions(self.line_buffer[beg:end])
            if state >= len(self.completions): return None
            return self.completions[state]


        elif not self.operator_complete:
            # lhs is complete, but the operator is not so we complete
            # by inserting the operator manually.
            # Also try to complete the default value at this time.
            readline.insert_text('%s ' % self.operator)
            default_value = self.get_default_value()
            if default_value is not None:
                readline.insert_text(default_value)
            readline.redisplay()
            return None
        else:
            # lhs and operator are complete, if the the rhs is blank
            # (either empty or only only whitespace) then attempt
            # to complete by inserting the default value, otherwise
            # there is nothing we can complete to so we're done.
            if self.rhs.strip():
                return None
            default_value = self.get_default_value()
            if default_value is not None:
                readline.insert_text(default_value)
                readline.redisplay()
            return None

    def pre_input_hook(self):
        readline.insert_text('%s %s ' % (self.initial_lhs, self.operator))
        readline.redisplay()

    def read_input(self, prompt, initial_lhs=None):
        self.initial_lhs = initial_lhs
        try:
            self._reset()
            if initial_lhs is None:
                readline.set_pre_input_hook(None)
            else:
                readline.set_pre_input_hook(self.pre_input_hook)
            self.line_buffer = raw_input(prompt).strip()
            self.parse_input()
            if self.strip_rhs and self.rhs is not None:
                return self.lhs, self.rhs.strip()
            else:
                return self.lhs, self.rhs
        except EOFError:
            return None, None

    def get_pairs(self, prompt, mandatory_attrs=None, validate_callback=None, must_match=True, value_required=True):
        pairs = {}
        if mandatory_attrs:
            mandatory_attrs_remaining = copy.copy(mandatory_attrs)
        else:
            mandatory_attrs_remaining = []

        print "Enter name = value"
        print "Press <ENTER> to accept, a blank line terminates input"
        print "Pressing <TAB> will auto completes name, assignment, and value"
        print
        while True:
            if mandatory_attrs_remaining:
                attribute, value = self.read_input(prompt, mandatory_attrs_remaining[0])
            else:
                attribute, value = self.read_input(prompt)
            if attribute is None:
                # Are we done?
                if mandatory_attrs_remaining:
                    print "ERROR, you must specify: %s" % (','.join(mandatory_attrs_remaining))
                    continue
                else:
                    break
            if value is None:
                if value_required:
                    print "ERROR: you must specify a value for %s" % attribute
                    continue
            else:
                if must_match and attribute not in self.lhs_names:
                    print "ERROR: %s is not a valid name" % (attribute)
                    continue
            if validate_callback is not None:
                if not validate_callback(attribute, value):
                    print "ERROR: %s is not valid for %s" % (value, attribute)
                    continue
            try:
                mandatory_attrs_remaining.remove(attribute)
            except ValueError:
                pass

            pairs[attribute] = value
        return pairs