import {BufferGeometry, Float32BufferAttribute, Vector3} from "three";
import {extend} from "@react-three/fiber";
import {calculateNewPointWithLengthVector} from "../../../helpers/helpers";
import * as THREE from "three";

class CustomLinearExtrudedPointShape extends THREE.BufferGeometry {
  constructor(points, lengthVector, options = { caps: true }) {
    super();

    const vertices = [];
    const indices = [];
    const uvs = [];
    const normals = [];
    const bottomVertices = [];
    const topVertices = [];
    const { caps } = options;

    // Generate vertices, uvs, and placeholders for normals
    points.forEach((point, i) => {
      const [x, y, z] = point;

      // Bottom vertex
      vertices.push(x, y, z);
      bottomVertices.push(vertices.length / 3 - 1);
      uvs.push(i / (points.length - 1), 0); // UV for bottom
      normals.push(0, 0, 0); // Placeholder for normals

      // Extruded (top) vertex
      const newPoint = calculateNewPointWithLengthVector(
        new THREE.Vector3(x, y, z),
        lengthVector
      );
      vertices.push(newPoint.x, newPoint.y, newPoint.z);
      topVertices.push(vertices.length / 3 - 1);
      uvs.push(i / (points.length - 1), 1); // UV for top
      normals.push(0, 0, 0); // Placeholder for normals
    });

    // Generate side face indices and normals
    for (let i = 0; i < points.length - 1; i++) {
      const base1 = i * 2;
      const base2 = (i + 1) * 2;

      // First triangle
      indices.push(base1, base2, base1 + 1);

      // Second triangle
      indices.push(base2, base2 + 1, base1 + 1);

      // Calculate normals for the side face
      const v1 = new THREE.Vector3(
        vertices[base2 * 3] - vertices[base1 * 3],
        vertices[base2 * 3 + 1] - vertices[base1 * 3 + 1],
        vertices[base2 * 3 + 2] - vertices[base1 * 3 + 2]
      );
      const v2 = new THREE.Vector3(
        vertices[base1 * 3 + 3] - vertices[base1 * 3],
        vertices[base1 * 3 + 4] - vertices[base1 * 3 + 1],
        vertices[base1 * 3 + 5] - vertices[base1 * 3 + 2]
      );
      const normal = new THREE.Vector3().crossVectors(v1, v2).normalize();

      // Add normals to each vertex
      [base1, base2, base1 + 1, base2 + 1].forEach((index) => {
        normals[index * 3] += normal.x;
        normals[index * 3 + 1] += normal.y;
        normals[index * 3 + 2] += normal.z;
      });
    }

    // Caps
    if (caps) {
      // Bottom cap
      const bottomCenterIndex = vertices.length / 3;
      const bottomCenter = new THREE.Vector3();
      points.forEach((point) => bottomCenter.add(new THREE.Vector3(...point)));
      bottomCenter.divideScalar(points.length);
      vertices.push(bottomCenter.x, bottomCenter.y, bottomCenter.z);
      uvs.push(0.5, 0.5);
      normals.push(-lengthVector.x, -lengthVector.y, -lengthVector.z); // Normal opposite lengthVector

      for (let i = 0; i < bottomVertices.length; i++) {
        const next = (i + 1) % bottomVertices.length;
        indices.push(bottomVertices[next], bottomVertices[i], bottomCenterIndex);
      }

      // Top cap
      const topCenterIndex = vertices.length / 3;
      const topCenter = bottomCenter.clone().add(lengthVector);
      vertices.push(topCenter.x, topCenter.y, topCenter.z);
      uvs.push(0.5, 0.5);
      normals.push(lengthVector.x, lengthVector.y, lengthVector.z); // Normal aligns with lengthVector

      for (let i = 0; i < topVertices.length; i++) {
        const next = (i + 1) % topVertices.length;
        indices.push(topVertices[i], topVertices[next], topCenterIndex);
      }
    }

    // Normalize normals
    for (let i = 0; i < normals.length; i += 3) {
      const normal = new THREE.Vector3(
        normals[i],
        normals[i + 1],
        normals[i + 2]
      ).normalize();
      normals[i] = normal.x;
      normals[i + 1] = normal.y;
      normals[i + 2] = normal.z;
    }

    // Set attributes
    this.setAttribute("position", new THREE.Float32BufferAttribute(vertices, 3));
    this.setAttribute("normal", new THREE.Float32BufferAttribute(normals, 3));
    this.setAttribute("uv", new THREE.Float32BufferAttribute(uvs, 2));
    this.setIndex(indices);
  }

  static fromJSON(data) {
    return new CustomLinearExtrudedPointShape(
      data.points,
      data.lengthVector,
      data.options
    );
  }
}


extend({ CustomLinearExtrudedPointShape });
