import * as tf from '@tensorflow/tfjs';

import { generateDataset } from '../utils/transformers';
import { roundNumber } from '../utils/helpers';
import { makeRequest } from '../../../lib/infrastructure';

import { L2 } from '../utils/tensorflow';

export const generatePrediction = async ({ modelFile }) => {
  try {
    const dataset = await generateDataset(modelFile);

    if (dataset) {
      tf.serialization.registerClass(L2);
      const trainingToPredict = tf.tensor(dataset);
      const model = await tf.loadLayersModel('/model/model.json');
      const prediction = model.predict(trainingToPredict, {
        batchSize: trainingToPredict.length,
      });
      return prediction.arraySync();
    }

    return [];
  } catch (error) {
    console.log('error: ', error);
  }
};

export const getStressModels = async userId =>
  makeRequest({
    path: `/stressmodels?${new URLSearchParams({
      userId,
    })}`,
  });

export const getStressModel = async stressModelId =>
  makeRequest({
    path: `/stressmodels/${stressModelId}`,
  });

export const createStressModel = async stressModelConfig =>
  makeRequest({
    path: `/stressmodels`,
    method: 'POST',
    body: stressModelConfig,
  });

export const removeStressModel = async stressModelId =>
  makeRequest({
    path: `/stressmodels/${stressModelId}`,
    method: 'DELETE',
  });

export const calculateIfPredictionStressedOrNot = prediction => {
  const stressedNumbers = prediction.filter(item => roundNumber(item) === 1)
    .length;
  const notStressedNumbers = prediction.filter(item => roundNumber(item) === 0)
    .length;

  switch (true) {
    case stressedNumbers > notStressedNumbers:
      return 'stressed';
    case stressedNumbers < notStressedNumbers:
      return 'not-stressed';
    default:
      return 'neutral';
  }
};
