import {extend} from "@react-three/fiber";
import {BufferGeometry, Float32BufferAttribute} from "three";

export class CustomTaperedExtrusion extends BufferGeometry {
  constructor(shape1, shape2, options = { capStart: false, capEnd: false }) {
    super();

    if (shape1.length !== shape2.length) {
      throw new Error(`Shapes must have the same number of vertices, shape1: ${shape1.length}, shape2:${shape2.length}`);
    }

    const vertices = [];
    const indices = [];
    const uvs = [];

    const vertexCount = shape1.length;

    for (let i = 0; i < vertexCount; i++) {
      const iNext = (i + 1) % vertexCount;

      // Current shape indices
      const v1 = i;
      const v2 = iNext;
      const v3 = i + vertexCount;
      const v4 = iNext + vertexCount;

      // Push positions for the current quad
      vertices.push(
        shape1[v1][0], shape1[v1][1], shape1[v1][2],
        shape1[v2][0], shape1[v2][1], shape1[v2][2],
        shape2[v1][0], shape2[v1][1], shape2[v1][2],
        shape2[v2][0], shape2[v2][1], shape2[v2][2]
      );

      // Create indices for two triangles of the quad
      const baseIndex = i * 4;
      indices.push(
        baseIndex, baseIndex + 1, baseIndex + 2, // Triangle 1
        baseIndex + 2, baseIndex + 1, baseIndex + 3  // Triangle 2
      );

      // Push UVs for the quad
      uvs.push(
        i / vertexCount, 0,
        (i + 1) / vertexCount, 0,
        i / vertexCount, 1,
        (i + 1) / vertexCount, 1
      );
    }

    // Optionally cap the start of the extrusion
    if (options.capStart) {
      const startOffset = vertices.length / 3;
      for (let i = 0; i < vertexCount; i++) {
        vertices.push(shape1[i][0], shape1[i][1], shape1[i][2]);
        uvs.push(0.5 + shape1[i][0] / 2, 0.5 + shape1[i][1] / 2);
      }
      for (let i = 1; i < vertexCount - 1; i++) {
        indices.push(startOffset, startOffset + i, startOffset + i + 1);
      }
    }

    // Optionally cap the end of the extrusion
    if (options.capEnd) {
      const endOffset = vertices.length / 3;
      for (let i = 0; i < vertexCount; i++) {
        vertices.push(shape2[i][0], shape2[i][1], shape2[i][2]);
        uvs.push(0.5 + shape2[i][0] / 2, 0.5 + shape2[i][1] / 2);
      }
      for (let i = 1; i < vertexCount - 1; i++) {
        indices.push(endOffset, endOffset + i + 1, endOffset + i);
      }
    }

    this.setAttribute(
      'position',
      new Float32BufferAttribute(vertices, 3)
    );

    this.setAttribute(
      'uv',
      new Float32BufferAttribute(uvs, 2)
    );

    this.setIndex(indices);

    this.computeVertexNormals();
  }
}
extend({CustomTaperedExtrusion})