"""

uritemplate.variable
====================

This module contains the URIVariable class which powers the URITemplate class.

What treasures await you:

- URIVariable class

You see a hammer in front of you.
What do you do?
>

"""

import collections.abc
import enum
import string
import typing as t
import urllib.parse

ScalarVariableValue = t.Union[int, float, complex, str, None]
VariableValue = t.Union[
    t.Sequence[ScalarVariableValue],
    t.List[ScalarVariableValue],
    t.Mapping[str, ScalarVariableValue],
    t.Tuple[str, ScalarVariableValue],
    ScalarVariableValue,
]
VariableValueDict = t.Dict[str, VariableValue]


_UNRESERVED_CHARACTERS: t.Final[str] = (
    f"{string.ascii_letters}{string.digits}~-_."
)
_GEN_DELIMS: t.Final[str] = ":/?#[]@"
_SUB_DELIMS: t.Final[str] = "!$&'()*+,;="
_RESERVED_CHARACTERS: t.Final[str] = f"{_GEN_DELIMS}{_SUB_DELIMS}"


class Operator(enum.Enum):
    # Section 2.2. Expressions
    #      expression    =  "{" [ operator ] variable-list "}"
    #      operator      =  op-level2 / op-level3 / op-reserve
    #      op-level2     =  "+" / "#"
    #      op-level3     =  "." / "/" / ";" / "?" / "&"
    #      op-reserve    =  "=" / "," / "!" / "@" / "|"
    default = ""  # 3.2.2. Simple String Expansiona: {var}
    # Operator Level 2 (op-level2)
    reserved = "+"  # 3.2.3. Reserved Expansion: {+var}
    fragment = "#"  # 3.2.4. Fragment Expansion: {#var}
    # Operator Level 3 (op-level3)
    # 3.2.5. Label Expansion with Dot-Prefix: {.var}
    label_with_dot_prefix = "."
    path_segment = "/"  # 3.2.6. Path Segment Expansion: {/var}
    path_style_parameter = (
        ";"  # 3.2.7. Path-Style Parameter Expansion: {;var}
    )
    form_style_query = "?"  # 3.2.8. Form-Style Query Expansion: {?var}
    # 3.2.9. Form-Style Query Continuation: {&var}
    form_style_query_continuation = "&"
    # Reserved Operators (op-reserve)
    reserved_eq = "="
    reserved_comma = ","
    reserved_bang = "!"
    reserved_at = "@"
    reserved_pipe = "|"

    def reserved_characters(self) -> str:
        # TODO: Re-enable after un-commenting 3.9
        # match self:
        #     case Operator.reserved:
        #         return _RESERVED_CHARACTERS + "%"
        #     # case Operator.default | Operator.reserved | Operator.fragment:
        #     case Operator.fragment:
        #         return _RESERVED_CHARACTERS
        #     case _:
        #         return ""
        if self == Operator.reserved:
            return _RESERVED_CHARACTERS + "%"
        if self == Operator.fragment:
            return _RESERVED_CHARACTERS
        return ""

    def expansion_separator(self) -> str:
        """Identify the separator used during expansion.

        Per `Section 3.2.1. Variable Expansion`_:

        ======  ===========    =========
        Type    Separator
        ======  ===========    =========
                ``","``        (default)
        ``+``   ``","``
        ``#``   ``","``
        ``.``   ``"."``
        ``/``   ``"/"``
        ``;``   ``";"``
        ``?``   ``"&"``
        ``&``   ``"&"``
        ======  ===========    =========

        .. _`Section 3.2.1. Variable Expansion`:
            https://www.rfc-editor.org/rfc/rfc6570#section-3.2.1
        """
        if self == Operator.label_with_dot_prefix:
            return "."
        if self == Operator.path_segment:
            return "/"
        if self == Operator.path_style_parameter:
            return ";"
        if (
            self == Operator.form_style_query
            or self == Operator.form_style_query_continuation
        ):
            return "&"
        # if self == Operator.reserved or self == Operator.fragment:
        #     return ","
        return ","
        # match self:
        #     case Operator.label_with_dot_prefix:
        #         return "."
        #     case Operator.path_segment:
        #         return "/"
        #     case Operator.path_style_parameter:
        #         return ";"
        #     case (
        #         Operator.form_style_query |
        #         Operator.form_style_query_continuation
        #     ):
        #         return "&"
        #     case Operator.reserved | Operator.fragment:
        #         return ","
        #     case _:
        #         return ","

    def variable_prefix(self) -> str:
        if self == Operator.reserved:
            return ""
        return t.cast(str, self.value)
        # match self:
        #     case Operator.reserved:
        #         return ""
        #     case _:
        #         return t.cast(str, self.value)

    def _always_quote(self, value: str) -> str:
        return quote(value, "")

    def _only_quote_unquoted_characters(self, value: str) -> str:
        if urllib.parse.unquote(value) == value:
            return quote(value, _RESERVED_CHARACTERS)
        return value

    def quote(self, value: t.Any) -> str:
        if not isinstance(value, (str, bytes)):
            value = str(value)
        if isinstance(value, bytes):
            value = value.decode()

        if self == Operator.reserved or self == Operator.fragment:
            return self._only_quote_unquoted_characters(value)
        return self._always_quote(value)

    @staticmethod
    def from_string(s: str) -> "Operator":
        return _operators.get(s, Operator.default)


_operators: t.Final[t.Dict[str, Operator]] = {
    "+": Operator.reserved,
    "#": Operator.fragment,
    ".": Operator.label_with_dot_prefix,
    "/": Operator.path_segment,
    ";": Operator.path_style_parameter,
    "?": Operator.form_style_query,
    "&": Operator.form_style_query_continuation,
    "!": Operator.reserved_bang,
    "|": Operator.reserved_pipe,
    "@": Operator.reserved_at,
    "=": Operator.reserved_eq,
    ",": Operator.reserved_comma,
}


class URIVariable:
    """This object validates everything inside the URITemplate object.

    It validates template expansions and will truncate length as decided by
    the template.

    Please note that just like the :class:`URITemplate <URITemplate>`, this
    object's ``__str__`` and ``__repr__`` methods do not return the same
    information. Calling ``str(var)`` will return the original variable.

    This object does the majority of the heavy lifting. The ``URITemplate``
    object finds the variables in the URI and then creates ``URIVariable``
    objects.  Expansions of the URI are handled by each ``URIVariable``
    object. ``URIVariable.expand()`` returns a dictionary of the original
    variable and the expanded value. Check that method's documentation for
    more information.

    """

    def __init__(self, var: str):
        #: The original string that comes through with the variable
        self.original: str = var
        #: The operator for the variable
        self.operator: Operator = Operator.default
        #: List of variables in this variable
        self.variables: t.List[t.Tuple[str, t.MutableMapping[str, t.Any]]] = (
            []
        )
        #: List of variable names
        self.variable_names: t.List[str] = []
        #: List of defaults passed in
        self.defaults: t.MutableMapping[str, ScalarVariableValue] = {}
        # Parse the variable itself.
        self.parse()

    def __repr__(self) -> str:
        return "URIVariable(%s)" % self

    def __str__(self) -> str:
        return self.original

    def parse(self) -> None:
        """Parse the variable.

        This finds the:
            - operator,
            - set of safe characters,
            - variables, and
            - defaults.

        """
        var_list_str = self.original
        if (operator_str := self.original[0]) in _operators:
            self.operator = Operator.from_string(operator_str)
            var_list_str = self.original[1:]

        var_list = var_list_str.split(",")

        for var in var_list:
            default_val = None
            name = var
            # NOTE(sigmavirus24): This is from an earlier draft but is not in
            # the specification
            if "=" in var:
                name, default_val = tuple(var.split("=", 1))

            explode = name.endswith("*")
            name = name.rstrip("*")

            prefix: t.Optional[int] = None
            if ":" in name:
                name, prefix_str = tuple(name.split(":", 1))
                prefix = int(prefix_str, 10)

            if default_val:
                self.defaults[name] = default_val

            self.variables.append(
                (name, {"explode": explode, "prefix": prefix})
            )

        self.variable_names = [varname for (varname, _) in self.variables]

    def _query_expansion(
        self,
        name: str,
        value: VariableValue,
        explode: bool,
        prefix: t.Optional[int],
    ) -> t.Optional[str]:
        """Expansion method for the '?' and '&' operators."""
        if value is None:
            return None

        tuples, items = is_list_of_tuples(value)

        safe = self.operator.reserved_characters()
        _quote = self.operator.quote
        if list_test(value) and not tuples:
            if not value:
                return None
            value = t.cast(t.Sequence[ScalarVariableValue], value)
            if explode:
                return self.operator.expansion_separator().join(
                    f"{name}={_quote(v)}" for v in value
                )
            else:
                value = ",".join(_quote(v) for v in value)
                return f"{name}={value}"

        if dict_test(value) or tuples:
            if not value:
                return None
            value = t.cast(t.Mapping[str, ScalarVariableValue], value)
            items = items or sorted(value.items())
            if explode:
                return self.operator.expansion_separator().join(
                    f"{quote(k, safe)}={_quote(v)}" for k, v in items
                )
            else:
                value = ",".join(
                    f"{quote(k, safe)},{_quote(v)}" for k, v in items
                )
                return f"{name}={value}"

        if value:
            value = t.cast(t.Text, value)
            value = value[:prefix] if prefix else value
            return f"{name}={_quote(value)}"
        return name + "="

    def _label_path_expansion(
        self,
        name: str,
        value: VariableValue,
        explode: bool,
        prefix: t.Optional[int],
    ) -> t.Optional[str]:
        """Label and path expansion method.

        Expands for operators: '/', '.'

        """
        join_str = self.operator.expansion_separator()
        safe = self.operator.reserved_characters()

        if value is None or (
            not isinstance(value, (str, int, float, complex))
            and len(value) == 0
        ):
            return None

        tuples, items = is_list_of_tuples(value)

        if list_test(value) and not tuples:
            if not explode:
                join_str = ","

            value = t.cast(t.Sequence[ScalarVariableValue], value)
            fragments = [
                self.operator.quote(v) for v in value if v is not None
            ]
            return join_str.join(fragments) if fragments else None

        if dict_test(value) or tuples:
            value = t.cast(t.Mapping[str, ScalarVariableValue], value)
            items = items or sorted(value.items())
            format_str = "%s=%s"
            if not explode:
                format_str = "%s,%s"
                join_str = ","

            expanded = join_str.join(
                format_str % (quote(k, safe), self.operator.quote(v))
                for k, v in items
                if v is not None
            )
            return expanded if expanded else None

        value = t.cast(t.Text, value)
        value = value[:prefix] if prefix else value
        return self.operator.quote(value)

    def _semi_path_expansion(
        self,
        name: str,
        value: VariableValue,
        explode: bool,
        prefix: t.Optional[int],
    ) -> t.Optional[str]:
        """Expansion method for ';' operator."""
        join_str = self.operator.expansion_separator()
        safe = self.operator.reserved_characters()

        if value is None:
            return None

        tuples, items = is_list_of_tuples(value)

        if list_test(value) and not tuples:
            value = t.cast(t.Sequence[ScalarVariableValue], value)
            if explode:
                expanded = join_str.join(
                    f"{name}={quote(v, safe)}" for v in value if v is not None
                )
                return expanded if expanded else None
            else:
                value = ",".join(quote(v, safe) for v in value)
                return f"{name}={value}"

        if dict_test(value) or tuples:
            value = t.cast(t.Mapping[str, ScalarVariableValue], value)
            items = items or sorted(value.items())

            if explode:
                return join_str.join(
                    f"{quote(k, safe)}={self.operator.quote(v)}"
                    for k, v in items
                    if v is not None
                )
            else:
                expanded = ",".join(
                    f"{quote(k, safe)},{self.operator.quote(v)}"
                    for k, v in items
                    if v is not None
                )
                return f"{name}={expanded}"

        value = t.cast(t.Text, value)
        value = value[:prefix] if prefix else value
        if value:
            return f"{name}={self.operator.quote(value)}"

        return name

    def _string_expansion(
        self,
        name: str,
        value: VariableValue,
        explode: bool,
        prefix: t.Optional[int],
    ) -> t.Optional[str]:
        if value is None:
            return None

        tuples, items = is_list_of_tuples(value)

        if list_test(value) and not tuples:
            value = t.cast(t.Sequence[ScalarVariableValue], value)
            return ",".join(self.operator.quote(v) for v in value)

        if dict_test(value) or tuples:
            value = t.cast(t.Mapping[str, ScalarVariableValue], value)
            items = items or sorted(value.items())
            format_str = "%s=%s" if explode else "%s,%s"

            return ",".join(
                format_str % (self.operator.quote(k), self.operator.quote(v))
                for k, v in items
            )

        value = t.cast(t.Text, value)
        value = value[:prefix] if prefix else value
        return self.operator.quote(value)

    def expand(
        self, var_dict: t.Optional[VariableValueDict] = None
    ) -> t.Mapping[str, str]:
        """Expand the variable in question.

        Using ``var_dict`` and the previously parsed defaults, expand this
        variable and subvariables.

        :param dict var_dict: dictionary of key-value pairs to be used during
            expansion
        :returns: dict(variable=value)

        Examples::

            # (1)
            v = URIVariable('/var')
            expansion = v.expand({'var': 'value'})
            print(expansion)
            # => {'/var': '/value'}

            # (2)
            v = URIVariable('?var,hello,x,y')
            expansion = v.expand({'var': 'value', 'hello': 'Hello World!',
                                  'x': '1024', 'y': '768'})
            print(expansion)
            # => {'?var,hello,x,y':
            #     '?var=value&hello=Hello%20World%21&x=1024&y=768'}

        """
        return_values = []
        if var_dict is None:
            return {self.original: self.original}

        for name, opts in self.variables:
            value = var_dict.get(name, None)
            if not value and value != "" and name in self.defaults:
                value = self.defaults[name]

            if value is None:
                continue

            expanded = None
            if (
                self.operator == Operator.path_segment
                or self.operator == Operator.label_with_dot_prefix
            ):
                expansion = self._label_path_expansion
            elif (
                self.operator == Operator.form_style_query
                or self.operator == Operator.form_style_query_continuation
            ):
                expansion = self._query_expansion
            elif self.operator == Operator.path_style_parameter:
                expansion = self._semi_path_expansion
            else:
                expansion = self._string_expansion
            # match self.operator:
            #     case Operator.path_segment | Operator.label_with_dot_prefix:
            #         expansion = self._label_path_expansion
            #     case (Operator.form_style_query |
            #           Operator.form_style_query_continuation):
            #         expansion = self._query_expansion
            #     case Operator.path_style_parameter:
            #         expansion = self._semi_path_expansion
            #     case _:
            #         expansion = self._string_expansion

            expanded = expansion(name, value, opts["explode"], opts["prefix"])

            if expanded is not None:
                return_values.append(expanded)

        value = ""
        if return_values:
            value = (
                self.operator.variable_prefix()
                + self.operator.expansion_separator().join(return_values)
            )
        return {self.original: value}


def is_list_of_tuples(
    value: t.Any,
) -> t.Tuple[bool, t.Optional[t.Sequence[t.Tuple[str, ScalarVariableValue]]]]:
    if (
        not value
        or not isinstance(value, (list, tuple))
        or not all(isinstance(t, tuple) and len(t) == 2 for t in value)
    ):
        return False, None

    return True, value


def list_test(value: t.Any) -> bool:
    return isinstance(value, (list, tuple))


def dict_test(value: t.Any) -> bool:
    return isinstance(value, (dict, collections.abc.MutableMapping))


def _encode(value: t.AnyStr, encoding: str = "utf-8") -> bytes:
    if isinstance(value, str):
        return value.encode(encoding)
    return value


def quote(value: t.Any, safe: str) -> str:
    if not isinstance(value, (str, bytes)):
        value = str(value)
    return urllib.parse.quote(_encode(value), safe)
