import { useRef, useEffect, forwardRef, useImperativeHandle, useState } from 'react'
import { useFrame, useThree } from '@react-three/fiber'
import { useGLTF } from '@react-three/drei'
import { InstancedMesh, Matrix4, Vector3, BufferGeometry, Texture, Group, MeshBasicMaterial, OrthographicCamera } from 'three'
import sundialUrl from '/components/kit_02_lunisolar.glb'
import { SundialGLTF } from '@/types/gltfFiles'

interface StarLayer {
  mesh: InstancedMesh<BufferGeometry, MeshBasicMaterial>;
  speed: number;
  offsetX: number;
}

export interface StarfieldRef {
  group: Group;
}

const Starfield = forwardRef<StarfieldRef>((props, ref) => {
  const { nodes, materials } = useGLTF(sundialUrl) as SundialGLTF
  const groupRef = useRef<Group>(null)
  const [shouldRestart, setShouldRestart] = useState(false)
  const numLayers = 5
  const numVariants = 3
  const starsPerLayer = 21

  const starLayersRef = useRef<StarLayer[]>([])

  useImperativeHandle(ref, () => ({
    group: groupRef.current!
  }))

  const { camera } = useThree()

  // Create star layers
  useEffect(() => {
    if (!groupRef.current) return

    starLayersRef.current = []
    const baseTexture = materials.glitter.map as Texture

    for (let i = 0; i < numLayers; i++) {
      for (let variant = 0; variant < numVariants; variant++) {
        const geometry = nodes.glitter.geometry.clone()
        const material = new MeshBasicMaterial({
          map: baseTexture.clone(),
          transparent: true,
          depthWrite: false,
        })

        const mesh = new InstancedMesh(
          geometry,
          material,
          starsPerLayer
        )

        mesh.rotation.z = (Math.random() - 0.5) * Math.PI * 0.1
        mesh.instanceMatrix.setUsage(WebGLRenderingContext.DYNAMIC_DRAW)

        const starLayer: StarLayer = {
          mesh,
          speed: Math.random() * 0.0005 + 0.001,
          offsetX: Math.random() * 3
        }

        for (let j = 0; j < starsPerLayer; j++) {
          const matrix = new Matrix4()
          const [x, y] = [Math.random() * 2 - 1, Math.random() * 2 - 1]
          matrix.setPosition(x, y, 0)
          const scale = Math.random() * 0.005 + 0.003
          matrix.scale(new Vector3(scale, scale, scale))
          mesh.setMatrixAt(j, matrix)
        }

        groupRef.current.add(mesh)
        starLayersRef.current.push(starLayer)
      }
    }

    setShouldRestart(false);
    return () => {
      starLayersRef.current.forEach(layer => groupRef.current?.remove(layer.mesh))
    }
  }, [nodes.glitter.geometry, materials.glitter, shouldRestart])

  useFrame((state, delta) => {
    if (!(camera instanceof OrthographicCamera)) return

    const position = new Vector3()
    const matrix = new Matrix4()
    const frameTime = state.clock.getElapsedTime();

    if (frameTime > 60) {
      setShouldRestart(true)
      state.clock.start()
    }

    starLayersRef.current.forEach((starLayer) => {
      const { mesh, speed, offsetX } = starLayer

      for (let i = 0; i < starsPerLayer; i++) {
        mesh.getMatrixAt(i, matrix)
        position.setFromMatrixPosition(matrix)

        position.x += speed * delta * 5
        position.y += speed * delta * 5

        if (position.x > camera.right) position.x = camera.left
        if (position.x < camera.left) position.x = camera.right
        if (position.y > camera.top) position.y = camera.bottom
        if (position.y < camera.bottom) position.y = camera.top

        matrix.setPosition(position)
        mesh.setMatrixAt(i, matrix)

        const newTime = state.clock.getElapsedTime()
        const newY = (Math.floor(((speed * newTime * 500) + offsetX) % 4) / 4) * 0.5
        const newY2 = newY + 0.125

        for (let j = 0; j < 3; j++) {
          mesh.geometry.attributes.uv.setY(4 * i + 4 * j, newY)
          mesh.geometry.attributes.uv.setY(4 * i + 4 * j + 2, newY)
          mesh.geometry.attributes.uv.setY(4 * i + 4 * j + 1, newY2)
          mesh.geometry.attributes.uv.setY(4 * i + 4 * j + 3, newY2)
        }
      }

      mesh.instanceMatrix.needsUpdate = true
      mesh.geometry.attributes.uv.needsUpdate = true
    })
  })

  return <group ref={groupRef} />
})

export default Starfield

useGLTF.preload(sundialUrl)
