# -*- coding: shift-jis -*-

import pymel.core as pm
import pymel.core.nodetypes as nt


def getRenderSets(mesh, expandMesh=True):
    u"""bVɃATCꂽ_[Zbg擾܂B

    :type mesh: Mesh or Transform
    :param bool expandMesh: MeshMeshFaceɕϊ邩B
    :rtype: dict
    :returns: ShadingEngine: list of Component
    """

    # TransformȂ\Meshɕϊ
    if isinstance(mesh, nt.Transform):
        mesh = mesh.getShape()

    # MeshɊ֘AVF[fBOO[v擾B
    sgNodes = pm.listSets(type=1, object=mesh)

    result = {}
    setdefault = result.setdefault
    for sgNode in sgNodes:
        append = setdefault(sgNode, []).append
        for mem in sgNode.members():
            if isinstance(mem, nt.Mesh):
                if mesh == mem:
                    if expandMesh:
                        append(mem.faces)
                    else:
                        append(mem)
            else:
                if mesh == mem.node():
                    append(mem)

    return result


def getShadingMap(mesh):
    u"""bV̊etF[XɊꂽVF[fBOO[v擾܂B

    :type mesh: pymel.core.nodetypes.Mesh or Transform
    :returns: (SGm[h̃Xg, tF[XSGm[h̃CfbNX)
    :rtype: (list of ShadingEngine, list of int)
    """

    # TransformȂ\Meshɕϊ
    if isinstance(mesh, nt.Transform):
        mesh = mesh.getShape()

    # MeshɊ֘AVF[fBOO[v擾B
    sgNodes = pm.listSets(type=1, object=mesh)

    sgMap = [-1] * mesh.numFaces()
    for idx, sgNode in enumerate(sgNodes):
        for comp in sgNode.members():
            if isinstance(comp, nt.Mesh):
                comp = comp.faces
            if comp.node() == mesh:
                for i in comp.indices():
                    sgMap[i] = idx

    return sgNodes, sgMap


def getShadingGroups(mesh):
    u"""SGm[hɃO[vtF[XID̃Xg擾܂B

    1̃tF[XIDK1xo邱Ƃۏ؂Ă܂B
    ATC̃tF[Xꍇ, NoneIuWFNgƑΉt܂B

    :type mesh: pymel.core.nodetypes.Mesh or Transform
    :returns: (SGm[h̃Xg, tF[XIDXg̃Xg)
    :rtype: (list of ShadingEngine and None, list of list of int)
    """
    sgNodes, sgMap = getShadingMap(mesh)
    faceIds = [[] for _ in range(len(sgNodes) + 1)]
    appends = [L.append for L in faceIds]
    for i, sgId in enumerate(sgMap):
        appends[sgId](i)

    if not faceIds[-1]:
        del faceIds[-1]
    else:
        sgNodes.append(None)

    return sgNodes, faceIds
