import * as ort from 'onnxruntime-web';
import Tesseract from 'tesseract.js';

import cardJson from './card-tempale.json';

let { cv } = window as any;

/**
 * Constants
 */

const MAX_SIZE = 800;
const MAX_FEATURES = 1000;

const PYRAMID_LEVEL = 4;
const SCALE_FACTOR = 1.2;

const KNN_MATCH_RATIO = 0.8;

const MIN_MATCH_COUNT = 10;

const kPatchSize = 32;

const computeImagePyramid = (/** @type {CvMat} */ inputImage: any) => {
  const imagePyramid = [];

  let srcImage = inputImage.clone();
  imagePyramid.push(srcImage);

  for (let i = 1; i < PYRAMID_LEVEL; i += 1) {
    srcImage = srcImage.clone();
    cv.resize(
      srcImage,
      srcImage,
      { width: 0, height: 0 },
      1 / SCALE_FACTOR,
      1 / SCALE_FACTOR
    );

    imagePyramid.push(srcImage);
  }

  return imagePyramid;
};

/**
 * Utilities
 */

const extractPatch = (feature: any, imagePyramid: any) => {
  const img = imagePyramid[feature.octave];
  const scaleFactor = 1 / SCALE_FACTOR ** feature.octave;
  const center = {
    x: feature.pt.x * scaleFactor,
    y: feature.pt.y * scaleFactor,
  };
  const rot = cv.getRotationMatrix2D(center, feature.angle, 1.0);
  rot.doublePtr(0, 2)[0] += kPatchSize / 2 - center.x;
  rot.doublePtr(1, 2)[0] += kPatchSize / 2 - center.y;

  const croppedImg = new cv.Mat();
  cv.warpAffine(img, croppedImg, rot, {
    width: kPatchSize,
    height: kPatchSize,
  });

  rot.delete();

  return croppedImg;
};

const computeDescriptors = async (
  /** @type {CvMat} */ img: any,
  detector: any,
  /** @type {InferenceSession} */ ortSession: any,
) => {
  console.time('pre-process');
  const image = img.clone();

  const scale = MAX_SIZE / Math.max(image.cols, image.rows);
  cv.resize(image, image, { width: 0, height: 0 }, scale, scale);
  cv.cvtColor(image, image, cv.COLOR_RGB2GRAY);
  console.timeEnd('pre-process');

  console.time('detectKeypoint');
  const keypoints = new cv.KeyPointVector();
  detector.detect(image, keypoints);
  console.timeEnd('detectKeypoint');

  const N = keypoints.size();
  console.log('keypoints', N);

  if (N < 200) {
    return 'Not enough keypoints';
  }

  console.time('computeImagePyramid');
  const imagePyramid = computeImagePyramid(image);
  console.timeEnd('computeImagePyramid');

  // image.delete();

  console.time('extractPatch');
  const patchMat = [];
  for (let i = 0; i < N; i += 1) {
    const keypoint = keypoints.get(i);
    const patch = extractPatch(keypoint, imagePyramid);
    patch.convertTo(patch, cv.CV_32FC1, 1.0 / 128.0, -1.0);
    patchMat.push(...patch.data32F);
    patch.delete();
  }
  console.timeEnd('extractPatch');

  // scale back keypoints
  for (let i = 0; i < keypoints.size(); i += 1) {
    const keypoint = keypoints.get(i);
    keypoint.pt.x /= scale;
    keypoint.pt.y /= scale;
    keypoints.set(i, keypoint);
  }

  // keypoints.delete();
  while (imagePyramid.length > 0) {
    imagePyramid.pop().delete();
  }

  const tensor = new ort.Tensor('float32', patchMat, [N, 32, 32, 1]);

  // // warmup
  // for (let i = 0; i < 5; i+=1) {
  //   console.time("warmup");
  //   await ortSession.run({ rgb_to_grayscale_1: tensor });
  //   console.timeEnd("warmup");
  // }

  console.time('inference');
  const {
    'siamese_neural_congas_1/feature_compression/normalize_embeddings':
    descriptors,
  } = await ortSession.run({ rgb_to_grayscale_1: tensor }, [
    'siamese_neural_congas_1/feature_compression/normalize_embeddings',
  ]);
  console.timeEnd('inference');

  const desc = cv.matFromArray(N, 40, cv.CV_32F, descriptors.data);

  // eslint-disable-next-line consistent-return
  return {
    keypoints,
    desc,
  };
};

/**
 * Main function
 */

const ocrCard2 = async (
  imageSource: any,
  {
    size: trainSize, loc: [sT, sL, sB, sR], kp: trainKp, desc: trainDesc
  }: any,
  detector: any,
  matcher: any,
  /** @type {InferenceSession} */ ortSession: any,
  /** @type {Tesseract.Worker} */ ocr: any
) => {
  console.time('ocrCard');

  console.time('readImage');
  const img = cv.imread(imageSource);
  console.timeEnd('readImage');

  const { desc: queryDesc, keypoints: queryKp }: any = await computeDescriptors(
    img,
    detector,
    ortSession
  );

  console.time('match');
  const cvMatches = new cv.DMatchVectorVector();
  matcher.knnMatch(queryDesc, trainDesc, cvMatches, 2);

  const goodMatches = new cv.DMatchVector();

  for (let i = 0; i < cvMatches.size(); i += 1) {
    const m = cvMatches.get(i);

    if (m.get(0).distance < KNN_MATCH_RATIO * m.get(1).distance) {
      goodMatches.push_back(m.get(0));
    }
  }

  console.log(
    'matches',
    Array.from({ length: cvMatches.size() }, (_, i) => cvMatches.get(i)).map(
      (m) => [
        m.get(0).distance / m.get(1).distance,
        m.get(0).distance,
        m.get(1).distance,
      ]
    )
  );

  cvMatches.delete();

  console.timeEnd('match');

  console.log('goodMatches', goodMatches.size());
  console.table(
    Array.from({ length: goodMatches.size() }, (_, i) => i).map((i) => goodMatches.get(i)),
    ['queryIdx', 'trainIdx', 'distance']
  );

  if (goodMatches.size() < MIN_MATCH_COUNT) {
    const imKeypoints = new cv.Mat();
    cv.drawKeypoints(img, queryKp, imKeypoints);
    // cv.imshow(debugEl, imKeypoints);

    imKeypoints.delete();

    return 'Not enough matches';
  }

  const queryPoints = [];
  const trainPoints = [];

  console.log('query -> train');
  for (let i = 0; i < goodMatches.size(); i += 1) {
    queryPoints.push(queryKp.get(goodMatches.get(i).queryIdx).pt.x);
    queryPoints.push(queryKp.get(goodMatches.get(i).queryIdx).pt.y);
    trainPoints.push(trainKp[goodMatches.get(i).trainIdx * 2]);
    trainPoints.push(trainKp[goodMatches.get(i).trainIdx * 2 + 1]);

    console.log(
      `(${queryPoints[queryPoints.length - 2].toFixed(2)}, ${queryPoints[
        queryPoints.length - 1
      ].toFixed(2)}) -> (${trainPoints[trainPoints.length - 2].toFixed(
        2
      )}, ${trainPoints[trainPoints.length - 1].toFixed(2)})`
    );
  }

  queryKp.delete();

  const queryPointsMat = cv.matFromArray(
    goodMatches.size(),
    2,
    cv.CV_32F,
    queryPoints
  );
  const trainPointsMat = cv.matFromArray(
    goodMatches.size(),
    2,
    cv.CV_32F,
    trainPoints
  );

  goodMatches.delete();

  console.time('findHomography');
  // const M = cv.findHomography(trainPointsMat, queryPointsMat);
  const M = cv.findHomography(trainPointsMat, queryPointsMat, cv.RANSAC, 5);
  console.timeEnd('findHomography');

  queryPointsMat.delete();
  trainPointsMat.delete();

  if (M.empty()) {
    return 'homography matrix empty!';
  }

  console.time('debug');
  console.log('train -> query');

  const ptsArr = [
    [0, 0],
    [trainSize.width, 0],
    [trainSize.width, trainSize.height],
    [0, trainSize.height],
  ];
  const pts = cv.matFromArray(ptsArr.length, 1, cv.CV_32FC2, ptsArr.flat());

  const ptsInQuery = cv.Mat.zeros(4, 1, cv.CV_32FC2);

  cv.perspectiveTransform(pts, ptsInQuery, M);

  for (let i = 0; i < ptsInQuery.rows; i += 1) {
    console.log(
      `(${pts.floatAt(i, 0)}, ${pts.floatAt(i, 1)}) -> (${ptsInQuery
        .floatAt(i, 0)
        .toFixed(2)}, ${ptsInQuery.floatAt(i, 1).toFixed(2)})`
    );
  }

  const scratchLoc = [
    [sL, sT],
    [sL, sB],
    [sR, sB],
    [sR, sT],
  ];
  const scratch = cv.matFromArray(
    scratchLoc.length,
    1,
    cv.CV_32FC2,
    scratchLoc.flat()
  );
  const dstScratch = cv.Mat.zeros(4, 1, cv.CV_32FC2);
  cv.perspectiveTransform(scratch, dstScratch, M);

  for (let i = 0; i < dstScratch.rows; i += 1) {
    console.log(
      `(${scratch.floatAt(i, 0)}, ${scratch.floatAt(
        i,
        1
      )}) -> (${dstScratch.floatAt(i, 0)}, ${dstScratch.floatAt(i, 1)})`
    );
  }

  const dstVec = new cv.MatVector();
  const dstInt = new cv.Mat();
  ptsInQuery.convertTo(dstInt, cv.CV_32SC2);
  dstVec.push_back(dstInt);

  const dstScratchInt = new cv.Mat();
  dstScratch.convertTo(dstScratchInt, cv.CV_32SC2);
  dstVec.push_back(dstScratchInt);

  const debugImg = img.clone();
  cv.polylines(debugImg, dstVec, true, new cv.Scalar(0, 255, 0, 255), 2);

  pts.delete();
  ptsInQuery.delete();
  dstInt.delete();
  dstScratch.delete();
  dstScratchInt.delete();
  dstVec.delete();

  /** @type {HTMLCanvasElement} */
  const tmpCanvas = document.createElement('canvas');

  cv.imshow(tmpCanvas, debugImg);
  console.log(tmpCanvas.toDataURL('image/jpeg'));

  debugImg.delete();
  console.timeEnd('debug');

  console.time('extractROI');
  const scratchImg = new cv.Mat();

  // M is transform from train space to query space, so we need to inverse it
  cv.warpPerspective(
    img,
    scratchImg,
    M,
    trainSize,
    // eslint-disable-next-line no-bitwise
    cv.INTER_LINEAR | cv.WARP_INVERSE_MAP
  );

  img.delete();
  M.delete();

  // crop the scratch area
  const roiRect = new cv.Rect(sL, sT, sR - sL, sB - sT);
  const roiMat = scratchImg.roi(roiRect);
  scratchImg.delete();

  cv.imshow(tmpCanvas, roiMat);
  roiMat.delete();

  const imgUri = tmpCanvas.toDataURL('image/jpeg');
  console.timeEnd('extractROI');

  console.log(imgUri);
  // (scratchEl as any).src = imgUri;

  console.time('ocr');
  const { data } = await ocr.recognize(imgUri);
  console.timeEnd('ocr');

  console.log('ocr', data);

  // queryDesc.delete();

  console.timeEnd('ocrCard');

  // eslint-disable-next-line consistent-return
  return data.text.trim();
};

/**
 * Dependencies
 */

// eslint-disable-next-line consistent-return
const waitForCv = cv.then((_cv: any) => {
  cv = _cv;
  console.log('cv loaded');
});
const ortPromise = ort.InferenceSession.create(
  '/models/knift_float_1k.onnx',
  { executionProviders: ['webgl'] }
);
const detectorPromise = waitForCv.then(
  () => new cv.ORB(MAX_FEATURES, SCALE_FACTOR, PYRAMID_LEVEL)
);
const matcherPromise = waitForCv.then(() => new cv.BFMatcher());
const templatePromise = waitForCv
  .then(() => cardJson)
  .then(({
    size, loc, kp, desc
  }: any) => ({
    size,
    loc,
    desc: cv.matFromArray(
      400,
      40,
      cv.CV_32F,
      new Float32Array(
        Uint8Array.from(atob(desc), (c) => c.charCodeAt(0)).buffer
      )
    ),
    kp: new Float32Array(
      Uint8Array.from(atob(kp), (c) => c.charCodeAt(0)).buffer
    ),
  }));

const tesseractPromise = Tesseract.createWorker('eng').then(async (worker) => {
  await worker.setParameters({
    tessedit_pageseg_mode: Tesseract.PSM.SINGLE_WORD,
    tessedit_char_whitelist: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ',
  });
  return worker;
});

const ocrCard = async (imageSource: any) => {
  const [ortSession, template, detector, matcher, ocr] = await Promise.all([
    ortPromise,
    templatePromise,
    detectorPromise,
    matcherPromise,
    tesseractPromise,
  ]);

  return ocrCard2(imageSource, template, detector, matcher, ortSession, ocr);
};

export default ocrCard;

// (async () => {
//   await Promise.all([
//     ortPromise,
//     templatePromise,
//     detectorPromise,
//     matcherPromise,
//     tesseractPromise,
//   ]);
//   console.log("Initialized");
// })();
