"""Helpers for AST (Abstract Syntax Tree)."""

import sys
from typing import Dict, List, Optional, Type, overload

if sys.version_info > (3, 8):
    import ast
else:
    try:
        # use typed_ast module if installed
        from typed_ast import ast3 as ast
    except ImportError:
        import ast  # type: ignore


OPERATORS: Dict[Type[ast.AST], str] = {
    ast.Add: "+",
    ast.And: "and",
    ast.BitAnd: "&",
    ast.BitOr: "|",
    ast.BitXor: "^",
    ast.Div: "/",
    ast.FloorDiv: "//",
    ast.Invert: "~",
    ast.LShift: "<<",
    ast.MatMult: "@",
    ast.Mult: "*",
    ast.Mod: "%",
    ast.Not: "not",
    ast.Pow: "**",
    ast.Or: "or",
    ast.RShift: ">>",
    ast.Sub: "-",
    ast.UAdd: "+",
    ast.USub: "-",
}


def parse(code: str, mode: str = 'exec') -> "ast.AST":
    """Parse the *code* using the built-in ast or typed_ast libraries.

    This enables "type_comments" feature if possible.
    """
    try:
        # type_comments parameter is available on py38+
        return ast.parse(code, mode=mode, type_comments=True)  # type: ignore
    except SyntaxError:
        # Some syntax error found. To ignore invalid type comments, retry parsing without
        # type_comments parameter (refs: https://github.com/sphinx-doc/sphinx/issues/8652).
        return ast.parse(code, mode=mode)
    except TypeError:
        # fallback to ast module.
        # typed_ast is used to parse type_comments if installed.
        return ast.parse(code, mode=mode)


@overload
def unparse(node: None, code: str = '') -> None:
    ...


@overload
def unparse(node: ast.AST, code: str = '') -> str:
    ...


def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]:
    """Unparse an AST to string."""
    if node is None:
        return None
    elif isinstance(node, str):
        return node
    return _UnparseVisitor(code).visit(node)


# a greatly cut-down version of `ast._Unparser`
class _UnparseVisitor(ast.NodeVisitor):
    def __init__(self, code: str = '') -> None:
        self.code = code

    def _visit_op(self, node: ast.AST) -> str:
        return OPERATORS[node.__class__]
    for _op in OPERATORS:
        locals()['visit_{}'.format(_op.__name__)] = _visit_op

    def visit_arg(self, node: ast.arg) -> str:
        if node.annotation:
            return "%s: %s" % (node.arg, self.visit(node.annotation))
        else:
            return node.arg

    def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str:
        """Unparse a single argument to a string."""
        name = self.visit(arg)
        if default:
            if arg.annotation:
                name += " = %s" % self.visit(default)
            else:
                name += "=%s" % self.visit(default)
        return name

    def visit_arguments(self, node: ast.arguments) -> str:
        defaults: List[Optional[ast.expr]] = list(node.defaults)
        positionals = len(node.args)
        posonlyargs = 0
        if hasattr(node, "posonlyargs"):  # for py38+
            posonlyargs += len(node.posonlyargs)  # type:ignore
            positionals += posonlyargs
        for _ in range(len(defaults), positionals):
            defaults.insert(0, None)

        kw_defaults: List[Optional[ast.expr]] = list(node.kw_defaults)
        for _ in range(len(kw_defaults), len(node.kwonlyargs)):
            kw_defaults.insert(0, None)

        args: List[str] = []
        if hasattr(node, "posonlyargs"):  # for py38+
            for i, arg in enumerate(node.posonlyargs):  # type: ignore
                args.append(self._visit_arg_with_default(arg, defaults[i]))

            if node.posonlyargs:  # type: ignore
                args.append('/')

        for i, arg in enumerate(node.args):
            args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))

        if node.vararg:
            args.append("*" + self.visit(node.vararg))

        if node.kwonlyargs and not node.vararg:
            args.append('*')
        for i, arg in enumerate(node.kwonlyargs):
            args.append(self._visit_arg_with_default(arg, kw_defaults[i]))

        if node.kwarg:
            args.append("**" + self.visit(node.kwarg))

        return ", ".join(args)

    def visit_Attribute(self, node: ast.Attribute) -> str:
        return "%s.%s" % (self.visit(node.value), node.attr)

    def visit_BinOp(self, node: ast.BinOp) -> str:
        return " ".join(self.visit(e) for e in [node.left, node.op, node.right])

    def visit_BoolOp(self, node: ast.BoolOp) -> str:
        op = " %s " % self.visit(node.op)
        return op.join(self.visit(e) for e in node.values)

    def visit_Call(self, node: ast.Call) -> str:
        args = ([self.visit(e) for e in node.args] +
                ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
        return "%s(%s)" % (self.visit(node.func), ", ".join(args))

    def visit_Constant(self, node: ast.Constant) -> str:  # type: ignore
        if node.value is Ellipsis:
            return "..."
        elif isinstance(node.value, (int, float, complex)):
            if self.code and sys.version_info > (3, 8):
                return ast.get_source_segment(self.code, node)  # type: ignore
            else:
                return repr(node.value)
        else:
            return repr(node.value)

    def visit_Dict(self, node: ast.Dict) -> str:
        keys = (self.visit(k) for k in node.keys)
        values = (self.visit(v) for v in node.values)
        items = (k + ": " + v for k, v in zip(keys, values))
        return "{" + ", ".join(items) + "}"

    def visit_Index(self, node: ast.Index) -> str:
        return self.visit(node.value)

    def visit_Lambda(self, node: ast.Lambda) -> str:
        return "lambda %s: ..." % self.visit(node.args)

    def visit_List(self, node: ast.List) -> str:
        return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"

    def visit_Name(self, node: ast.Name) -> str:
        return node.id

    def visit_Set(self, node: ast.Set) -> str:
        return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"

    def visit_Subscript(self, node: ast.Subscript) -> str:
        def is_simple_tuple(value: ast.AST) -> bool:
            return (
                isinstance(value, ast.Tuple) and
                bool(value.elts) and
                not any(isinstance(elt, ast.Starred) for elt in value.elts)
            )

        if is_simple_tuple(node.slice):
            elts = ", ".join(self.visit(e) for e in node.slice.elts)  # type: ignore
            return "%s[%s]" % (self.visit(node.value), elts)
        elif isinstance(node.slice, ast.Index) and is_simple_tuple(node.slice.value):
            elts = ", ".join(self.visit(e) for e in node.slice.value.elts)  # type: ignore
            return "%s[%s]" % (self.visit(node.value), elts)
        else:
            return "%s[%s]" % (self.visit(node.value), self.visit(node.slice))

    def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
        return "%s %s" % (self.visit(node.op), self.visit(node.operand))

    def visit_Tuple(self, node: ast.Tuple) -> str:
        if len(node.elts) == 0:
            return "()"
        elif len(node.elts) == 1:
            return "(%s,)" % self.visit(node.elts[0])
        else:
            return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"

    if sys.version_info < (3, 8):
        # these ast nodes were deprecated in python 3.8
        def visit_Bytes(self, node: ast.Bytes) -> str:
            return repr(node.s)

        def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
            return "..."

        def visit_NameConstant(self, node: ast.NameConstant) -> str:
            return repr(node.value)

        def visit_Num(self, node: ast.Num) -> str:
            return repr(node.n)

        def visit_Str(self, node: ast.Str) -> str:
            return repr(node.s)

    def generic_visit(self, node):
        raise NotImplementedError('Unable to parse %s object' % type(node).__name__)