# -*- coding: shift-jis -*-
import time
import math
import os.path as ospath
import maya.OpenMaya as om
import pymel.core as pm
import pymel.core.datatypes as dt
import pymel.core.nodetypes as nt
from itertools import izip

import pmxio.trutils.maya as mayautils
from pymeshio import converter, pmx
from pymeshio.pmx import reader
from pmxio.language import getCurrentLang
from pmxio.utils import rgbToColor, alphaToGlayColor
from pmxio.trutils import getGroupByCounts
from pmxio.trutils.string import getUnicode
from pmxio.trutils.maya import getSafeName, getDagPath
from pmxio.trutils.maya.skin import getSkinFn
from pmxio.trutils.maya.matrix import getRotationFromVectors
from pmxio.trutils.maya._general import getDependObject

# 萔
DEFAULT_MODEL_NAME = "PmxModel"
DEFAULT_MATERIAL_NAME = "PmxMaterial"
DEFAULT_TEXTURE_NAME = "PmxTexture"
NAME_MODE = 2
SEPARATOR_LENGTH = 32
EDGE_FACTOR_COLORSET_NAME = "EdgeMap"


class MeshData(object):

    def __init__(self, name="", indices=[], vertices=[], materials=[]):
        self.name = name
        self.indices = indices
        self.vertices = vertices
        self.materials = materials
        self.meshNode = None
        self.transNode = None
        self.blinnNodes = []
        self.sgNodes = []
        self.skinNode = None

    def createShape(self, scale, useEdgeFactor, parent=None):
        u"""Create polygon mesh from PMX model.

        To set self.meshNode and self.transNode.

        :type useEdgeFactor: bool
        :type scale: float
        """

        # Initialize variables.
        indices = self.indices
        vertices = self.vertices
        numVerts = len(vertices)
        numIndices = len(indices)
        numFaces = numIndices / 3
        points = om.MFloatPointArray(numVerts)
        normals = om.MVectorArray(numVerts)
        vtxIds = om.MIntArray(numVerts)
        faceCounts = om.MIntArray(numFaces, 3)
        vtxConnects = om.MIntArray(numIndices)
        uValues = om.MFloatArray(numVerts)
        vValues = om.MFloatArray(numVerts)
        edgeFactors = om.MColorArray(numVerts)

        # Get vertex datas.
        setId = vtxIds.set
        setPoint = points.set
        setNormal = normals.set
        setEF = edgeFactors.set
        s = scale
        for i, v in enumerate(vertices):
            setId(i, i)
            p = v.position.to_tuple()
            setPoint(om.MFloatPoint(p[0] * s, p[1] * s, p[2] * s * -1), i)
            n = v.normal.to_tuple()
            setNormal(om.MVector(n[0], n[1], n[2] * -1), i)
            uv = v.uv
            uValues[i] = uv[0]
            vValues[i] = (uv[1] * -1 + 1.0)  # Reverse v value.
            color = v.edge_factor
            setEF(om.MColor(color, 0.0, 0.0, 1.0), i)

        # Create edge connection data.
        setIdx = vtxConnects.set
        for i in xrange(0, numIndices, 3):
            setIdx(indices[i + 2], i)
            setIdx(indices[i + 1], i + 1)
            setIdx(indices[i], i + 2)

        # Create Transform node.
        transFn = om.MFnTransform()
        if parent:
            parentObj = getDependObject(parent.fullPath())
            transObj = transFn.create(parentObj)
        else:
            transObj = transFn.create()

        # Create mesh shape from got datas.
        meshFn = om.MFnMesh()
        meshFn.create(numVerts, numFaces, points,
                      faceCounts, vtxConnects,
                      uValues, vValues, transObj)

        # Set node names.
        transFn.setName(self.name)
        meshFn.setName(self.name + "Shape")

        # Assign uv set.
        meshFn.assignUVs(faceCounts, vtxConnects)

        # Set vertex normals.
        meshFn.setVertexNormals(normals, vtxIds, om.MSpace.kObject)

        # Bake edge factor.(edge width map) TODO: Support kRGB
        if useEdgeFactor:
            colorSetName = EDGE_FACTOR_COLORSET_NAME
            meshFn.createColorSetWithName(colorSetName)
            meshFn.setColors(edgeFactors, colorSetName, om.MFnMesh.kAlpha)
            meshFn.assignColors(vtxConnects, colorSetName)

        # Get PyMEL Nodes
        mesh = pm.PyNode(transFn.fullPathName()).getShape()
        trans = mesh.getParent()
        self.meshNode = mesh
        self.transNode = trans

    def createMaterials(self, fileNodes, connectAlpha):
        """Create shading nodes and to assign to faces.

        To set self.sgNodes and self.blinnNodes.
        """

        faceCount = 0
        baseName = self.meshNode.name() + ".f[%d:%d]"

        for mat in self.materials:

            if mat.vertex_count <= 0:
                continue

            # }eA쐬
            sg, blinn = self._createMaterial(mat, fileNodes, connectAlpha)
            self.sgNodes.append(sg)
            self.blinnNodes.append(blinn)

            # tF[XɃATC
            nextCount = faceCount + mat.vertex_count / 3
            faceComps = baseName % (faceCount, nextCount - 1)
            cmd = "sets -e -fe %s %s;" % (sg, faceComps)
            pm.mel.eval(cmd)  # @UndefinedVariable

            faceCount = nextCount

    def _createMaterial(self, material, fileNodes, connectAlpha):
        u"""Create blinn material from pmx material.

        :type material: pmx.Material
        :type fileNodes: pymel.core.nodetypes.ShadingEngine list
        :type connectAlpha: bool
        :returns: (nt.ShadingEngine, nt.Blinn)
        """

        # f[^擾
        name = getMaterialName(material)
        enName = getUnicode(material.english_name)
        jpName = getUnicode(material.name)
        comment = material.comment
        diffuse = rgbToColor(material.diffuse_color)
        ambient = rgbToColor(material.ambient_color)
        specular = rgbToColor(material.specular_color)
        transparency = alphaToGlayColor(1 - material.alpha)
        factor = max(material.specular_factor,
                     math.pow(2, -10))  #  2-10
        rollOff = math.pow(0.75, math.log(factor, 2) + 1)

        # BlinnVF[_[쐬
        blinnNode = pm.shadingNode("blinn", n=name, asShader=True)

        # VF[fBOO[v쐬
        name = blinnNode.name() + "SG"
        sgNode = pm.sets(n=name, renderable=True, noSurfaceShader=True,
                         empty=True)
        blinnNode.outColor >> sgNode.surfaceShader

        # ǉ̃Agr[g
        pm.addAttr(blinnNode, ln="pmxName", dt="string")
        pm.addAttr(blinnNode, ln="pmxEnglishName", dt="string")
        pm.addAttr(blinnNode, ln="notes", dt="string")

        # Agr[gݒ
        blinnNode.color.set(diffuse)
        blinnNode.ambientColor.set(ambient)
        blinnNode.specularColor.set(specular)
        blinnNode.transparency.set(transparency)
        blinnNode.specularRollOff.set(rollOff)
        blinnNode.pmxName.set(jpName)
        blinnNode.pmxEnglishName.set(enName)
        blinnNode.notes.set(comment)

        # eNX`ڑ
        texIndex = material.texture_index
        if texIndex >= 0:
            if not isinstance(fileNodes[texIndex], nt.File):
                fileNode = createTexture(fileNodes[texIndex])
                fileNodes[texIndex] = fileNode

            fileNode = fileNodes[texIndex]
            if fileNode:
                fileNode.outColor >> blinnNode.color
                # xڑ
                if connectAlpha and fileNode.fileHasAlpha.get():
                    fileNode.outTransparency >> blinnNode.transparency

        return (sgNode, blinnNode)

    def bindSkin(self, jointNodes):
        u"""Create SknCluster and to bind Mesh.

        To set self.skinNode.

        :type jointNodes: list of nt.Joint
        """

        # LȃWCg𒊏o
        jointIdxSet = set()
        addIdx = jointIdxSet.add
        for vert in self.vertices:
            d = vert.deform
            if isinstance(d, pmx.Bdef4):
                addIdx(d.index0)
                addIdx(d.index1)
                addIdx(d.index2)
                addIdx(d.index3)
            elif isinstance(d, (pmx.Bdef2, pmx.Sdef)):
                addIdx(d.index0)
                addIdx(d.index1)
            else:
                addIdx(d.index0)

        newJointSet = set(jointNodes[i] for i in jointIdxSet)

        maxInf = 4  # őCtGX (PMX̎dlA4)
        numInfs = len(newJointSet)
        numVerts = len(self.vertices)
        mesh = self.meshNode

        # bVƃWCgoCh
        self.skinNode = pm.skinCluster(newJointSet, mesh, mi=maxInf, tsb=True)
        skinFn = getSkinFn(self.skinNode)

        # BoneCfbNX => InfluenceCfbNX ϊXg쐬
        idxConv = []
        addIdx = idxConv.append
        for joint in jointNodes:
            if joint in newJointSet:
                dagPath = getDagPath(joint.fullPath())
                addIdx(skinFn.indexForInfluenceObject(dagPath))
            else:
                addIdx(0)

        # EFCgXg쐬
        # TODO: indexl͈͊Oꍇ̗O(dl㖳ۂ(?))
        weights = []
        for vert in self.vertices:
            ws = [0.0] * numInfs
            d = vert.deform
            if isinstance(d, pmx.Bdef4):
                w = [d.weight0,
                     d.weight1,
                     d.weight2,
                     d.weight3]
                sw = sum(w)
                w[0] /= sw
                w[1] /= sw
                w[2] /= sw
                w[3] = 1.0 - (w[0] + w[1] + w[2])  # vZ덷΍

                ws[idxConv[d.index0]] += w[0]
                ws[idxConv[d.index1]] += w[1]
                ws[idxConv[d.index2]] += w[2]
                ws[idxConv[d.index3]] += w[3]
            elif isinstance(d, (pmx.Bdef2, pmx.Sdef)):
                w0 = min(d.weight0, 1.0)
                ws[idxConv[d.index0]] += w0
                ws[idxConv[d.index1]] += (1.0 - w0)
            else:
                ws[idxConv[d.index0]] = 1.0

            weights.extend(ws)

        # APIz֕ϊ
        weights = mayautils.listToMDoubleArray(weights)
        jointIdcs = mayautils.listToMIntArray(range(numInfs))

        # _R|[lg쐬
        components = mayautils.getVertComponents(range(numVerts))

        # EFCglݒ
        meshPath = mayautils.getDagPath(mesh.fullPath())
        skinFn.setWeights(meshPath, components, jointIdcs, weights, True)


def run(filePath, setting, msgCallback=(lambda m, w: _)):
    u"""ff[^C|[g܂.

    :param str filePath: PMX/PMDt@Cւ̃pX.
    :param setting: ݒf[^.
    :type setting: ImportSetting
    :param msgCallback: 2̈֐IuWFNg.
                        message: R[obNbZ[W.
                        isWarning: xtO.
    """

    lang = getCurrentLang()
    langImp = lang.Importer
    msgCallback(langImp.started)
    startTime = time.time()

    # PMXt@Cǂݍ
    try:
        model = reader.read_from_file(filePath)
        model = converter.pmd_to_pmx(model)
    except:
        model = reader.read_from_file(filePath)
    modelName = getModelName(filePath, model)

    # ff[^\߃bVPʂɕĂ
    if setting.separateMesh:
        # }eAɕ
        meshDatas = separateModelWithMaterial(modelName, model)
    else:
        # P̃bV
        meshDatas = [MeshData(modelName, model.indices,
                              model.vertices, model.materials)]

    # l[Xy[XύX
    if setting.createNs:
        name = getSafeName(setting.nsName, mode=NAME_MODE)
        if not name:
            name = modelName
        currentNs = pm.namespaceInfo(currentNamespace=True)
        pm.namespace(set=pm.namespace(add=name))

    # O[v쐬
    name = getSafeName(setting.grpName, mode=NAME_MODE)
    if not name:
        name = modelName
    group = pm.group(n=name, em=True, w=True)

    # bV쐬
    msgCallback(langImp.createMesh)
    for meshData in meshDatas:
        meshData.createShape(setting.scale, False, group)

    # XPg(WCg)쐬
    msgCallback(langImp.createSkelton)
    jointNodes = createSkelton(model, setting.scale, setting.jointSize,
                               msgCallback, langImp, group)

    # }eA쐬
    msgCallback(langImp.createMaterial)
    fileNodes = createAllTextures(model)[0]
    for meshData in meshDatas:
        meshData.createMaterials(fileNodes, setting.connectAlpha)

    # XPgoCh
    msgCallback(langImp.createSkin)
    for meshData in meshDatas:
        meshData.bindSkin(jointNodes)

    # l[Xy[Xɖ߂
    if setting.createNs:
        pm.namespace(set=currentNs)

    # O[vw肳ĂȂ̓[hɈړ
    # 邱ƂFnMesh.createIUndoł
    if not setting.createGrp:
        grpName = group.nodeName()
        childs = group.getChildren()
        rename = None
        for c in childs:
            nodeName = c.nodeName()
            if nodeName == grpName:
                rename = (c, nodeName)
                break
        pm.parent(childs, w=True)
        pm.delete(group)
        if rename:
            rename[0].rename(rename[1])

    pm.select(cl=True)
    span = time.time() - startTime
    msgCallback(u"%s ( %.4f %s )" % (langImp.completed, span, lang.second))


def getModelName(path, model):
    u"""Get model name from english_name or file name.

    :type path: str
    :type model: pmx.Model
    :rtype: str
    """

    name = getSafeName(model.english_name, mode=NAME_MODE)
    if not name:
        fname = ospath.basename(path).split(".")[0]
        name = getSafeName(fname, mode=NAME_MODE)
    if not name:
        name = DEFAULT_MODEL_NAME

    return name


def getMaterialName(material):
    u"""Get material name.

    :type material: pmx.Material
    :rtype: str
    """

    matName = getSafeName(material.english_name, mode=NAME_MODE)
    if not matName:
        matName = DEFAULT_MATERIAL_NAME

    return matName


def getTextureName(texturePath):
    """:rtype: str"""

    name = getUnicode(texturePath)
    name = ospath.splitext(ospath.basename(name))[0]
    name = mayautils.getSafeName(name, defstr="file", mode=NAME_MODE)
    return name


def separateModelWithMaterial(modelName, model):
    u"""Separate polygon face (indices, vertices) with material.

    :type modelName: str
    :type model: pmx.Model
    :rtype: list of MeshData
    """

    vertices = model.vertices
    matCounts = (m.vertex_count for m in model.materials)
    matIndices = getGroupByCounts(matCounts, model.indices)
    result = []

    for mat, indices in izip(model.materials, matIndices):

        # CfbNXƒ_ꂼVXgɐ؂o
        indexMap = [None] * len(vertices)
        newIndex = 0
        newIndices = []
        newVertices = []
        appendIndex = newIndices.append
        appendVertex = newVertices.append

        for index in indices:
            if indexMap[index] is None:
                appendVertex(vertices[index])
                appendIndex(newIndex)
                indexMap[index] = newIndex
                newIndex += 1
            else:
                appendIndex(indexMap[index])

        name = "%s_%s" % (modelName, getMaterialName(mat))
        result.append(MeshData(name, newIndices, newVertices, [mat]))

    return result


def createTexture(texturePath):
    u"""Create file node with texture file path.

    :type texturePath: str
    :rtype: tuple
    :returns: (nt.File, nt.Place2dTexture)
    """

    # m[h쐬
    name = getTextureName(texturePath)
    fileNode = pm.shadingNode("file", n=name, asTexture=True)

    p2d = "place2dTexture"
    placeNode = pm.shadingNode(p2d, n=(p2d + name), asUtility=True)

    # Agr[gڑ
    placeNode.coverage >> fileNode.coverage
    placeNode.translateFrame >> fileNode.translateFrame
    placeNode.rotateFrame >> fileNode.rotateFrame
    placeNode.mirrorU >> fileNode.mirrorU
    placeNode.mirrorV >> fileNode.mirrorV
    placeNode.stagger >> fileNode.stagger
    placeNode.wrapU >> fileNode.wrapU
    placeNode.wrapV >> fileNode.wrapV
    placeNode.repeatUV >> fileNode.repeatUV
    placeNode.offset >> fileNode.offset
    placeNode.rotateUV >> fileNode.rotateUV
    placeNode.noiseUV >> fileNode.noiseUV
    placeNode.vertexUvOne >> fileNode.vertexUvOne
    placeNode.vertexUvTwo >> fileNode.vertexUvTwo
    placeNode.vertexUvThree >> fileNode.vertexUvThree
    placeNode.vertexCameraOne >> fileNode.vertexCameraOne
    placeNode.outUV >> fileNode.uv
    placeNode.outUvFilterSize >> fileNode.uvFilterSize

    # t@Cm[hɃpXݒ
    try:
        fileNode.fileTextureName.set(texturePath)
    except:
        pass

    return (fileNode, placeNode)


def createAllTextures(model):
    """Create all textures from model.materials.

    :rtype: tuple
    :returns: (list of nt.File, list of nt.Place2dTexture)
    """
    pmxDir = ospath.dirname(getUnicode(model.path))
    fileNodes = []
    placeNodes = []
    for t in model.textures:
        try:
            texPath = ospath.join(pmxDir, t)
        except:
            texPath = ""

        r = createTexture(texPath)
        fileNodes.append(r[0])
        placeNodes.append(r[1])

    return (fileNodes, placeNodes)


def createSkelton(model, scale, jointSize, msgCallback, lang, group=None):
    """Create maya skelton from pmx bones.

    :type model: pymeshio.pmx.Model
    :param float scale: Global position scale.
    :param float jointSize: Size of joint radius.
    :param msgCallback: Function with a string argument.
    :type lang: EnglishLanguage
    :type group: nt.DagNode
    :rtype: list of nt.Joint
    """

    warnMsg = unicode(lang.boneParentErr) + u" (%d: '%s' %s:%d)"
    bones = model.bones
    joints = []
    addJoint = joints.append

    for i, bone in enumerate(bones):

        # {[f[^̎擾
        parent = bone.parent_index
        jpName = getUnicode(bone.name)
        enName = getUnicode(bone.english_name)
        name = mayautils.getSafeName(enName, mode=NAME_MODE, defstr="joint")
        p = bone.position.to_tuple()
        pos = dt.Vector(p[0] * scale, p[1] * scale, p[2] * scale * -1)
        xAxis = dt.Vector(*bone.local_x_vector.to_tuple())
        zAxis = dt.Vector(*bone.local_z_vector.to_tuple())
        order = "yxz"

        # sȐeQƂ̃`FbN
        if parent >= i:
            msg = warnMsg % (i, enName, lang.parentId, parent)
            msgCallback(msg)
            pm.warning(msg)

        if 0 <= parent < i:
            # ew肳Ă
            pm.select(joints[parent])
        elif group:
            # O[vw肳Ă
            pm.select(group)
        else:
            # ew肳ĂȂ
            pm.select(cl=True)

        # [J]vZ
        xAxis.z *= -1
        zAxis.z *= -1
        xAxis.normalize()
        zAxis.normalize()
        yAxis = zAxis.cross(xAxis)
        zAxis = xAxis.cross(yAxis)
        orient = getRotationFromVectors(xAxis, yAxis, zAxis, order)

        # WCg쐬
        joint = pm.joint(n=name, p=pos, rad=jointSize, roo=order)
        joint.setRotation(orient, "world")
        addJoint(joint)

        # ǉ̃Agr[gݒ
        pm.addAttr(joint, ln="pmxName", dt="string")
        pm.addAttr(joint, ln="pmxEnglishName", dt="string")
        pm.setAttr(joint + ".pmxName", jpName)
        pm.setAttr(joint + ".pmxEnglishName", enName)

        # ]t[Y
        pm.makeIdentity(joint, apply=True, t=False, r=True, s=False)

    return joints
