import type { Results } from "@mediapipe/selfie_segmentation";
import { SelfieSegmentation } from "@mediapipe/selfie_segmentation";

export type InputType = HTMLImageElement | HTMLVideoElement | HTMLCanvasElement;

export class SelfieSegmentationService {
  private readonly model: SelfieSegmentation;
  public selectedModel: number = 0;
  public selfieMode = false;

  public constructor() {
    this.model = new SelfieSegmentation({
      locateFile: (file: string) => {
        return `https://cdn.jsdelivr.net/npm/@mediapipe/selfie_segmentation/${file}`;
      }
    });
    this.setOptions();
  }

  /** Set if this is a selfie
   * @param value If `true` the result will be flipped horizontally
   */
  public set isSelfie(value: boolean) {
    this.selfieMode = value;
    this.setOptions();
  }


  /** Select the model used for segmentation
   * @param value Use 1 for the `General` model or 0 for `Landscape` model. The `Landscape` model runs faster,
   * but the `General` model is more accurate, giving better segmentation result overall. The default is 1.
   */
  public set modelSelection(value: 0 | 1) {
    if (value < 0 || value > 1) value = 1;
    this.selectedModel = value;
    this.setOptions();
  }

  public setOptions(): void {
    this.model.setOptions({
      selfieMode: this.selfieMode,
      modelSelection: this.selectedModel,
    });
  }

  /** Run the segmentation on the detected human body and remove the background. This is running asynchronously
   * via callback, so you have to define the output param beforehand
   * @param input Either `HTMLImageElement`, `HTMLCanvasElement`, or `HTMLVideoElement` is accepted
   * @param output The canvas element that will contain the result.
   * @returns void - No returns, result is written directly to the `output` canvas
   */
/*   public async run(
    input: InputType,
    output: HTMLCanvasElement
  ): Promise<void> {
    const canvasCtx = output.getContext('2d');
    if (canvasCtx == null) throw new Error('No canvas context found');

    const resultCallback = (results: Results): void => {
      canvasCtx.save();
      canvasCtx.clearRect(0, 0, output.width, output.height);
      canvasCtx.drawImage(results.image, 0, 0, output.width, output.height);

      canvasCtx.globalCompositeOperation = 'destination-atop';
      canvasCtx.drawImage(results.segmentationMask, 0, 0, output.width, output.height);
      canvasCtx.restore();
    }

    this.model.onResults(resultCallback);
    await this.model.send({ image: input });
  } */

  public async run(
    input: InputType,
  ): Promise<string> {
    return new Promise(async (resolve, reject) => {
      try {
        const resultCallback = (results: Results): void => {
          const output = document.createElement('canvas') as HTMLCanvasElement;
          const canvasCtx = output.getContext('2d');
          if (canvasCtx == null) throw new Error('No canvas context found');
          output.width = results.image.width;
          output.height = results.image.height;
          canvasCtx.save();
          canvasCtx.clearRect(0, 0, output.width, output.height);
          canvasCtx.drawImage(results.image, 0, 0, output.width, output.height);

          canvasCtx.globalCompositeOperation = 'destination-atop';
          canvasCtx.drawImage(results.segmentationMask, 0, 0, output.width, output.height);
          canvasCtx.restore();
          const segmentedImage = output.toDataURL();
          resolve(segmentedImage);
        }

        this.model.onResults(resultCallback);
        await this.model.send({ image: input });
      }
      catch (err) {
        console.error(err);
        reject(err)
      }
    })
  }
}
