/*
 * Decompiled with CFR 0.152.
 */
package org.sinytra.adapter.patch.transformer;

import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import org.objectweb.asm.Handle;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.VarInsnNode;
import org.sinytra.adapter.patch.analysis.LocalVarAnalyzer;
import org.sinytra.adapter.patch.analysis.LocalVariableLookup;
import org.sinytra.adapter.patch.api.MethodContext;
import org.sinytra.adapter.patch.api.MethodTransform;
import org.sinytra.adapter.patch.api.Patch;
import org.sinytra.adapter.patch.api.PatchContext;
import org.sinytra.adapter.patch.transformer.param.TransformParameters;
import org.sinytra.adapter.patch.util.AdapterUtil;
import org.sinytra.adapter.patch.util.MethodQualifier;
import org.sinytra.adapter.patch.util.OpcodeUtil;

public record ExtractMixin(String targetClass, boolean remove) implements MethodTransform
{
    public ExtractMixin(String targetClass) {
        this(targetClass, true);
    }

    @Override
    public Collection<String> getAcceptedAnnotations() {
        return Set.of("Lcom/llamalad7/mixinextras/injector/WrapWithCondition;", "Lcom/llamalad7/mixinextras/injector/wrapoperation/WrapOperation;", "Lorg/spongepowered/asm/mixin/injection/ModifyConstant;", "Lorg/spongepowered/asm/mixin/injection/ModifyArg;", "Lorg/spongepowered/asm/mixin/injection/Inject;", "Lorg/spongepowered/asm/mixin/injection/Redirect;", "Lorg/spongepowered/asm/mixin/injection/ModifyVariable;", "Lcom/llamalad7/mixinextras/injector/ModifyExpressionValue;");
    }

    @Override
    public Patch.Result apply(ClassNode classNode, MethodNode methodNode, MethodContext methodContext, PatchContext context) {
        boolean isStatic = (methodNode.access & 8) == 8;
        MethodQualifier qualifier = methodContext.getTargetMethodQualifier();
        if (qualifier == null) {
            return Patch.Result.PASS;
        }
        String owner = Objects.requireNonNullElse(qualifier.internalOwnerName(), this.targetClass);
        boolean isInherited = context.environment().inheritanceHandler().isClassInherited(this.targetClass, owner);
        Candidates candidates = ExtractMixin.findCandidates(classNode, methodNode);
        if (!candidates.canMove(classNode, isInherited)) {
            return Patch.Result.PASS;
        }
        ClassNode targetClass = context.environment().dirtyClassLookup().getClass(this.targetClass).orElse(null);
        ClassNode generatedTarget = context.environment().classGenerator().getOrGenerateMixinClass(classNode, this.targetClass, targetClass != null ? targetClass.superName : null);
        context.environment().refmapHolder().copyEntries(classNode.name, generatedTarget.name);
        for (MethodNode method : candidates.methods()) {
            MethodNode transfer;
            if (this.remove) {
                transfer = method;
            } else {
                transfer = new MethodNode(method.access, method.name, method.desc, method.signature, method.exceptions == null ? null : method.exceptions.toArray(new String[0]));
                method.accept((MethodVisitor)transfer);
            }
            generatedTarget.methods.add(transfer);
            ExtractMixin.updateOwnerRefereces(transfer, classNode, this.targetClass);
            if (isStatic || transfer.localVariables == null) continue;
            transfer.localVariables.stream().filter(l -> l.index == 0).findFirst().ifPresent(lvn -> {
                lvn.desc = Type.getObjectType((String)generatedTarget.name).getDescriptor();
            });
        }
        candidates.handleUpdates().forEach(c -> c.accept(generatedTarget));
        Patch.Result result = Patch.Result.PASS;
        if (methodContext.methodAnnotation().getValue("locals").isPresent()) {
            result = result.or(ExtractMixin.recreateLocalVariables(classNode, methodNode, methodContext, context, generatedTarget));
        }
        if (this.remove) {
            context.postApply(() -> classNode.methods.removeAll(candidates.methods));
        }
        return result.or(Patch.Result.APPLY);
    }

    private static Candidates findCandidates(ClassNode classNode, MethodNode methodNode) {
        ArrayList<MethodNode> methods = new ArrayList<MethodNode>();
        ArrayList<Consumer<ClassNode>> handleUpdates = new ArrayList<Consumer<ClassNode>>();
        methods.add(methodNode);
        for (AbstractInsnNode insn : methodNode.instructions) {
            if (!(insn instanceof InvokeDynamicInsnNode)) continue;
            InvokeDynamicInsnNode indy = (InvokeDynamicInsnNode)insn;
            if (indy.bsmArgs.length < 3) continue;
            for (int i = 0; i < indy.bsmArgs.length; ++i) {
                Handle handle;
                Object object = indy.bsmArgs[i];
                if (!(object instanceof Handle) || !(handle = (Handle)object).getOwner().equals(classNode.name) || !handle.getName().startsWith("lambda$" + methodNode.name)) continue;
                int finalI = i;
                classNode.methods.stream().filter(m -> m.name.equals(handle.getName()) && m.desc.equals(handle.getDesc())).findFirst().ifPresent(m -> {
                    methods.add((MethodNode)m);
                    handleUpdates.add(t -> {
                        indy.bsmArgs[finalI] = new Handle(handle.getTag(), t.name, handle.getName(), handle.getDesc(), handle.isInterface());
                    });
                });
            }
        }
        return new Candidates(methods, handleUpdates);
    }

    private static void updateOwnerRefereces(MethodNode methodNode, ClassNode originalClass, String targetClass) {
        for (AbstractInsnNode insn : methodNode.instructions) {
            if (insn instanceof MethodInsnNode) {
                MethodInsnNode minsn = (MethodInsnNode)insn;
                if (minsn.owner.equals(originalClass.name)) {
                    minsn.owner = targetClass;
                    continue;
                }
            }
            if (!(insn instanceof FieldInsnNode)) continue;
            FieldInsnNode finsn = (FieldInsnNode)insn;
            if (!finsn.owner.equals(originalClass.name)) continue;
            finsn.owner = targetClass;
        }
    }

    private static boolean isInheritedField(ClassNode cls, FieldInsnNode finsn, boolean isTargetInherited, List<Runnable> accessUpdates) {
        FieldNode field = cls.fields.stream().filter(f -> f.name.equals(finsn.name)).findFirst().orElse(null);
        if (field != null) {
            if (AdapterUtil.isShadowField(field)) {
                return true;
            }
            if (isTargetInherited) {
                accessUpdates.add(() -> {
                    field.access = ExtractMixin.fixAccess(field.access);
                });
                return true;
            }
        }
        return false;
    }

    private static boolean isInheritedMethod(ClassNode cls, MethodInsnNode minsn, boolean isTargetInherited, List<Runnable> accessUpdates) {
        MethodNode method = cls.methods.stream().filter(m -> m.name.equals(minsn.name) && m.desc.equals(minsn.desc)).findFirst().orElse(null);
        if (method != null) {
            List annotations;
            List list = annotations = method.visibleAnnotations != null ? method.visibleAnnotations : List.of();
            if (AdapterUtil.hasAnnotation(annotations, "Lorg/spongepowered/asm/mixin/Shadow;")) {
                return true;
            }
            if (isTargetInherited) {
                accessUpdates.add(() -> {
                    method.access = ExtractMixin.fixAccess(method.access);
                });
                return true;
            }
        }
        return false;
    }

    private static int fixAccess(int access) {
        int visibility = access & 7;
        if (visibility == 2 || visibility == 0) {
            return access & 0xFFFFFFF8 | 1 | 0x1000;
        }
        return access;
    }

    private static Patch.Result recreateLocalVariables(ClassNode classNode, MethodNode methodNode, MethodContext methodContext, PatchContext context, ClassNode extractClass) {
        AdapterUtil.CapturedLocals capturedLocals = AdapterUtil.getCapturedLocals(methodNode, methodContext);
        if (capturedLocals == null) {
            return Patch.Result.PASS;
        }
        LocalVariableLookup table = capturedLocals.lvt();
        int paramLocalStart = capturedLocals.paramLocalStart();
        LocalVarAnalyzer.CapturedLocalsTransform transform = LocalVarAnalyzer.analyzeCapturedLocals(capturedLocals, methodNode);
        LocalVarAnalyzer.CapturedLocalsUsage usage = transform.getUsage(capturedLocals);
        List<Integer> used = transform.used();
        LocalVariableLookup targetTable = usage.targetTable();
        Int2ObjectMap<InsnList> varInsnLists = usage.varInsnLists();
        Int2IntMap usageCount = usage.usageCount();
        Patch.Result result = transform.remover().apply(classNode, methodNode, methodContext, context);
        if (result == Patch.Result.PASS) {
            return Patch.Result.PASS;
        }
        MethodNode copy = new MethodNode(2 | (capturedLocals.isStatic() ? 8 : 0), "adapter$bridge$" + methodNode.name, methodNode.desc, null, null);
        methodNode.accept((MethodVisitor)copy);
        copy.visibleAnnotations = null;
        copy.invisibleAnnotations = null;
        methodNode.instructions = new InsnList();
        TransformParameters cleanupPatch = TransformParameters.builder().chain(b -> IntStream.range(paramLocalStart, paramLocalStart + used.size()).boxed().sorted(Collections.reverseOrder()).forEach(b::remove)).build();
        Patch.Result cleanupResult = cleanupPatch.apply(classNode, methodNode, methodContext, context);
        if (cleanupResult == Patch.Result.PASS) {
            return Patch.Result.PASS;
        }
        InsnList replacementInsns = new InsnList();
        Type lastVar = Type.getType((String)table.getLast().desc);
        AtomicInteger nextAvailableIndex = new AtomicInteger(methodNode.localVariables.size() - 1 + AdapterUtil.getLVTOffsetForType(lastVar));
        usageCount.forEach((index, count) -> {
            if (count == 1) {
                int usages = 0;
                for (Int2ObjectMap.Entry entry : varInsnLists.int2ObjectEntrySet()) {
                    for (AbstractInsnNode insn : (InsnList)entry.getValue()) {
                        if (!(insn instanceof VarInsnNode)) continue;
                        VarInsnNode varInsn = (VarInsnNode)insn;
                        if (varInsn.var != index) continue;
                        InsnList varInitializers = (InsnList)varInsnLists.get(varInsn.var);
                        ((InsnList)entry.getValue()).insert((AbstractInsnNode)varInsn, varInitializers);
                        ((InsnList)entry.getValue()).remove((AbstractInsnNode)varInsn);
                        ++usages;
                    }
                }
                if (usages > 1) {
                    throw new IllegalStateException("Expected only one reference to variable " + index);
                }
            }
        });
        LabelNode end = new LabelNode();
        HashMap newIndices = new HashMap();
        usageCount.forEach((index, count) -> {
            if (count > 1) {
                LocalVariableNode node = targetTable.getByIndex((int)index);
                Type type = Type.getType((String)node.desc);
                int newIndex = nextAvailableIndex.getAndAdd(AdapterUtil.getLVTOffsetForType(type));
                LabelNode start = new LabelNode();
                methodNode.localVariables.add(new LocalVariableNode(node.name, node.desc, node.signature, start, end, newIndex));
                InsnList insns = (InsnList)varInsnLists.get(index.intValue());
                replacementInsns.add(insns);
                replacementInsns.add((AbstractInsnNode)new VarInsnNode(OpcodeUtil.getStoreOpcode(type.getSort()), newIndex));
                replacementInsns.add((AbstractInsnNode)start);
                varInsnLists.remove(index);
                varInsnLists.forEach((varIndex, varInsns) -> {
                    for (AbstractInsnNode insn : varInsns) {
                        if (!(insn instanceof VarInsnNode)) continue;
                        VarInsnNode varInsn = (VarInsnNode)insn;
                        if (varInsn.var != index) continue;
                        varInsn.var = newIndex;
                    }
                });
                newIndices.put(index, newIndex);
            }
        });
        for (int i = 0; i < paramLocalStart + 1; ++i) {
            Type type = Type.getType((String)table.getByIndex((int)i).desc);
            int opcode = OpcodeUtil.getLoadOpcode(type.getSort());
            replacementInsns.add((AbstractInsnNode)new VarInsnNode(opcode, i));
        }
        used.forEach(ordinal -> {
            LocalVariableNode node = targetTable.getByOrdinal((int)ordinal);
            InsnList insns = (InsnList)varInsnLists.get(node.index);
            if (insns != null) {
                replacementInsns.add(insns);
            } else {
                Type type = Type.getType((String)node.desc);
                int newIndex = newIndices.getOrDefault(node.index, -1);
                if (newIndex == -1) {
                    throw new IllegalArgumentException("Missing new index for var " + node.index);
                }
                replacementInsns.add((AbstractInsnNode)new VarInsnNode(OpcodeUtil.getLoadOpcode(type.getSort()), newIndex));
            }
        });
        replacementInsns.add((AbstractInsnNode)new MethodInsnNode(capturedLocals.isStatic() ? 184 : 182, extractClass.name, copy.name, copy.desc));
        replacementInsns.add((AbstractInsnNode)new LabelNode());
        replacementInsns.add((AbstractInsnNode)new InsnNode(OpcodeUtil.getReturnOpcode(methodNode)));
        replacementInsns.add((AbstractInsnNode)end);
        methodNode.instructions = replacementInsns;
        extractClass.methods.add(copy);
        return Patch.Result.COMPUTE_FRAMES;
    }

    record Candidates(List<MethodNode> methods, List<Consumer<ClassNode>> handleUpdates) {
        public boolean canMove(ClassNode classNode, boolean isInherited) {
            ArrayList<Runnable> accessFixes = new ArrayList<Runnable>();
            for (MethodNode methodNode : this.methods) {
                for (AbstractInsnNode insn : methodNode.instructions) {
                    block5: {
                        block4: {
                            if (!(insn instanceof FieldInsnNode)) break block4;
                            FieldInsnNode finsn = (FieldInsnNode)insn;
                            if (finsn.owner.equals(classNode.name) && !ExtractMixin.isInheritedField(classNode, finsn, isInherited, accessFixes)) break block5;
                        }
                        if (!(insn instanceof MethodInsnNode)) continue;
                        MethodInsnNode minsn = (MethodInsnNode)insn;
                        if (!minsn.owner.equals(classNode.name) || ExtractMixin.isInheritedMethod(classNode, minsn, isInherited, accessFixes)) continue;
                    }
                    return false;
                }
            }
            accessFixes.forEach(Runnable::run);
            return true;
        }
    }
}

