package dev.kosmx.playerAnim.api.layered.modifier;

import dev.kosmx.playerAnim.api.PartKey;
import dev.kosmx.playerAnim.api.TransformType;
import dev.kosmx.playerAnim.api.layered.IAnimation;
import dev.kosmx.playerAnim.api.layered.KeyframeAnimationPlayer;
import dev.kosmx.playerAnim.core.util.Vec3f;
import org.jetbrains.annotations.NotNull;

import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

/**
 * Adjusts body parts during animations.<br>
 * Make sure this instance is the very first one, over the KeyframeAnimationPlayer, in the animation stack.
 * <p>
 * Example use (adjusting the vertical angle of a custom attack animation):
 * <pre>
 * {@code
 * new AdjustmentModifier((partName) -> {
 *     float rotationX = 0;
 *     float rotationY = 0;
 *     float rotationZ = 0;
 *     float scaleX = 0;
 *     float scaleY = 0;
 *     float scaleZ = 0;
 *     float offsetX = 0;
 *     float offsetY = 0;
 *     float offsetZ = 0;
 *
 *     var pitch = player.getPitch() / 2F;
 *     pitch = (float) Math.toRadians(pitch);
 *     switch (partName) {
 *         case "body" -> {
 *             rotationX = (-1F) * pitch;
 *         }
 *         case "rightArm", "leftArm" -> {
 *             rotationX = pitch;
 *         }
 *         default -> {
 *             return Optional.empty();
 *         }
 *     }
 *
 *     return Optional.of(new AdjustmentModifier.PartModifier(
 *             new Vec3f(rotationX, rotationY, rotationZ),
 *             new Vec3f(scaleX, scaleY, scaleZ),
 *             new Vec3f(offsetX, offsetY, offsetZ))
 *     );
 * });
 * }
 * </pre>
 */
public class AdjustmentModifier extends AbstractModifier {
    public static final class PartModifier {
        private final Vec3f rotation;
        private final Vec3f scale;
        private final Vec3f offset;

        public PartModifier(
                Vec3f rotation,
                Vec3f offset
        ) {
            this(rotation, Vec3f.ZERO, offset);
        }

        public PartModifier(
                Vec3f rotation,
                Vec3f scale,
                Vec3f offset
        ) {
            this.rotation = rotation;
            this.scale = scale;
            this.offset = offset;
        }

        public Vec3f rotation() {
            return rotation;
        }

        public Vec3f scale() {
            return scale;
        }

        public Vec3f offset() {
            return offset;
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == this) return true;
            if (obj == null || obj.getClass() != this.getClass()) return false;
            PartModifier that = (PartModifier) obj;
            return Objects.equals(this.rotation, that.rotation) &&
                    Objects.equals(this.scale, that.scale) &&
                    Objects.equals(this.offset, that.offset);
        }

        @Override
        public int hashCode() {
            return Objects.hash(rotation, scale, offset);
        }

        @Override
        public String toString() {
            return "PartModifier[" +
                    "rotation=" + rotation + ", " +
                    "scale=" + scale + ", " +
                    "offset=" + offset + ']';
        }
    }

    /// Whether the adjustment should be increasingly applied
    /// between animation.start and animation.begin
    public boolean fadeIn = true;
    /// Whether the adjustment should be decreasingly applied
    /// between animation.end and animation.stop
    public boolean fadeOut = true;
    /// Whether the adjustment should be applied at all
    public boolean enabled = true;

    protected Function<PartKey, Optional<PartModifier>> source;

    public AdjustmentModifier(Function<PartKey, Optional<PartModifier>> source) {
        this.source = source;
    }

    protected float getFadeIn() {
        float fadeIn = 1;
        IAnimation animation = this.getAnim();
        if(this.fadeIn && animation instanceof KeyframeAnimationPlayer player) {
            float currentTick = player.getTick() + player.getTickDelta();
            fadeIn = currentTick / (float) player.getData().beginTick;
            fadeIn = Math.min(fadeIn, 1F);
        }
        return fadeIn;
    }

    @Override
    public void tick() {
        super.tick();

        if (remainingFadeout > 0) {
            remainingFadeout -= 1;
            if(remainingFadeout <= 0) {
                instructedFadeout = 0;
            }
        }
    }

    protected int instructedFadeout = 0;
    private int remainingFadeout = 0;

    public void fadeOut(int fadeOut) {
        instructedFadeout = fadeOut;
        remainingFadeout = fadeOut + 1;
    }

    protected float getFadeOut(float delta) {
        float fadeOut = 1;
        if(remainingFadeout > 0 && instructedFadeout > 0) {
            float current = Math.max(remainingFadeout - delta , 0);
            fadeOut = current / ((float)instructedFadeout);
            fadeOut = Math.min(fadeOut, 1F);
            return fadeOut;
        }
        IAnimation animation = this.getAnim();
        if(this.fadeOut && animation instanceof KeyframeAnimationPlayer player) {
            float currentTick = player.getTick() + player.getTickDelta();
            float position = (-1F) * (currentTick - player.getData().stopTick);
            float length = player.getData().stopTick - player.getData().endTick;
            if (length > 0) {
                fadeOut = position / length;
                fadeOut = Math.min(fadeOut, 1F);
            }
        }
        return fadeOut;
    }

    @Override
    @Deprecated(forRemoval = true)
    public @NotNull Vec3f get3DTransform(@NotNull String modelName, @NotNull TransformType type, float tickDelta, @NotNull Vec3f value0) {
        return get3DTransform(PartKey.keyForId(modelName), type, tickDelta, value0);
    }

    @Override
    public @NotNull Vec3f get3DTransform(@NotNull PartKey partKey, @NotNull TransformType type, float tickDelta, @NotNull Vec3f value0) {
        if (!enabled) {
            return super.get3DTransform(partKey, type, tickDelta, value0);
        }

        Optional<PartModifier> partModifier = source.apply(partKey);

        Vec3f modifiedVector = value0;
        float fade = getFadeIn() * getFadeOut(tickDelta);
        if (partModifier.isPresent()) {
            modifiedVector = super.get3DTransform(partKey, type, tickDelta, modifiedVector);
            return transformVector(modifiedVector, type, partModifier.get(), fade);
        } else {
            return super.get3DTransform(partKey, type, tickDelta, value0);
        }
    }

    protected Vec3f transformVector(Vec3f vector, TransformType type, PartModifier partModifier, float fade) {
        switch (type) {
            case POSITION:
                return vector.add(partModifier.offset().scale(fade));
            case ROTATION:
                return vector.add(partModifier.rotation().scale(fade));
            case SCALE:
                return vector.add(partModifier.scale().scale(fade));
            case BEND:
                break;
        }
        return vector;
    }
}