summaryrefslogtreecommitdiffstats
path: root/formats/command.py
blob: b92a5c8175b89cd1fb5fe29435798392fe0bf0bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: UTF-8 -*-
# Copyright 2015 Red Hat, Inc.
# Part of clufter project
# Licensed under GPLv2+ (a copy included | http://gnu.org/licenses/gpl-2.0.txt)
"""Format representing merged/isolated (1/2 levels) of single command to exec"""
__author__ = "Jan Pokorný <jpokorny @at@ Red Hat .dot. com>"

try:
    from collections import OrderedDict
except ImportError:
    from ordereddict import OrderedDict
from logging import getLogger

log = getLogger(__name__)

from ..format import SimpleFormat
from ..protocol import Protocol
from ..utils import head_tail
from ..utils_func import apply_intercalate


class command(SimpleFormat):
    native_protocol = SEPARATED = Protocol('separated')
    BYTESTRING = SimpleFormat.BYTESTRING
    DICT = Protocol('dict')
    MERGED = Protocol('merged')

    @staticmethod
    def _escape(base, qs=("'", '"')):
        # rule: last but one item in qs cannot be escaped inside enquotion
        ret = []
        for b in base:
            if ' ' in b or any(b.startswith(q) or b.endswith(q) for q in qs):
                use_q = ''
                for q in qs:
                    if q not in b:
                        use_q = q
                        break
                else:
                    use_q = qs[-1]
                    if use_q != qs[0]:
                        b = b.replace(use_q, '\\' + use_q)
                    else:
                        raise RuntimeError('cannot quote the argument')
                b = b.join((use_q, use_q))
            ret.append(b)
        return ret

    @SimpleFormat.producing(BYTESTRING, chained=True)
    def get_bytestring(self, *protodecl):
        """Return command as canonical single string"""
        # chained fallback
        return ' '.join(self.MERGED(protect_safe=True))

    @SimpleFormat.producing(SEPARATED, protect=True)
    def get_separated(self, *protodecl):
        merged = self.MERGED()
        merged.reverse()
        ret, acc = [], []
        while merged:
            i = merged.pop()
            if acc == ['--'] or i is None or i.startswith('-') and i != '-':
                if acc:
                    ret.append(tuple(acc))
                acc = [] if i is None else [i]
            elif self._dict.get('magic_split', False):
                acc.extend(i.split('::'))  # magic "::"-split
                merged.append(None)
            else:
                acc.append(i)
        # expect that, by convention, option takes at most a single argument
        ret.extend(filter(bool, (tuple(acc[:2]), tuple(acc[2:]))))
        return ret

    @SimpleFormat.producing(MERGED, protect=True)
    def get_merged(self, *protodecl):
        # try to look (indirectly) if we have "separated" at hand first
        if self.BYTESTRING in self._representations:  # break the possible loop
            from shlex import split
            ret = split(self.BYTESTRING())
            if self._dict.get('enquote', True):
                ret = self._escape(ret)
            for i, lexeme in enumerate(ret[:]):
                # heuristic(!) method to normalize: '-a=b' -> '-a', 'b'
                if (lexeme.count('=') == 1 and lexeme.startswith('-') and
                    ('"' not in lexeme or lexeme.count('"') % 2) and
                    ("'" not in lexeme or lexeme.count("'") % 2)):
                    ret[i:i + 1] = lexeme.split('=')
        elif self.DICT in self._representations:  # break the possible loop (2)
            d = self.DICT(protect_safe=True)
            if not isinstance(d, OrderedDict):
                log.warning("'{0}' format: not backed by OrderedDict".format(
                    self.__class__.name
                ))
            ret = list(d.get('__cmd__', ()))
            ret.extend((k, v) for k, vs in d.iteritems() for v in (vs or ((), ))
                                  if k not in ('__cmd__', '__args__'))
            ret.extend(d.get('__args__', ()))
        else:
            ret = self.SEPARATED(protect_safe=True)
        return apply_intercalate(ret)

    @SimpleFormat.producing(DICT, protect=True)
    # not a perfectly bijective mapping, this is a bit lossy representation,
    # on the other hand it canonicalizes the notation when turned to other forms
    def get_dict(self, *protodecl):
        separated = self.SEPARATED()
        separated.reverse()
        ret = OrderedDict()
        arg_bucket = '__cmd__'
        while separated:
            head, tail = head_tail(separated.pop())
            if head.startswith('-') and head != '-':
                arg_bucket = '__args__'
            else:
                head, tail = arg_bucket, head
            ret.setdefault(head, []).append(tail)
        return ret