# -*- coding: shift-jis -*-

"""PMX{[W[.

BoneList(UniqueList)
    BoneCX^XǗp̃RNV.
    update\bhŎq̎QƊ֌W\z鎖ł.

Bone(pmx.Bone)
    ۂ̃f[^NX.
    BoneFactorygĐ.
"""

import pymel.core as pm
import pymel.core.nodetypes as nt
from pmxio.pymeshio import pmx, common
from pmxio.utils import NodeObject
from pmxio.trutils.collection import UniqueList
from pmxio.trutils.decorator import memonode


class TailModes(object):

    NONE = 1
    SINGLE = 2
    AVERAGE = 3


class TargetModes(object):

    TREE = 1
    INFLUENCE = 2


class BoneFactory(object):

    def __init__(self, removeNs=True, scale=1.0, tailMode=TailModes.NONE):
        self._removeNs = removeNs
        self._scale = scale
        self._tailMode = tailMode
        self._cache = {}

    @property
    def instances(self):
        """:rtype: dict"""
        return self._cache

    def _getJointTree(self, joints):
        u"""w肵WCgƓc[ɑSẴWCg擾܂.

        :rtype: list of Jointt
        """

        # WCgSẴ[gWCg擾
        rootJoints = set(j.root() for j in joints)

        # [gȉ̑SẴWCg擾
        return pm.ls(rootJoints, ap=True, dag=True, type="joint")

    @memonode
    def create(self, node):
        u"""w肵m[hBone擾܂.

        ̃\bh͈ŃLbV܂.
        ȃm[hꍇ, ftHg{[擾܂.

        :rtype: Bone
        """
        return Bone(node, self._scale, self._removeNs, self._tailMode)

    def createFromJoints(self, joints, tgtMode):
        u"""̃WCgBonẽXg擾܂.

        ̃\bh͈̃m[hŃLbV܂.
        ȃm[hꍇ, ftHg{[擾܂.

        :rtype: list of Bone
        """

        if tgtMode == TargetModes.TREE:
            joints = self._getJointTree(joints)

        return [self.create(j) for j in joints]


class BoneList(UniqueList):

    def sort(self):
        u"""{[MayaDAGXg(AEgCi)Ɠɕёւ܂."""

        dagToIndex = dict((d, i) for i, d in enumerate(pm.ls(dag=True)))
        getIndex = lambda x: dagToIndex.get(x.node, -1)
        UniqueList.sort(self, key=getIndex)

    def updateRelations(self):
        u"""{[̃CfbNX, eqQƂŐV̏ԂɍXV܂."""

        # {[̐eqQƂ\߃NAĂ
        for i, bone in enumerate(self):
            bone.index = i
            bone.parent = None
            bone.children = []

        # eqQƂݒ
        nodeToBone = dict((b.node, b) for b in self)
        for bone in self:

            # m[hȂXLbv
            node = bone.node
            if not isinstance(node, nt.Transform):
                continue

            # {[̐eqQƂݒ肷
            parent = node.getParent()
            while parent:
                pb = nodeToBone.get(parent)  # NoneԂ
                if pb:
                    bone.parent = pb
                    pb.children.append(bone)
                    break
                parent = parent.getParent()

    def update(self):
        u"""XgŐV̏Ԃɕۂ܂."""

        self.sort()
        self.updateRelations()


class Bone(pmx.Bone, NodeObject):
    u"""ۂ̃{[f[^NX.

    parantychildren݂ꍇ, parent_indexytail_indexy
    tail_position͎Iɐݒ肳.

    :
    q̐ƕ\[h̑gݍ킹ňȉ9ʂ̃p^[l.

    ڂbone_tail.xlsQ
    """

    DEF_AXIS_X = common.Vector3(1.0, 0.0, 0.0)
    DEF_AXIS_Z = common.Vector3(0.0, 0.0, -1.0)

    def __init__(self, node=None, scale=1.0, remove_ns=True, tail_mode=False):
        u"""ftHglŏ"""

        pmx.Bone.__init__(self,
                          name=u"",
                          english_name=u"",
                          position=common.Vector3(0.0, 0.0, 0.0),
                          parent_index=-1,
                          layer=0,
                          flag=0,
                          tail_position=common.Vector3(0.0, 0.0, 0.0),
                          tail_index=-1,
                          effect_index=-1,
                          effect_factor=0.0,
                          fixed_axis=common.Vector3(0.0, 0.0, 0.0),
                          local_x_vector=common.Vector3(1.0, 0.0, 0.0),
                          local_z_vector=common.Vector3(0.0, 0.0, 1.0),
                          external_key=-1,
                          ik=None)

        NodeObject.__init__(self, node, remove_ns)

        self.tail_mode = tail_mode
        self.visible = True
        self.rotatable = True
        self.translatable = True
        self.manipulatable = True
        self.parent = None
        self.children = []
        self.index = -1

        if isinstance(node, nt.Transform):
            self._fromTransform(node, scale)

    def _fromTransform(self, node, scale):
        u"""Transformp[^擾."""

        # [J]
        matrix = node.getMatrix(worldSpace=True)
        mx = matrix[0]
        mz = matrix[2]
        xAxis = common.Vector3(mx[0], mx[1], mx[2] * -1)
        zAxis = common.Vector3(mz[0], mz[1], mz[2] * -1)
        self.local_x_vector = xAxis
        self.local_z_vector = zAxis

        # [hW
        p = node.getTranslation(space="world")
        self.position = common.Vector3(p.x * scale,
                                       p.y * scale,
                                       p.z * scale * -1)

    @property
    def has_local_axis(self):
        u"""[JftHglłȂΐ^Ԃ܂.

        :rtype: bool
        """
        cls = self.__class__
        return (self.local_x_vector != cls.DEF_AXIS_X or
                self.local_z_vector != cls.DEF_AXIS_Z)

    @property
    def has_tail_bone(self):
        u"""\悪{[Ȃ^Ԃ܂.

        :rtype: bool
        """
        return (self.tail_mode != TailModes.NONE
                and len(self.children) == 1)

    @property
    def parent_index(self):
        u"""e{[ΐẽCfbNX, -1Ԃ܂.

        :rtype: int
        """
        if self.parent is not None:
            return self.parent.index
        else:
            return -1

    @property
    def tail_index(self):
        u"""\{[̃CfbNX擾܂.

        :rtype: int
        """
        if self.has_tail_bone:
            return self.children[0].index
        else:
            return -1

    @property
    def tail_position(self):
        u"""q{[̍WԂ܂.

        :rtype: common.Vector3
        """
        # q, AVERAGE[h
        if len(self.children) > 1 and self.tail_mode == TailModes.AVERAGE:

            # ς
            num = len(self.children)
            ps = common.Vector3()
            for c in self.children:
                ps += c.position
            tpos = common.Vector3(ps.x / num, ps.y / num, ps.z / num)
            return tpos - self.position

        # _Ԃ
        return common.Vector3()

    @property
    def flag(self):
        u"""vpeB{[tO쐬ĕԂ܂.

        :rtype: int
        """
        flag = 0
        if self.visible:
            flag += pmx.BONEFLAG_IS_VISIBLE
        if self.rotatable:
            flag += pmx.BONEFLAG_CAN_ROTATE
        if self.translatable:
            flag += pmx.BONEFLAG_CAN_TRANSLATE
        if self.manipulatable:
            flag += pmx.BONEFLAG_CAN_MANIPULATE
        if self.has_tail_bone:
            flag += pmx.BONEFLAG_TAILPOS_IS_BONE
        if self.has_local_axis:
            flag += pmx.BONEFLAG_HAS_LOCAL_COORDINATE

        return flag

    #  Dummy Setters --------------------------------------------------------

    @flag.setter
    def flag(self, val):
        pass

    @parent_index.setter
    def parent_index(self, val):
        pass

    @tail_index.setter
    def tail_index(self, val):
        pass

    @tail_position.setter
    def tail_position(self, val):
        pass
