summaryrefslogtreecommitdiffstats
path: root/install/migration/migration.py
blob: 6b447f37797cb6e0a319b4e78e3c2f87ab237363 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Authors:
#   Pavel Zuna <pzuna@redhat.com>
#
# Copyright (C) 2009  Red Hat
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
Password migration script
"""

import errno
import ldap
import cgi
import wsgiref

BASE_DN = ''
LDAP_URI = 'ldap://localhost:389'

def wsgi_redirect(start_response, loc):
    start_response('302 Found', [('Location', loc)])
    return []

def get_ui_url(environ):
    full_url = wsgiref.util.request_uri(environ)
    index = full_url.rfind(environ.get('SCRIPT_NAME',''))
    if index == -1:
        raise ValueError('Cannot strip the script URL from full URL "%s"' % full_url)
    return full_url[:index] + "/ipa/ui"

def get_base_dn():
    """
    Retrieve LDAP server base DN.
    """
    if BASE_DN:
        return BASE_DN
    try:
        conn = ldap.initialize(LDAP_URI)
        conn.simple_bind_s('', '')
        entries = conn.search_ext_s(
            '', scope=ldap.SCOPE_BASE, attrlist=['namingcontexts']
        )
    except ldap.LDAPError:
        return ''
    conn.unbind_s()
    try:
        return entries[0][1]['namingcontexts'][0]
    except (IndexError, KeyError):
        return ''

def bind(username, password):
    base_dn = get_base_dn()
    if not base_dn:
        raise IOError(errno.EIO, 'Cannot get Base DN')
    bind_dn = 'uid=%s,cn=users,cn=accounts,%s' % (username, base_dn)
    try:
        conn = ldap.initialize(LDAP_URI)
        conn.simple_bind_s(bind_dn, password)
    except (ldap.INVALID_CREDENTIALS, ldap.UNWILLING_TO_PERFORM,
            ldap.NO_SUCH_OBJECT):
        raise IOError(errno.EPERM, 'Invalid LDAP credentials for user %s' % username)
    except ldap.LDAPError:
        raise IOError(errno.EIO, 'Bind error')

    conn.unbind_s()

def application(environ, start_response):
    if environ.get('REQUEST_METHOD', None) != 'POST':
        return wsgi_redirect(start_response, 'index.html')

    form_data = cgi.FieldStorage(fp=environ['wsgi.input'], environ=environ)
    if not form_data.has_key('username') or not form_data.has_key('password'):
        return wsgi_redirect(start_response, 'invalid.html')

    try:
        bind(form_data['username'].value, form_data['password'].value)
    except IOError as err:
        if err.errno == errno.EPERM:
            return wsgi_redirect(start_response, 'invalid.html')
        if err.errno == errno.EIO:
            return wsgi_redirect(start_response, 'error.html')

    ui_url = get_ui_url(environ)
    return wsgi_redirect(start_response, ui_url)

id='n462' href='#n462'>462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
# Authors: Rob Crittenden <rcritten@redhat.com>
#
# Copyright (C) 2008  Red Hat
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

# Documentation can be found at http://freeipa.org/page/LdapUpdate

# TODO
# save undo files?

UPDATES_DIR="/usr/share/ipa/updates/"

import sys
from ipaserver.install import installutils
from ipaserver.install import service
from ipaserver import ipaldap
from ipapython import entity, ipautil
from ipalib import util
from ipalib import errors
from ipalib import api
import ldap
from ldap.dn import escape_dn_chars
from ipapython.ipa_log_manager import *
import krbV
import platform
import time
import random
import os
import pwd
import fnmatch
import csv
import inspect
from ipaserver.install.plugins import PRE_UPDATE, POST_UPDATE
from ipaserver.install.plugins import FIRST, MIDDLE, LAST

class BadSyntax(installutils.ScriptError):
    def __init__(self, value):
        self.value = value
        self.msg = "There is a syntax error in this update file: \n  %s" % value
        self.rval = 1

    def __str__(self):
        return repr(self.value)

class LDAPUpdate:
    def __init__(self, dm_password, sub_dict={}, live_run=True,
                 online=True, ldapi=False, plugins=False):
        """dm_password = Directory Manager password
           sub_dict = substitution dictionary
           live_run = Apply the changes or just test
           online = do an online LDAP update or use an experimental LDIF updater
           ldapi = bind using ldapi. This assumes autobind is enabled.
           plugins = execute the pre/post update plugins
        """
        self.sub_dict = sub_dict
        self.live_run = live_run
        self.dm_password = dm_password
        self.conn = None
        self.modified = False
        self.online = online
        self.ldapi = ldapi
        self.plugins = plugins
        self.pw_name = pwd.getpwuid(os.geteuid()).pw_name

        if sub_dict.get("REALM"):
            self.realm = sub_dict["REALM"]
        else:
            krbctx = krbV.default_context()
            try:
                self.realm = krbctx.default_realm
                suffix = ipautil.realm_to_suffix(self.realm)
            except krbV.Krb5Error:
                self.realm = None
                suffix = None

        domain = ipautil.get_domain_name()
        libarch = self.__identify_arch()

        fqdn = installutils.get_fqdn()
        if fqdn is None:
            raise RuntimeError("Unable to determine hostname")
        fqhn = fqdn # Save this for the sub_dict variable
        if self.ldapi:
            fqdn = "ldapi://%%2fvar%%2frun%%2fslapd-%s.socket" % "-".join(
                self.realm.split(".")
            )

        if not self.sub_dict.get("REALM") and self.realm is not None:
            self.sub_dict["REALM"] = self.realm
        if not self.sub_dict.get("FQDN"):
            self.sub_dict["FQDN"] = fqhn
        if not self.sub_dict.get("DOMAIN"):
            self.sub_dict["DOMAIN"] = domain
        if not self.sub_dict.get("SUFFIX") and suffix is not None:
            self.sub_dict["SUFFIX"] = suffix
        if not self.sub_dict.get("ESCAPED_SUFFIX"):
            self.sub_dict["ESCAPED_SUFFIX"] = escape_dn_chars(suffix)
        if not self.sub_dict.get("LIBARCH"):
            self.sub_dict["LIBARCH"] = libarch
        if not self.sub_dict.get("TIME"):
            self.sub_dict["TIME"] = int(time.time())
        if not self.sub_dict.get("DOMAIN") and domain is not None:
            self.sub_dict["DOMAIN"] = domain

        if online:
            # Try out the connection/password
            try:
                conn = ipaldap.IPAdmin(fqdn, ldapi=self.ldapi, realm=self.realm)
                if self.dm_password:
                    conn.do_simple_bind(binddn="cn=directory manager", bindpw=self.dm_password)
                elif os.getegid() == 0:
                    try:
                        # autobind
                        conn.do_external_bind(self.pw_name)
                    except errors.NotFound:
                        # Fall back
                        conn.do_sasl_gssapi_bind()
                else:
                    conn.do_sasl_gssapi_bind()
                conn.unbind()
            except (ldap.CONNECT_ERROR, ldap.SERVER_DOWN):
                raise RuntimeError("Unable to connect to LDAP server %s" % fqdn)
            except ldap.INVALID_CREDENTIALS:
                raise RuntimeError("The password provided is incorrect for LDAP server %s" % fqdn)
            except ldap.LOCAL_ERROR, e:
                raise RuntimeError('%s' % e.args[0].get('info', '').strip())
        else:
            raise RuntimeError("Offline updates are not supported.")

    # The following 2 functions were taken from the Python
    # documentation at http://docs.python.org/library/csv.html
    def __utf_8_encoder(self, unicode_csv_data):
        for line in unicode_csv_data:
            yield line.encode('utf-8')

    def __unicode_csv_reader(self, unicode_csv_data, quote_char="'", dialect=csv.excel, **kwargs):
        # csv.py doesn't do Unicode; encode temporarily as UTF-8:
        csv_reader = csv.reader(self.__utf_8_encoder(unicode_csv_data),
                                dialect=dialect, delimiter=',',
                                quotechar=quote_char,
                                skipinitialspace=True,
                                **kwargs)
        for row in csv_reader:
            # decode UTF-8 back to Unicode, cell by cell:
            yield [unicode(cell, 'utf-8') for cell in row]

    def __identify_arch(self):
        """On multi-arch systems some libraries may be in /lib64, /usr/lib64,
           etc.  Determine if a suffix is needed based on the current
           architecture.
        """
        bits = platform.architecture()[0]

        if bits == "64bit":
            return "64"
        else:
            return ""

    def _template_str(self, s):
        try:
            return ipautil.template_str(s, self.sub_dict)
        except KeyError, e:
            raise BadSyntax("Unknown template keyword %s" % e)

    def __parse_values(self, line):
        """Parse a comma-separated string into separate values and convert them
           into a list. This should handle quoted-strings with embedded commas
        """
        if   line[0] == "'":
            quote_char = "'"
        else:
            quote_char = '"'
        reader = self.__unicode_csv_reader([line], quote_char)
        value = []
        for row in reader:
            value = value + row
        return value

    def read_file(self, filename):
        if filename == '-':
            fd = sys.stdin
        else:
            fd = open(filename)
        text = fd.readlines()
        if fd != sys.stdin: fd.close()
        return text

    def __entry_to_entity(self, ent):
        """Tne Entry class is a bare LDAP entry. The Entity class has a lot more
           helper functions that we need, so convert to dict and then to Entity.
        """
        entry = dict(ent.data)
        entry['dn'] = ent.dn
        for key,value in entry.iteritems():
            if isinstance(value,list) or isinstance(value,tuple):
                if len(value) == 0:
                    entry[key] = ''
                elif len(value) == 1:
                    entry[key] = value[0]
        return entity.Entity(entry)

    def __combine_updates(self, dn_list, all_updates, update):
        """Combine a new update with the list of total updates

           Updates are stored in 2 lists:
               dn_list: contains a unique list of DNs in the updates
               all_updates: the actual updates that need to be applied

           We want to apply the updates from the shortest to the longest
           path so if new child and parent entries are in different updates
           we can be sure the parent gets written first. This also lets
           us apply any schema first since it is in the very short cn=schema.
        """
        dn = update.get('dn')
        dns = ldap.explode_dn(dn.lower())
        l = len(dns)
        if dn_list.get(l):
            if dn not in dn_list[l]:
                dn_list[l].append(dn)
        else:
            dn_list[l] = [dn]
        if not all_updates.get(dn):
            all_updates[dn] = update
            return all_updates

        e = all_updates[dn]
        if 'default' in update:
            if 'default' in e:
                e['default'] = e['default'] + update['default']
            else:
                e['default'] = update['default']
        elif 'updates' in update:
            if 'updates' in e:
                e['updates'] = e['updates'] + update['updates']
            else:
                e['updates'] = update['updates']
        else:
            root_logger.debug("Unknown key in updates %s" % update.keys())

        all_updates[dn] = e

        return all_updates

    def parse_update_file(self, data, all_updates, dn_list):
        """Parse the update file into a dictonary of lists and apply the update
           for each DN in the file."""
        valid_keywords = ["default", "add", "remove", "only", "deleteentry", "replace", "addifnew", "addifexist"]
        update = {}
        d = ""
        index = ""
        dn = None
        lcount = 0
        for line in data:
            # Strip out \n and extra white space
            lcount = lcount + 1

            # skip comments and empty lines
            line = line.rstrip()
            if line.startswith('#') or line == '': continue

            if line.lower().startswith('dn:'):
                if dn is not None:
                    all_updates = self.__combine_updates(dn_list, all_updates, update)

                update = {}
                dn = line[3:].strip()
                update['dn'] = self._template_str(dn)
            else:
                if dn is None:
                    raise BadSyntax, "dn is not defined in the update"

                line = self._template_str(line)
                if line.startswith(' '):
                    v = d[len(d) - 1]
                    v = v + line[1:]
                    d[len(d) - 1] = v
                    update[index] = d
                    continue
                line = line.strip()
                values = line.split(':', 2)
                if len(values) != 3:
                    raise BadSyntax, "Bad formatting on line %d: %s" % (lcount,line)

                index = values[0].strip().lower()

                if index not in valid_keywords:
                    raise BadSyntax, "Unknown keyword %s" % index

                attr = values[1].strip()
                value = values[2].strip()

                new_value = ""
                if index == "default":
                    new_value = attr + ":" + value
                else:
                    new_value = index + ":" + attr + ":" + value
                    index = "updates"

                d = update.get(index, [])

                d.append(new_value)

                update[index] = d

        if dn is not None:
            all_updates = self.__combine_updates(dn_list, all_updates, update)

        return (all_updates, dn_list)

    def create_index_task(self, attribute):
        """Create a task to update an index for an attribute"""

        # Sleep a bit to ensure previous operations are complete
        if self.live_run:
            time.sleep(5)

        r = random.SystemRandom()

        # Refresh the time to make uniqueness more probable. Add on some
        # randomness for good measure.
        self.sub_dict['TIME'] = int(time.time()) + r.randint(0,10000)

        cn = self._template_str("indextask_$TIME")
        dn = "cn=%s, cn=index, cn=tasks, cn=config" % cn

        e = ipaldap.Entry(dn)

        e.setValues('objectClass', ['top', 'extensibleObject'])
        e.setValue('cn', cn)
        e.setValue('nsInstance', 'userRoot')
        e.setValues('nsIndexAttribute', attribute)

        root_logger.info("Creating task to index attribute: %s", attribute)
        root_logger.debug("Task id: %s", dn)

        if self.live_run:
            self.conn.addEntry(e)

        return dn

    def monitor_index_task(self, dn):
        """Give a task DN monitor it and wait until it has completed (or failed)
        """

        if not self.live_run:
            # If not doing this live there is nothing to monitor
            return

        # Pause for a moment to give the task time to be created
        time.sleep(1)

        attrlist = ['nstaskstatus', 'nstaskexitcode']
        entry = None

        while True:
            try:
                entry = self.conn.getEntry(dn, ldap.SCOPE_BASE, "(objectclass=*)", attrlist)
            except errors.NotFound, e:
                root_logger.error("Task not found: %s", dn)
                return
            except errors.DatabaseError, e:
                root_logger.error("Task lookup failure %s", e)
                return

            status = entry.getValue('nstaskstatus')
            if status is None:
                # task doesn't have a status yet
                time.sleep(1)
                continue

            if status.lower().find("finished") > -1:
                root_logger.info("Indexing finished")
                break

            root_logger.debug("Indexing in progress")
            time.sleep(1)

        return

    def __create_default_entry(self, dn, default):
        """Create the default entry from the values provided.

           The return type is entity.Entity
        """
        entry = ipaldap.Entry(dn)

        if not default:
            # This means that the entire entry needs to be created with add
            return self.__entry_to_entity(entry)

        for line in default:
            # We already do syntax-parsing so this is safe
            (k, v) = line.split(':',1)
            e = entry.getValues(k)
            if e:
                # multi-valued attribute
                e = list(e)
                e.append(v)
            else:
                e = v
            entry.setValues(k, e)

        return self.__entry_to_entity(entry)

    def __get_entry(self, dn):
        """Retrieve an object from LDAP.

           The return type is ipaldap.Entry
        """
        searchfilter="objectclass=*"
        sattrs = ["*", "aci", "attributeTypes", "objectClasses"]
        scope = ldap.SCOPE_BASE

        return self.conn.getList(dn, scope, searchfilter, sattrs)

    def __update_managed_entries(self):
        """Update and move legacy Managed Entry Plugins."""

        suffix = ipautil.realm_to_suffix(self.realm)
        searchfilter = '(objectclass=*)'
        definitions_managed_entries = []
        old_template_container = 'cn=etc,%s' % suffix
        old_definition_container = 'cn=Managed Entries,cn=plugins,cn=config'
        new = 'cn=Managed Entries,cn=etc,%s' % suffix
        sub = ['cn=Definitions,', 'cn=Templates,']
        new_managed_entries = []
        old_templates = []
        template = None
        try:
            definitions_managed_entries = self.conn.getList(old_definition_container, ldap.SCOPE_ONELEVEL, searchfilter,[])
        except errors.NotFound, e:
            return new_managed_entries
        for entry in definitions_managed_entries:
            new_definition = {}
            definition_managed_entry_updates = {}
            definitions_managed_entries
            old_definition = {'dn': entry.dn, 'deleteentry': ['dn: %s' % entry.dn]}
            old_template = entry.getValue('managedtemplate')
            entry.setValues('managedtemplate', entry.getValue('managedtemplate').replace(old_template_container, sub[1] + new))
            new_definition['dn'] = entry.dn.replace(old_definition_container, sub[0] + new)
            new_definition['default'] = str(entry).strip().replace(': ', ':').split('\n')[1:]
            definition_managed_entry_updates[new_definition['dn']] = new_definition
            definition_managed_entry_updates[old_definition['dn']] = old_definition
            old_templates.append(old_template)
            new_managed_entries.append(definition_managed_entry_updates)
        for old_template in old_templates:
            try:
                template = self.conn.getEntry(old_template, ldap.SCOPE_BASE, searchfilter,[])
                new_template = {}
                template_managed_entry_updates = {}
                old_template = {'dn': template.dn, 'deleteentry': ['dn: %s' % template.dn]}
                new_template['dn'] = template.dn.replace(old_template_container, sub[1] + new)
                new_template['default'] = str(template).strip().replace(': ', ':').split('\n')[1:]
                template_managed_entry_updates[new_template['dn']] = new_template
                template_managed_entry_updates[old_template['dn']] = old_template
                new_managed_entries.append(template_managed_entry_updates)
            except errors.NotFound, e:
                pass
        if len(new_managed_entries) > 0:
            new_managed_entries.sort(reverse=True)

        return new_managed_entries

    def __apply_updates(self, updates, entry):
        """updates is a list of changes to apply
           entry is the thing to apply them to

           Returns the modified entry
        """
        if not updates:
            return entry

        only = {}
        for u in updates:
            # We already do syntax-parsing so this is safe
            (utype, k, values) = u.split(':',2)
            values = self.__parse_values(values)

            e = entry.getValues(k)
            if not isinstance(e, list):
                if e is None:
                    e = []
                else:
                    e = [e]
            for v in values:
                if utype == 'remove':
                    root_logger.debug("remove: '%s' from %s, current value %s", v, k, e)
                    try:
                        e.remove(v)
                    except ValueError:
                        root_logger.warning("remove: '%s' not in %s", v, k)
                        pass
                    entry.setValues(k, e)
                    root_logger.debug('remove: updated value %s', e)
                elif utype == 'add':
                    root_logger.debug("add: '%s' to %s, current value %s", v, k, e)
                    # Remove it, ignoring errors so we can blindly add it later
                    try:
                        e.remove(v)
                    except ValueError:
                        pass
                    e.append(v)
                    root_logger.debug('add: updated value %s', e)
                    entry.setValues(k, e)
                elif utype == 'addifnew':
                    root_logger.debug("addifnew: '%s' to %s, current value %s", v, k, e)
                    # Only add the attribute if it doesn't exist. Only works
                    # with single-value attributes.
                    if len(e) == 0:
                        e.append(v)
                        root_logger.debug('addifnew: set %s to %s', k, e)
                        entry.setValues(k, e)
                elif utype == 'addifexist':