import { ShaderLib, ShaderMaterial, Sprite, Texture, UniformsUtils } from "three";

export class AnimatedSprite extends Sprite {
  private readonly tilesHoriz: number;
  private readonly tilesVert: number;
  private verticalOffset: number;
  private tiles: number;
  private _currentFrame: number = 0;

  constructor(texture: Texture, tilesHoriz: number, tilesVert: number, initialTile: number, verticalOffset: number) {
    const spriteMaterial = new ShaderMaterial({
      uniforms: {
        ...UniformsUtils.clone(ShaderLib.sprite.uniforms),
        map: { type: "t", value: texture },
        columnsCount: { type: "f", value: tilesHoriz },
        rowsCount: { type: "f", value: tilesVert },
        columnIndex: { type: "f", value: 0 },
        rowIndex: { type: "f", value: 0 },
        verticalOffset: { type: "f", value: verticalOffset },
      },
      vertexShader,
      fragmentShader,
      transparent: true,
    });

    // a few hacks to make default raycasting work
    // @ts-ignore
    spriteMaterial.sizeAttenuation = true;
    // @ts-ignore
    spriteMaterial.rotation = 0;

    super(spriteMaterial as any);

    this.tilesHoriz = tilesHoriz;
    this.tilesVert = tilesVert;
    this.verticalOffset = verticalOffset;
    this.tiles = tilesHoriz * tilesVert;
    this._currentFrame = initialTile;

    this.offsetTexture();
  }

  public setFrame(frameID: number) {
    this._currentFrame = frameID;
    this.offsetTexture();
  }

  public setVerticalOffset(verticalOffset: number) {
    this.verticalOffset = verticalOffset;
    // @ts-ignore
    this.material.uniforms.verticalOffset.value = verticalOffset;
  }

  get currentFrame(): number {
    return this._currentFrame;
  }

  private offsetTexture() {
    if (this.material) {
      // @ts-ignore
      this.material.uniforms.columnIndex.value = this.getColumn();
      // @ts-ignore
      this.material.uniforms.rowIndex.value = this.tilesVert - this.getRow() - 1;
    }
  }

  private getColumn() {
    return this._currentFrame % this.tilesHoriz;
  }

  private getRow() {
    return Math.floor(this._currentFrame / this.tilesHoriz);
  }
}

export class SpriteAnimation {
  private mustLoop: boolean = false;
  private paused: boolean = true;
  private clampWhenFinished: boolean = true;
  private readonly frameStart: number;
  private readonly frameEnd: number;
  private readonly frameDisplayDuration: number;
  private sprite: AnimatedSprite;
  private currentDisplayTime: number = 0;
  private eventListeners: Map<string, Function[]> = new Map();

  constructor(sprite: AnimatedSprite, frameStart: number, frameEnd: number, frameDisplayDuration: number) {
    this.sprite = sprite;
    this.frameStart = frameStart;
    this.frameEnd = frameEnd;
    this.frameDisplayDuration = frameDisplayDuration;
  }

  addEventListener(type: string, listener: Function) {
    let bucket = this.eventListeners.get(type);
    if (!bucket) {
      bucket = [];
      this.eventListeners.set(type, bucket);
    }
    bucket.push(listener);
  }

  removeEventListener(type: string, listener: Function) {
    const bucket = this.eventListeners.get(type);
    if (bucket) {
      this.eventListeners.set(
        type,
        bucket.filter((l) => l === listener)
      );
    }
  }

  playOnce() {
    this.mustLoop = false;
    this.paused = false;
    this.sprite.visible = true;
    this.sprite.setFrame(this.frameStart);
  }

  resume() {
    if (this.sprite.currentFrame > this.frameStart && this.sprite.currentFrame < this.frameEnd) {
      this.sprite.setFrame(this.frameStart);
    }
    this.paused = false;
    this.sprite.visible = true;
  }

  playLoop() {
    this.mustLoop = true;
    this.paused = false;
    this.sprite.visible = true;
    this.sprite.setFrame(this.frameStart);
  }

  pauseNextEnd() {
    this.mustLoop = false;
  }

  pause() {
    this.paused = true;
  }

  stop() {
    this.paused = true;
    this.sprite.setFrame(this.frameStart);
  }

  setLoop(loop: boolean) {
    this.mustLoop = loop;
  }

  update(delta: number) {
    if (this.paused) {
      return;
    }

    this.currentDisplayTime += delta * 1000;

    while (this.currentDisplayTime > this.frameDisplayDuration) {
      this.currentDisplayTime -= this.frameDisplayDuration;
      this.sprite.setFrame(this.sprite.currentFrame + 1);

      // Restarts the animation if the last frame was reached at last call.
      if (this.sprite.currentFrame > this.frameEnd) {
        this.sprite.setFrame(this.frameStart);
        // Call the user callbacks on the event 'loop'
        if (this.mustLoop) {
          this.eventListeners.get("loop")?.forEach((listener) => {
            listener({
              type: "loop",
            });
          });
        } else {
          if (this.clampWhenFinished) {
            this.paused = true;
            this.callFinishedListeners();
          } else {
            this.paused = true;
            // setTimeout(() => {
            //   updateAction(action, action.tileDisplayDuration);
            //   callFinishedListeners(action);
            // }, this.frameDisplayDuration);
          }
        }
      }
    }
  }

  private callFinishedListeners() {
    this.eventListeners.get("finished")?.forEach((listener) => {
      listener({
        type: "finished",
      });
    });
  }
}

const vertexShader = `
uniform float columnsCount;
uniform float rowsCount;
uniform float columnIndex;
uniform float rowIndex;
uniform float rotation;
uniform vec2 center;
#include <common>
#include <uv_pars_vertex>
#include <fog_pars_vertex>
#include <logdepthbuf_pars_vertex>
#include <clipping_planes_pars_vertex>

attribute vec2 uvOffset;
varying vec2 vUV;

void main() {
  float partWidth = 1.0 / columnsCount;
  float partHeight = 1.0 / rowsCount;

  vec2 offset = vec2(partWidth * columnIndex, partHeight * rowIndex);
  vUV = uv * vec2(partWidth, partHeight) + offset;

  vec4 mvPosition = modelViewMatrix * vec4( 0.0, 0.0, 0.0, 1.0 );
  vec2 scale;
  scale.x = length( vec3( modelMatrix[ 0 ].x, modelMatrix[ 0 ].y, modelMatrix[ 0 ].z ) );
  scale.y = length( vec3( modelMatrix[ 1 ].x, modelMatrix[ 1 ].y, modelMatrix[ 1 ].z ) );

  vec2 alignedPosition = ( position.xy - ( center - vec2( 0.5 ) ) ) * scale;
  vec2 rotatedPosition;
  rotatedPosition.x = cos( rotation ) * alignedPosition.x - sin( rotation ) * alignedPosition.y;
  rotatedPosition.y = sin( rotation ) * alignedPosition.x + cos( rotation ) * alignedPosition.y;
  mvPosition.xy += rotatedPosition;

  gl_Position = projectionMatrix * mvPosition;

  #include <logdepthbuf_vertex>
  #include <clipping_planes_vertex>
  #include <fog_vertex>
}                
`;

const fragmentShader = `
varying vec2 vUV;
uniform sampler2D map;
uniform float verticalOffset;

void main() {
  if (verticalOffset >= 0.0) {
    if (vUV.y > verticalOffset) {
      gl_FragColor = texture2D(map, vUV);
    } else {
      gl_FragColor = vec4(0.0, 0.0, 0.0, 0.0);
    }
  } else {
    if (vUV.y > (verticalOffset * -1.0)) {
      gl_FragColor = vec4(0.0, 0.0, 0.0, 0.0);
    } else {
      gl_FragColor = texture2D(map, vUV);
    }
  }
  #include <encodings_fragment>
}
`;
