import { Sphere } from "@react-three/drei"
import { useFrame } from "@react-three/fiber"
import React, { useEffect, useMemo, useRef } from "react"
import {
	CatmullRomCurve3,
	Color,
	Group,
	HSL,
	MathUtils,
	Mesh,
	MeshBasicMaterial,
	Vector3,
} from "three"
import useMidiStore from "../stores/midi"
import { soundwave } from "../utils/soundwave"

export function damp(
	target: Vector3,
	to: Vector3,
	step: number,
	delta: number
) {
	target.x = MathUtils.damp(target.x, to.x, step, delta)
	target.y = MathUtils.damp(target.y, to.y, step, delta)
	target.z = MathUtils.damp(target.z, to.z, step, delta)

	return target
}

const resolution = 100

const SoundwaveDisplay = ({
	index,
	color,
	offset,
	hidden,
	standByAnimation,
}: {
	index?: number
	color: Color
	offset?: number
	hidden: boolean
	standByAnimation?: boolean
}) => {
	const allNotes = useMidiStore((state) => Array.from(state.notes))

	const notes = useMemo(
		() =>
			index !== undefined
				? allNotes.length > index
					? [allNotes[index]]
					: []
				: allNotes,
		[allNotes, index]
	)

	const displayCurve = useRef<Mesh>(null)

	const group = useRef<Group>(null)

	const points = useRef<Vector3[]>(
		new Array(resolution)
			.fill(0)
			.map((_, i) => new Vector3(i / resolution, 0, 0))
	)

	const spline = useRef(
		new CatmullRomCurve3(points.current, false, "catmullrom", 0.5)
	)

	useEffect(() => {
		const waves = notes.map((note) => {
			return {
				frequency: note - 30,
				amplitude: 0.5,
			}
		})

		const _points = new Array(resolution).fill(0).map((_, i) => {
			return new Vector3(
				(i / resolution) * 1,
				0,
				waves.length > 0 ? soundwave(waves, i / resolution) * 0.1 : 0
			)
		})

		const zBounds = _points.reduce(
			(acc, point) => {
				if (point.z < acc.min) {
					acc.min = point.z
				}
				if (point.z > acc.max) {
					acc.max = point.z
				}
				return acc
			},
			{ min: Infinity, max: -Infinity }
		)

		if (displayCurve.current) {
			displayCurve.current.scale.set(1, 1, 10 / (zBounds.max - zBounds.min) / 2)
		}

		points.current = _points
	}, [index, notes])

	useFrame(({ clock }, delta) => {
		if (spline.current && group.current) {
			;(group.current.children as unknown as Mesh[]).forEach((child, i) => {
				if (points.current[i]) {
					const newPos = damp(
						child.position.clone(),
						standByAnimation
							? new Vector3(
									points.current[i].x,
									points.current[i].y,
									points.current[i].z +
										Math.sin(clock.elapsedTime * 4 + i) * 0.01
							  )
							: points.current[i],
						8,
						delta
					)
					child.position.set(newPos.x, newPos.y, newPos.z)

					const material = child.material as MeshBasicMaterial

					const existingHsl: HSL = { h: 0, s: 0, l: 0 }
					material.color.getHSL(existingHsl, undefined)

					const newHsl: HSL = { h: 0, s: 0, l: 0 }
					if (!hidden) {
						color.getHSL(newHsl, undefined)
					}

					const dampedHsl = damp(
						new Vector3(existingHsl.h, existingHsl.s, existingHsl.l),
						new Vector3(newHsl.h, newHsl.s, newHsl.l),
						8,
						delta
					)
					material.color.setHSL(dampedHsl.x, dampedHsl.y, dampedHsl.z)
				}
			})
		}

		if (group.current) {
			const newPos = damp(
				group.current.position.clone(),
				new Vector3(hidden ? -0.45 : -0.9, 0, hidden ? 0 : offset ?? 0),
				8,
				delta
			)
			group.current.position.set(newPos.x, newPos.y, newPos.z)

			const _scale = hidden ? 0.5 : 1
			const newScale = damp(
				group.current.scale,
				new Vector3(_scale * 1.8, _scale, _scale),
				8,
				delta
			)
			group.current.scale.set(newScale.x, newScale.y, newScale.z)
		}
	})

	return (
		<>
			<group ref={group} position={[-0.9, 0, 0]}>
				{points.current.map((p, i) => (
					<Sphere key={i} scale={0.004} position={p}>
						<meshBasicMaterial />
					</Sphere>
				))}
			</group>
		</>
	)
}

export default SoundwaveDisplay
