/*
 * Decompiled with CFR 0.152.
 */
package com.hxzhitang.tongdarailway.railway.planner;

import com.hxzhitang.tongdarailway.Tongdarailway;
import com.hxzhitang.tongdarailway.railway.RailwayBuilder;
import com.hxzhitang.tongdarailway.railway.RegionPos;
import com.hxzhitang.tongdarailway.railway.planner.StationPlanner;
import com.hxzhitang.tongdarailway.structure.TrackPutInfo;
import com.hxzhitang.tongdarailway.util.AStarPathfinder;
import com.hxzhitang.tongdarailway.util.AdaptiveHeightSampler;
import com.hxzhitang.tongdarailway.util.CurveRoute;
import com.hxzhitang.tongdarailway.util.MyMth;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import net.minecraft.core.BlockPos;
import net.minecraft.server.level.ServerLevel;
import net.minecraft.server.level.WorldGenRegion;
import net.minecraft.world.level.LevelHeightAccessor;
import net.minecraft.world.level.chunk.ChunkGenerator;
import net.minecraft.world.level.levelgen.Heightmap;
import net.minecraft.world.level.levelgen.RandomState;
import net.minecraft.world.phys.Vec3;

public class RoutePlanner {
    private final RegionPos regionPos;

    public RoutePlanner(RegionPos regionPos) {
        this.regionPos = regionPos;
    }

    public int[][] getCostMap(WorldGenRegion level) {
        int[][] heightMap;
        for (int[] ints : heightMap = new int[768][768]) {
            Arrays.fill(ints, Integer.MAX_VALUE);
        }
        for (int i = -1; i < 2; ++i) {
            for (int j = -1; j < 2; ++j) {
                if (Math.abs(i) == 1 && Math.abs(j) == 1) continue;
                RegionPos rPos = new RegionPos(this.regionPos.x() + i, this.regionPos.z() + j);
                RailwayBuilder builder = RailwayBuilder.getInstance(level.getSeed());
                int[][] map = builder != null ? builder.regionHeightMap.computeIfAbsent(rPos, k -> this.getHeightMap(level.getLevel(), rPos)) : this.getHeightMap(level.getLevel(), rPos);
                for (int x = 0; x < map.length; ++x) {
                    for (int z = 0; z < map[0].length; ++z) {
                        int picX = (i + 1) * 128 * 2 + x;
                        int picZ = (j + 1) * 128 * 2 + z;
                        heightMap[picX][picZ] = map[x][z];
                    }
                }
            }
        }
        return heightMap;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private int[][] getHeightMap(ServerLevel serverLevel, RegionPos regionPos) {
        ChunkGenerator gen = serverLevel.getChunkSource().getGenerator();
        RandomState cfg = serverLevel.getChunkSource().randomState();
        AdaptiveHeightSampler sampler = new AdaptiveHeightSampler(10.0, 3, 4, (x, z) -> {
            int wx = (int)(x * 8.0 + (double)(regionPos.x() * 128 * 16));
            int wz = (int)(z * 8.0 + (double)(regionPos.z() * 128 * 16));
            return gen.getBaseHeight(wx, wz, Heightmap.Types.WORLD_SURFACE_WG, (LevelHeightAccessor)serverLevel, cfg);
        });
        try {
            long startTime = System.currentTimeMillis();
            sampler.buildQuadTree(256.0);
            long endTime = System.currentTimeMillis();
            Tongdarailway.LOGGER.info(" Build HeightMap time: {}ms", (Object)(endTime - startTime));
        }
        catch (InterruptedException e) {
            Tongdarailway.LOGGER.error(e.getMessage());
        }
        finally {
            sampler.shutdown();
        }
        int[][] heightMap = sampler.generateImage(256, 256);
        return heightMap;
    }

    public ResultWay getWay(List<int[]> way, int[][] costMap, StationPlanner.ConnectionGenInfo connectionGenInfo, ServerLevel level) {
        List<int[]> handledHeightWay = this.handleHeight(way, level, costMap, connectionGenInfo);
        handledHeightWay = handledHeightWay.stream().map(AStarPathfinder::pic2RegionPos).toList();
        List<List<int[]>> straightPaths = StraightPathFinder.findStraightPaths(handledHeightWay);
        straightPaths.removeIf(list -> list.size() <= 2);
        List<List<Vec3>> poi = this.handlePath(straightPaths, 2.0);
        CurveRoute.CompositeCurve route = this.connectPaths(connectionGenInfo, poi);
        List<TrackPutInfo> track = this.connectTrack(connectionGenInfo, poi);
        return new ResultWay(route, track);
    }

    public List<int[]> handleHeight(List<int[]> path, ServerLevel level, int[][] heightMap, StationPlanner.ConnectionGenInfo con) {
        List<double[]> adPath = new LinkedList<double[]>();
        int seaLevel = level.getSeaLevel();
        for (int[] p2 : path) {
            int h = heightMap[p2[0]][p2[1]];
            h = Math.max(h, seaLevel + 5);
            h = Math.min(h, seaLevel + 100);
            adPath.add(new double[]{p2[0], p2[1], h});
        }
        ((double[])adPath.getFirst())[2] = con.connectStart()[2];
        ((double[])adPath.getLast())[2] = con.connectEnd()[2];
        adPath = RoutePlanner.adjustmentHeight(adPath);
        int max = adPath.stream().mapToInt(p -> (int)p[2]).max().orElse(0);
        int min = adPath.stream().mapToInt(p -> (int)p[2]).min().orElse(0);
        int framed2 = (max - min) / 2 + 1;
        if (adPath.size() > framed2 * 2 && framed2 * 2 >= 3) {
            ArrayList<double[]> adPath1 = new ArrayList<double[]>();
            adPath1.add(adPath.getFirst());
            for (int i = 1; i < adPath.size() - 1; ++i) {
                double mean = 0.0;
                int sum = 0;
                for (int j = i - framed2; j <= i + framed2; ++j) {
                    if (j >= 0 && j < adPath.size()) {
                        mean += adPath.get(j)[2];
                        ++sum;
                        continue;
                    }
                    if (j < 0) {
                        mean += adPath.getFirst()[2];
                        ++sum;
                        continue;
                    }
                    mean += adPath.getLast()[2];
                    ++sum;
                }
                adPath1.add(new double[]{adPath.get(i)[0], adPath.get(i)[1], mean /= (double)sum});
            }
            adPath1.add(adPath.getLast());
            adPath = adPath1;
            double fh = con.connectStart()[2];
            double lh = con.connectEnd()[2];
            if (adPath.size() > framed2 * 2 + 20) {
                for (int i = 1; i < framed2 + 10; ++i) {
                    double t = (double)i / (double)(framed2 + 10);
                    double sh = adPath.get(i)[2];
                    double eh = adPath.get(adPath.size() - 1 - i)[2];
                    adPath.get((int)i)[2] = fh * (1.0 - t) + sh * t;
                    adPath.get((int)(adPath.size() - 1 - i))[2] = lh * (1.0 - t) + eh * t;
                }
            }
        }
        return adPath.stream().map(arr -> Arrays.stream(arr).mapToInt(d -> (int)Math.round(d)).toArray()).collect(Collectors.toList());
    }

    public List<List<Vec3>> handlePath(List<List<int[]>> straightPaths, double cutDistance) {
        ArrayList processedPaths = new ArrayList();
        for (List<int[]> path : straightPaths) {
            ArrayList<int[]> arrayList = new ArrayList<int[]>();
            for (int[] point : path) {
                arrayList.add(new int[]{point[0], point[1], point[2]});
            }
            processedPaths.add(arrayList);
        }
        for (int i = 0; i < processedPaths.size() - 1; ++i) {
            List currentPath = (List)processedPaths.get(i);
            List list = (List)processedPaths.get(i + 1);
            if (currentPath.size() < 2 || list.size() < 2) continue;
            int[] currentEndPoint = (int[])currentPath.getLast();
            int[] nextStartPoint = (int[])list.getFirst();
            double distance = Math.sqrt(Math.pow(nextStartPoint[0] - currentEndPoint[0], 2.0) + Math.pow(nextStartPoint[1] - currentEndPoint[1], 2.0));
            if (!(distance < cutDistance)) continue;
            this.trimPathEnd(currentPath, (int)(cutDistance / 2.0));
            this.trimPathStart(list, (int)(cutDistance / 2.0));
        }
        ArrayList<List<Vec3>> result = new ArrayList<List<Vec3>>();
        for (List list : processedPaths) {
            ArrayList<Vec3> path = new ArrayList<Vec3>();
            for (int[] point : list) {
                path.add(MyMth.inRegionPos2WorldPos(this.regionPos, new Vec3((double)point[0], (double)point[2], (double)point[1]).multiply(8.0, 1.0, 8.0)));
            }
            result.add(path);
        }
        return result;
    }

    private CurveRoute.CompositeCurve connectPaths(StationPlanner.ConnectionGenInfo con, List<List<Vec3>> pathPoi) {
        Vec3 endConnectDir;
        Vec3 endConnectPos;
        Vec3 startConnectDir;
        Vec3 startConnectPos;
        CurveRoute.CompositeCurve compositeCurve = new CurveRoute.CompositeCurve();
        if (pathPoi == null || pathPoi.isEmpty()) {
            return compositeCurve;
        }
        List<Vec3> startSegment = pathPoi.getFirst();
        Vec3 firstPos = startSegment.getFirst();
        Vec3 firstDir = firstPos.subtract(startSegment.get(1)).multiply(1.0, 0.0, 1.0).normalize();
        List<Vec3> lastSegment = pathPoi.getLast();
        Vec3 lastPos = lastSegment.getLast();
        Vec3 lastDir = lastPos.subtract(lastSegment.get(lastSegment.size() - 2)).multiply(1.0, 0.0, 1.0).normalize();
        if (firstPos.distanceTo(con.start()) > lastPos.distanceTo(con.start())) {
            startConnectPos = lastPos;
            startConnectDir = lastDir;
        } else {
            startConnectPos = firstPos;
            startConnectDir = firstDir;
        }
        CurveRoute.CubicBezier startConnect = CurveRoute.CubicBezier.getCubicBezier(con.start(), con.startDir(), startConnectPos.subtract(con.start()), startConnectDir);
        compositeCurve.addSegment(startConnect);
        for (int i = 0; i < pathPoi.size(); ++i) {
            List<Vec3> segment = pathPoi.get(i);
            for (int j = 0; j < segment.size() - 1; ++j) {
                Vec3 pA = segment.get(j);
                Vec3 pB = segment.get(j + 1);
                compositeCurve.addSegment(new CurveRoute.LineSegment(pA, pB));
            }
            if (i >= pathPoi.size() - 1) continue;
            List<Vec3> nextSegment = pathPoi.get(i + 1);
            Vec3 vecA = new Vec3(segment.getLast().x(), 0.0, segment.getLast().z()).subtract(new Vec3(segment.getFirst().x(), 0.0, segment.getFirst().z()));
            Vec3 vecB = new Vec3(nextSegment.getFirst().x(), 0.0, nextSegment.getFirst().z()).subtract(new Vec3(nextSegment.getLast().x(), 0.0, nextSegment.getLast().z()));
            Vec3 prevDirection = vecA.normalize();
            Vec3 currentDirection = vecB.normalize();
            CurveRoute.CubicBezier bezierSegment = CurveRoute.CubicBezier.getCubicBezier(segment.getLast(), prevDirection, nextSegment.getFirst().subtract(segment.getLast()), currentDirection);
            compositeCurve.addSegment(bezierSegment);
        }
        if (firstPos.distanceTo(con.end()) > lastPos.distanceTo(con.end())) {
            endConnectPos = lastPos;
            endConnectDir = lastDir;
        } else {
            endConnectPos = firstPos;
            endConnectDir = firstDir;
        }
        CurveRoute.CubicBezier lastConnect = CurveRoute.CubicBezier.getCubicBezier(endConnectPos, endConnectDir, con.end().subtract(endConnectPos), con.endDir());
        compositeCurve.addSegment(lastConnect);
        return compositeCurve;
    }

    private List<TrackPutInfo> connectTrack(StationPlanner.ConnectionGenInfo con, List<List<Vec3>> pathPoi) {
        ArrayList<TrackPutInfo> trackPutInfos = new ArrayList<TrackPutInfo>();
        List<Vec3> startSegment = pathPoi.getFirst();
        Vec3 firstPos = startSegment.getFirst();
        Vec3 firstDir = firstPos.subtract(startSegment.get(1)).multiply(1.0, 0.0, 1.0).normalize();
        List<Vec3> lastSegment = pathPoi.getLast();
        Vec3 lastPos = lastSegment.getLast();
        Vec3 lastDir = lastPos.subtract(lastSegment.get(lastSegment.size() - 2)).multiply(1.0, 0.0, 1.0).normalize();
        trackPutInfos.add(TrackPutInfo.getByDir(new BlockPos((int)con.start().x, (int)con.start().y, (int)con.start().z), con.startDir(), new TrackPutInfo.BezierInfo(con.start(), con.startDir(), firstPos.subtract(con.start()), firstDir)));
        for (int i = 0; i < pathPoi.size(); ++i) {
            List<Vec3> segment = pathPoi.get(i);
            for (int j = 0; j < segment.size() - 1; ++j) {
                Vec3 pA = segment.get(j);
                Vec3 pB = segment.get(j + 1);
                if (pA.y == pB.y) {
                    int k = 0;
                    while ((double)k < Math.abs(pA.x - pB.x)) {
                        int x = (int)(pA.x + (double)(MyMth.getSign(pB.x - pA.x) * k));
                        int z = (int)(pA.z + (double)(MyMth.getSign(pB.z - pA.z) * k));
                        trackPutInfos.add(TrackPutInfo.getByDir(new BlockPos(x, (int)pA.y, z), pB.subtract(pA), null));
                        ++k;
                    }
                    continue;
                }
                Vec3 dir = pB.subtract(pA).multiply(1.0, 0.0, 1.0).normalize();
                trackPutInfos.add(TrackPutInfo.getByDir(new BlockPos((int)pA.x, (int)pA.y, (int)pA.z), dir, new TrackPutInfo.BezierInfo(pA, dir, pB.subtract(pA), dir.reverse())));
            }
            if (i >= pathPoi.size() - 1) continue;
            List<Vec3> nextSegment = pathPoi.get(i + 1);
            Vec3 vecA = new Vec3(segment.getLast().x(), 0.0, segment.getLast().z()).subtract(new Vec3(segment.getFirst().x(), 0.0, segment.getFirst().z()));
            Vec3 vecB = new Vec3(nextSegment.getFirst().x(), 0.0, nextSegment.getFirst().z()).subtract(new Vec3(nextSegment.getLast().x(), 0.0, nextSegment.getLast().z()));
            Vec3 prevDirection = vecA.normalize();
            Vec3 currentDirection = vecB.normalize();
            Vec3 pos = segment.getLast();
            trackPutInfos.add(TrackPutInfo.getByDir(new BlockPos((int)pos.x, (int)pos.y, (int)pos.z), prevDirection, new TrackPutInfo.BezierInfo(pos, prevDirection, nextSegment.getFirst().subtract(segment.getLast()), currentDirection)));
        }
        trackPutInfos.add(TrackPutInfo.getByDir(new BlockPos((int)con.end().x, (int)con.end().y, (int)con.end().z), con.endDir(), new TrackPutInfo.BezierInfo(lastPos, lastDir, con.end().subtract(lastPos), con.endDir())));
        return trackPutInfos;
    }

    private void trimPathStart(List<int[]> path, int trimLength) {
        if (path.size() <= trimLength + 1) {
            return;
        }
        for (int i = 0; i < trimLength && path.size() > 2; ++i) {
            path.remove(0);
        }
    }

    private void trimPathEnd(List<int[]> path, int trimLength) {
        if (path.size() <= trimLength + 1) {
            return;
        }
        for (int i = 0; i < trimLength && path.size() > 2; ++i) {
            path.remove(path.size() - 1);
        }
    }

    private static List<double[]> adjustmentHeight(List<double[]> path) {
        ArrayList<double[]> adjustedPath = new ArrayList<double[]>();
        if (path.size() < 2) {
            return new LinkedList<double[]>();
        }
        double hStart = path.getFirst()[2];
        double hEnd = path.getLast()[2];
        double pNum = path.size() - 1;
        ArrayList<double[]> heightList0 = new ArrayList<double[]>();
        HashMap<Integer, List> heightGroups = new HashMap<Integer, List>();
        double distance = 0.0;
        for (int i = 0; i < path.size(); ++i) {
            double[] point = path.get(i);
            double h = point[2] - hStart * ((pNum - (double)i) / pNum) - hEnd * ((double)i / pNum);
            if (i > 0) {
                double h0 = point[2];
                double h1 = path.get(i - 1)[2];
                distance += 1.0 + Math.abs(h0 - h1);
            }
            double[] p = new double[]{point[0], point[1], h, i, distance};
            heightList0.add(p);
            int hi = (int)h;
            heightGroups.computeIfAbsent(hi, k -> new ArrayList()).add(p);
        }
        double sec = Math.sqrt(Math.pow(heightList0.size(), 2.0) + Math.pow(Math.abs(hStart - hEnd), 2.0)) / (double)heightList0.size();
        for (int j = 0; j < heightList0.size(); ++j) {
            boolean conditionTunnel;
            double h;
            List group;
            int groupIndex;
            double[] thisPoint = (double[])heightList0.get(j);
            adjustedPath.add(new double[]{thisPoint[0], thisPoint[1], thisPoint[2]});
            int hd = 0;
            if (j < heightList0.size() - 1) {
                hd = (int)((double[])heightList0.get(j + 1))[2] - (int)thisPoint[2];
            }
            if (hd == 0 || (groupIndex = (group = (List)heightGroups.get((int)(h = thisPoint[2]))).indexOf(thisPoint)) >= group.size() - 1) continue;
            double[] nextSameHeightPoint = (double[])group.get(groupIndex + 1);
            int nextPointIndex = heightList0.indexOf(nextSameHeightPoint);
            double dA = thisPoint[4];
            double dB = nextSameHeightPoint[4];
            double iA = thisPoint[3];
            double iB = nextSameHeightPoint[3];
            boolean conditionBridge = hd < 0 && (iB - iA) * 4.0 * sec < dB - dA;
            boolean bl = conditionTunnel = hd > 0 && (iB - iA) * 3.0 * sec < dB - dA;
            if (!conditionBridge && !conditionTunnel) continue;
            for (int k2 = j; k2 < nextPointIndex; ++k2) {
                double[] np1 = (double[])heightList0.get(k2 + 1);
                adjustedPath.add(new double[]{np1[0], np1[1], thisPoint[2]});
            }
            j = nextPointIndex;
        }
        for (int i = 0; i < adjustedPath.size(); ++i) {
            double[] p = (double[])adjustedPath.get(i);
            p[2] = p[2] + (hStart * ((pNum - (double)i) / pNum) + hEnd * ((double)i / pNum));
        }
        return adjustedPath;
    }

    public static class StraightPathFinder {
        private static final int[][] DIRECTIONS = new int[][]{{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}, {-1, -1}, {-1, 0}, {-1, 1}};

        public static List<List<int[]>> findStraightPaths(List<int[]> path) {
            ArrayList<List<int[]>> straightPaths = new ArrayList<List<int[]>>();
            if (path == null || path.size() < 2) {
                return straightPaths;
            }
            int startIndex = 0;
            while (startIndex < path.size() - 1) {
                int[] point2;
                int[] point1;
                int currentDirection;
                int currentIndex;
                int[] currentPoint = path.get(startIndex);
                int[] nextPoint = path.get(startIndex + 1);
                int direction = StraightPathFinder.getDirection(currentPoint, nextPoint);
                ArrayList<int[]> straightPath = new ArrayList<int[]>();
                straightPath.add(currentPoint);
                straightPath.add(nextPoint);
                for (currentIndex = startIndex + 1; currentIndex < path.size() - 1 && (currentDirection = StraightPathFinder.getDirection(point1 = path.get(currentIndex), point2 = path.get(currentIndex + 1))) == direction; ++currentIndex) {
                    straightPath.add(point2);
                }
                if (straightPath.size() >= 2) {
                    straightPaths.add(straightPath);
                }
                startIndex = currentIndex;
            }
            return straightPaths;
        }

        private static int getDirection(int[] point1, int[] point2) {
            int dx = point2[0] - point1[0];
            int dz = point2[1] - point1[1];
            if (dx != 0) {
                dx /= Math.abs(dx);
            }
            if (dz != 0) {
                dz /= Math.abs(dz);
            }
            for (int i = 0; i < DIRECTIONS.length; ++i) {
                if (DIRECTIONS[i][0] != dx || DIRECTIONS[i][1] != dz) continue;
                return i;
            }
            return -1;
        }
    }

    public record ResultWay(CurveRoute.CompositeCurve way, List<TrackPutInfo> trackPutInfos) {
    }
}

