import maya.api.OpenMaya as api
import maya.cmds as cmds
import maya.OpenMayaUI as omui

from qt import toQtObject, QtWidgets

WIREFRAME_OPACITY_NAME = "muWireframeOpacity"
SYMMETRY_ENABLED_NAME  = "muSymmetryEnabled"

def set_reference_mesh():
    mesh = getSingleSelectedMesh()
    if mesh:
        cmds.muSetReferenceMesh(me=[mesh])

def mesh_to_layer(deformer):
    if not is_deformer_valid(deformer):
        cmds.error("invalid defformer") 
        return
    
    sel = cmds.ls(sl=1, dag=1, type="mesh", noIntermediate=True)
    if len(sel) != 2:
        cmds.error("Select base mesh, offset mesh and layer.")
        return
    
    selected_layers = get_deformer_selected_layers(deformer)
    if len(selected_layers) != 1:
        cmds.error("Select base mesh, offset mesh and layer.")
        return
    
    cmds.muMeshToLayer(d=deformer, li=selected_layers[0], bm=sel[0], om=sel[1])

def get_deformer_selected_layers(deformer):
    indices = cmds.getAttr(deformer+".layer", mi=1)
    sel = []
    for idx in indices:
        selected = cmds.getAttr("{}.layer[{}].selected".format(deformer, idx))
        if selected:
            sel.append(idx)
    return sel

def get_opacity():
    if cmds.optionVar(exists=WIREFRAME_OPACITY_NAME):
        return cmds.optionVar(q=WIREFRAME_OPACITY_NAME)
    set_opacity(50)
    return 50

def set_opacity(value):
    cmds.optionVar(fv=(WIREFRAME_OPACITY_NAME, value))
    opacity = float(value) / 100.0
    shapes = cmds.ls(type="muMesh")
    for s in shapes:
        attr = "{}.wireframeOpacity".format(s)
        cmds.setAttr(attr, opacity)

def get_symmetry():
    if cmds.optionVar(exists=SYMMETRY_ENABLED_NAME):
        value = cmds.optionVar(q=SYMMETRY_ENABLED_NAME)
        return bool(value)
    set_symmetry(False)
    return False

def set_symmetry(value):
    iValue = int(value)
    cmds.optionVar(fv=(SYMMETRY_ENABLED_NAME, iValue))
    ctx = cmds.currentCtx()
    if "mush3dCtxCmd" in ctx:
        cmds.mush3dCtxCmd(ctx, edit=True, symmetry=int(value))

def is_deformer_valid(name):
    if not name or not cmds.objExists(name) or cmds.nodeType(name) != "muCorrectiveShape":
        return False
    return True

def get_mesh_from_deformer(deformer):
    if not deformer:
        return None
    if cmds.objExists(deformer) and cmds.ls(deformer, type="muCorrectiveShape"):
        meshes = cmds.deformer(deformer, q=1, g=1) or []
        if meshes:
            return meshes[0]
    return None

def update_deformer_layer_selection(deformer, selected_indices):
    indices = cmds.getAttr(deformer+".layer", mi=1) or []
    if not indices:
        return
    
    for idx in indices:
        value = 1 if idx in selected_indices else 0
        cmds.setAttr("{}.layer[{}].selected".format(deformer, idx), value)

def get_selected_meshes():
    return  cmds.ls(sl=1, dag=1, type="mesh", noIntermediate=True) or []

def getSingleSelectedMesh():
    sel = cmds.ls(sl=1, dag=1, type="mesh", noIntermediate=True)
    if len(sel) != 1:
        return []
    return sel[0]

def get_selected_layers_indices(deformer):
    indices = cmds.getAttr(deformer + ".layer", mi=1) or []
    selected = []
    for idx in indices:
        if cmds.getAttr("{}.layer[{}].selected".format(deformer, idx)):
            selected.append(idx)
    return selected

def createDeformer():
    meshes = get_selected_meshes()
    if not meshes:
        cmds.warning("No Mesh Selected.")
        return None
    
    sel = cmds.ls(sl=1)
    cmds.select(meshes)
    deformer = cmds.deformer(type="muCorrectiveShape")[0]
    cmds.select(sel)

    # set reference points for cloning and wrinkle
    cmds.muSetReferenceMesh(me=meshes)

    # set deformer bind points
    for i in range(len(meshes)):
        # set deformer bind points
        dag = api.MSelectionList().add(meshes[i]).getDagPath(0)
        points = api.MFnMesh(dag).getPoints()
        a = []
        for p in points:
            a.append(float(p.x))
            a.append(float(p.y))
            a.append(float(p.z))

        cmds.setAttr("{}.bindPoints[{}]".format(deformer, i), a, type='floatArray')

    return deformer

def addLayer(deformer):
    meshes = cmds.deformer(deformer, q=1, g=1) or []
    geometry_indices = cmds.deformer(deformer, q=1, gi=1) or []

    if len(meshes) != len(geometry_indices):
        cmds.error("Invalid deformer input connections.")
        return

    # get next layer
    indices = cmds.getAttr("{}.layer".format(deformer), mi=1) or []
    next_layer_idx = max(indices) + 1 if indices else 0 
    layer = "{}.layer[{}]".format(deformer, next_layer_idx)

    # create per geometry
    for i in range(len(meshes)):
        mesh = meshes[i]
        geo_index = geometry_indices[i]

        dag = api.MSelectionList().add(mesh).getDagPath(0)
        points = api.MFnMesh(dag).getPoints()
        a = []
        for p in points:
            a.append(float(p.x))
            a.append(float(p.y))
            a.append(float(p.z))
            
        base_attr = layer + ".perGeo[{}]".format(geo_index)         

        cmds.setAttr(base_attr + ".basePoints", a, type='floatArray')
        cmds.setAttr(base_attr + ".offsetPoints", a, type='floatArray')

    return next_layer_idx

def delete_layers(deformer, indices_to_delete):
    attrs = []
    for idx in indices_to_delete:
        attrs.append("{}.layer[{}]".format(deformer, idx))

    if not attrs:
        return

    for a in attrs:
        cmds.removeMultiInstance(a, b=True)

def ensurePluginLoaded():
    if not cmds.pluginInfo("mush3d.mll", query=True, loaded=True):
        try:
            cmds.loadPlugin("mush3d.mll")
        except Exception as e:
            cmds.warning("Failed to load mush3d.mll plugin: {}".format(e))
            return False
    return True

def connect_tweak(deformer, layer_idx):
    meshes = cmds.deformer(deformer, q=True, g=True) or []

    for i in range(len(meshes)):
        src = "{}.layer[{}].perGeo[{}].tweakLocation[0]".format(deformer, layer_idx, i)
        dst = "{}.tweakLocation".format(meshes[i])

        existing = cmds.listConnections(dst, s=True, d=False, plugs=True) or []
        if src not in existing:
            cmds.connectAttr(src, dst, force=True)

def create_attrFieldSliderGrp(attr, parent=None):
    dummy_win = cmds.window()
    cmds.columnLayout()
    slider = cmds.attrFieldSliderGrp(
        attribute=attr,
        precision=2,
        columnWidth3=[1, 40, 60],  # label, field, slider width
        cw3=[1, 40, 60],           # Maya is picky, cw3 alias helps
    )

    ptr = omui.MQtUtil.findControl(slider)
    qt_widget = toQtObject(ptr)

    if parent and isinstance(parent, QtWidgets.QWidget):
        qt_widget.setParent(parent)

    # Tenta esconder o label (geralmente o primeiro filho é o label)
    children = qt_widget.children()
    if children:
        for child in children:
            if isinstance(child, QtWidgets.QLabel):
                child.setVisible(False)
            if isinstance(child, QtWidgets.QSlider):
                child.setStyleSheet("QSlider::groove:vertical { background: #2b2b2b; }")

    cmds.deleteUI(dummy_win)

    return qt_widget

def create_attrControlGrp(attr, parent=None):
    dummy_win = cmds.window()
    cmds.columnLayout()
    slider = cmds.attrControlGrp(attribute=attr)

    ptr = omui.MQtUtil.findControl(slider)
    qt_widget = toQtObject(ptr)

    if parent and isinstance(parent, QtWidgets.QWidget):
        qt_widget.setParent(parent)

    cmds.deleteUI(dummy_win)
    return qt_widget

def get_muCorrectiveShapes_from_mesh(mesh):
    history = cmds.listHistory(mesh, pruneDagObjects=True) or []
    deformers = [node for node in history if cmds.nodeType(node) == "muCorrectiveShape"]
    return deformers

def listMuDeformers():
    meshes = get_selected_meshes()
    all_deformers = []
    for obj in meshes:
        deformers = get_muCorrectiveShapes_from_mesh(obj)
        for d in deformers:
            if d not in all_deformers:
                all_deformers.append(d)
            
    return all_deformers

def solve_selected_layers(deformer_name):
    # get area vertices
    verts = cmds.muGetAreaVertices(d=deformer_name)
    if not verts:
        cmds.error("can't get vertices from selected layers.")
        return

    # get mesh before deformer
    input_attr = cmds.listConnections(deformer_name + ".input", shapes=True, source=True, destination=False, plugs=True)[0]
    
    if not input_attr:
        cmds.error("can't get input mesh for solver.")
        return
    
    # create solver
    solver = cmds.createNode("muPointSolver")
    
    # connect input to solver
    cmds.connectAttr(input_attr, solver+".inMesh")

    # set selected vertices
    cmds.setAttr(solver+".inVertexIds", verts, type="Int32Array")

    # get features of selected layers
    features = []

    # get bind points of first deformed mesh
    bind_feature = []
    pts = cmds.getAttr(deformer_name+".bindPoints[0]")
    for vert in verts:
        idx = vert * 3
        idy = vert * 3 + 1
        idz = vert * 3 + 2
        bind_feature.append([pts[idx], pts[idy], pts[idz]])
    features.append(bind_feature)

    # get selected layers features
    selected_layers_indices = []

    layers_indices = cmds.getAttr(deformer_name+".layer", mi=1)
    for idx in layers_indices:
        layer_attr = "{}.layer[{}]".format(deformer_name, idx)
        is_selected = cmds.getAttr(layer_attr+".selected")
        if is_selected:
            selected_layers_indices.append(idx)
            pts = cmds.getAttr(layer_attr+".perGeo[0].basePoints")
            feature_points = []
            for vert in verts:
                idx = vert * 3
                idy = vert * 3 + 1
                idz = vert * 3 + 2
                feature_points.append([pts[idx], pts[idy], pts[idz]])
            features.append(feature_points)

    # asign features to solver
    for i in range(len(features)):
        f = features[i]
        cmds.setAttr(solver+".features[{}]".format(i), len(f), *f, type="pointArray")

    # connect outputs
    for i in range(len(features)-1):
        layer_idx = selected_layers_indices[i]
        cmds.connectAttr(solver+".output[{}]".format(i+1), deformer_name+".layer[{}].intensity".format(layer_idx), f=1)

def extract_vector_map():
    mesh = getSingleSelectedMesh()
    if mesh:
        cmds.muExtractVectorDisplacement(m=mesh)