import { gql } from '@apollo/client';
import _ from 'lodash';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom';

import ComponentLoading from '../../../../../../components/common/ComponentLoading';
import CoveModal from '../../../../components/CoveModal';
import TrainingPipelineLoading from '../../components/TrainingPipelineLoading';

import {
  GQLRuleStatus,
  useGQLCreateCoveModelSignalCreationSamplingJobMutation,
  useGQLCreateRulesForModelThresholdsMutation,
  useGQLGetCoveModelSignalCreationSessionResultLazyQuery,
  useGQLGetSamplingJobResultsLazyQuery,
  useGQLPoliciesQuery,
  useGQLTrainAndSaveModelMutation,
} from '../../../../../../graphql/generated';
import { safePick } from '../../../../../../utils/misc';
import { JsonOf, jsonParse } from '../../../../../../utils/typescript-types';
import { ModalInfo } from '../../../../types/ModalInfo';
import TrainingPipelineFooter from '../../TrainingPipelineFooter';
import { useGoToScreen, type TrainingPipelineEnabledScreen } from '../types';
import TrainingPipelineSetThresholdsDragAndDropComponent, {
  Threshold,
} from './TrainingPipelineSetThresholdsDragAndDropComponent';
import TrainingPipelineSetThresholdsExplanationComponent from './TrainingPipelineSetThresholdsExplanationComponent';

gql`
  mutation TrainAndSaveModel($sessionId: CoveModelSignalCreationSessionId!) {
    finishCoveModelSignalCreationSession(sessionId: $sessionId)
  }

  mutation CreateRulesForModelThresholds(
    $input: CreateRulesForCoveModelInput!
  ) {
    createRulesForModel(input: $input) {
      rules {
        id
      }
    }
  }

  query GetCoveModelSignalCreationSessionResult(
    $sessionId: CoveModelSignalCreationSessionId!
  ) {
    getCoveModelSignalCreationSessionResult(sessionId: $sessionId) {
      ... on CoveModelSignalCreatedSuccess {
        signal {
          id
        }
      }
      ... on CoveModelSignalCreatedPending {
        _
      }
      ... on CoveModelSignalCreationError {
        title
      }
    }
  }
`;

/**
 * This component does the following:
 * 1. First, we train and save a new model using the labels submitted on the previous screen
 *    (these labels were already persisted to the DB in the previous screen so we don't
 *    need them here)
 * 2. When the model is deployed and ready for inference, we create a handful of parallel sampling
 *    jobs with the MODEL_SCORE sampling strategy, each of which runs lots of content through the
 *    model and returns that content - along with its model score - so we can display it in descending
 *    order of score.
 * 3. The user then selects what thresholds they want to set based on those scores.
 * 4. If the user saves the thresholds, we turn them into background rules.
 */
export default function TrainingPipelineSetThresholdsScreen(props: {
  nextScreen: TrainingPipelineEnabledScreen;
  previousScreen: TrainingPipelineEnabledScreen;
}) {
  const { nextScreen, previousScreen } = props;
  const goToScreen = useGoToScreen();
  const [searchParams] = useSearchParams();
  const policyId = searchParams.get('policyId') ?? '';
  const sessionId = searchParams.get('sessionId') ?? '';
  const jobId = searchParams.get('jobId') ?? '';

  const navigate = useNavigate();

  const [thresholds, setThresholds] = useState<Threshold[]>([]);
  const [modelSignalId, setModelSignalId] = useState<
    { id: string; version: number } | undefined
  >(undefined);
  const [modalInfo, setModalInfo] = useState<ModalInfo>({
    visible: false,
    title: '',
    body: '',
    okText: '',
    onOk: () => {},
    okIsDangerButton: false,
    cancelVisible: false,
  });

  const { data, loading: policiesLoading, error } = useGQLPoliciesQuery();

  if (!policyId) {
    throw new Error('No policy ID provided');
  }

  const policy = data?.myOrg?.policies.find((it) => it.id === policyId);
  if (data?.myOrg && !policy) {
    throw new Error(`Could not find policy with id ${policyId}`);
  }
  const policyName = policy?.name;

  const hideModal = () => setModalInfo({ ...modalInfo, visible: false });

  /* Queries and mutations for training and deploying the model */
  const [
    getSessionResult,
    { data: sessionResultData, stopPolling: stopPollingSessionResult },
  ] = useGQLGetCoveModelSignalCreationSessionResultLazyQuery({
    pollInterval: 5000,
  });
  const [
    trainAndSaveModel,
    {
      loading: trainAndSaveModelLoading,
      error: trainAndSaveModelError,
      data: trainAndSaveModelData,
    },
  ] = useGQLTrainAndSaveModelMutation({
    onCompleted: async () =>
      getSessionResult({
        variables: {
          sessionId,
        },
      }),
    onError: () =>
      setModalInfo({
        visible: true,
        title: 'Failed to save fine-tuned model',
        body: `We encountered an error while trying to save your fine-tuned ${policyName} AI model. Would you like to try again?`,
        okText: 'Try Again',
        onOk: () => {
          trainAndSaveModel({
            variables: {
              sessionId,
            },
          });
          hideModal();
        },
        okIsDangerButton: false,
        cancelVisible: false,
      }),
  });

  // Train and save model right when the session ID is available
  useEffect(() => {
    if (
      trainAndSaveModelError ||
      trainAndSaveModelLoading ||
      trainAndSaveModelData
    ) {
      return;
    }

    trainAndSaveModel({
      variables: {
        sessionId,
      },
    });
  }, [
    sessionId,
    trainAndSaveModel,
    trainAndSaveModelData,
    trainAndSaveModelError,
    trainAndSaveModelLoading,
  ]);

  // See if the model has finished training and deploying, and if it has, save
  // the deployed model ID and version and stop polling for the training status.
  useEffect(() => {
    if (!sessionResultData?.getCoveModelSignalCreationSessionResult) {
      return;
    }

    switch (
      sessionResultData.getCoveModelSignalCreationSessionResult.__typename
    ) {
      case 'CoveModelSignalCreatedPending':
        break;
      case 'CoveModelSignalCreatedSuccess':
        stopPollingSessionResult();
        const signalId = jsonParse(
          sessionResultData.getCoveModelSignalCreationSessionResult.signal
            .id as JsonOf<{ type: 'COVE_MODEL'; id: string; version: number }>,
        );
        setModelSignalId(safePick(signalId, ['id', 'version']));
        break;
      case 'CoveModelSignalCreationError':
        stopPollingSessionResult();
        setModalInfo({
          visible: true,
          title: 'Failed to save fine-tuned model',
          body: `We encountered an error while trying to save your fine-tuned ${policyName} AI model. Would you like to try again?`,
          okText: 'Try Again',
          onOk: () => {
            trainAndSaveModel({
              variables: {
                sessionId,
              },
            });
            hideModal();
          },
          okIsDangerButton: false,
          cancelVisible: false,
        });
    }
    // We don't need hideModal in the dependency array
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [
    policyName,
    sessionId,
    sessionResultData?.getCoveModelSignalCreationSessionResult,
    stopPollingSessionResult,
    trainAndSaveModel,
  ]);

  const getSamplingJobVariables = useCallback(() => {
    return {
      variables: {
        input: {
          sessionId,
          numSamples: 100,
          samplingStrategy: {
            modelScoreSamplingStrategy: {
              score: {
                min: 0,
                max: 1,
              },
            },
          },
        },
      },
    };
  }, [sessionId]);

  /* Queries and mutations for sampling scores */
  const [createSamplingJob, { data: samplingJobCreatedData }] =
    useGQLCreateCoveModelSignalCreationSamplingJobMutation({
      onError: () =>
        setModalInfo({
          visible: true,
          title: 'Failed to run backtest',
          body: `We encountered an error while trying to run a backtest on your ${policyName} AI model. Would you like to try again?`,
          okText: 'Try Again',
          onOk: () => {
            createSamplingJob(getSamplingJobVariables());
          },
          okIsDangerButton: false,
          cancelVisible: false,
        }),
    });

  const [createRulesForModel, { loading: createRulesForModelLoading }] =
    useGQLCreateRulesForModelThresholdsMutation();

  const initialSamplingJobId =
    samplingJobCreatedData?.createCoveModelSignalCreationSamplingJob.jobId;
  const [
    getSamplingJobResults,
    {
      data: samplingJobResultsData,
      stopPolling: stopPollingSamplingJobResults,
    },
  ] = useGQLGetSamplingJobResultsLazyQuery({
    variables: {
      jobId: initialSamplingJobId!,
    },
    pollInterval: 5000,
    onCompleted: (response) => {
      if (response.getSamplingJobResults.__typename === 'SamplingJobFailure') {
        setModalInfo({
          visible: true,
          title: 'Failed to run backtest',
          body: `We encountered an error while trying to run a backtest on your ${policyName} AI model. Would you like to try again?`,
          okText: 'Try Again',
          onOk: () => {
            createSamplingJob(getSamplingJobVariables());
            hideModal();
          },
          okIsDangerButton: false,
          cancelVisible: false,
        });
        stopPollingSamplingJobResults();
      } else if (
        response.getSamplingJobResults.__typename === 'SamplingJobNotFoundError'
      ) {
        setModalInfo({
          visible: true,
          title: 'Failed to run backtest',
          body: `We encountered an error while trying to run a backtest on your ${policyName} AI model. Would you like to try again?`,
          onOk: () => {
            createSamplingJob(getSamplingJobVariables());
            hideModal();
          },
          okText: 'Try Again',
          okIsDangerButton: false,
          cancelVisible: false,
        });
        stopPollingSamplingJobResults();
      }
    },
  });

  // Start sampling scores once the model has been deployed and we've
  // saved the model ID
  useEffect(() => {
    if (modelSignalId) {
      createSamplingJob(getSamplingJobVariables());
    }
  }, [createSamplingJob, getSamplingJobVariables, modelSignalId]);

  // Once the initial sampling job has been created, poll the server for the results
  useEffect(() => {
    if (initialSamplingJobId) {
      getSamplingJobResults();
    }
  }, [getSamplingJobResults, initialSamplingJobId]);

  const allSamples = useMemo(() => {
    // Don't look at any samples until all jobs have finished
    if (
      !samplingJobResultsData ||
      samplingJobResultsData.getSamplingJobResults.__typename ===
        'SamplingJobPending'
    ) {
      return undefined;
    } else if (
      samplingJobResultsData.getSamplingJobResults.__typename !==
      'SamplingJobSuccess'
    ) {
      // Error handling for these cases happens above in the getSamplingJobResults
      // onCompleted callback
      return [];
    }
    return samplingJobResultsData.getSamplingJobResults.samples;
  }, [samplingJobResultsData]);

  const loading =
    trainAndSaveModelLoading ||
    !sessionResultData ||
    sessionResultData.getCoveModelSignalCreationSessionResult.__typename ===
      'CoveModelSignalCreatedPending' ||
    !initialSamplingJobId ||
    !allSamples;

  const [loadingTitle, loadingSubtitle] = (() => {
    if (
      trainAndSaveModelLoading ||
      !sessionResultData ||
      sessionResultData.getCoveModelSignalCreationSessionResult.__typename ===
        'CoveModelSignalCreatedPending'
    ) {
      return [
        `Fine-tuning your ${policyName} model`,
        "We're teaching the model to make the exact same decisions you would make.",
      ];
    } else if (!initialSamplingJobId || !allSamples) {
      return [`Testing your ${policyName} model on real content`, ''];
    }
    return ['', ''];
  })();

  if (policiesLoading) {
    return <ComponentLoading />;
  }

  if (error) {
    throw error;
  }

  const modal = (
    <CoveModal
      title={modalInfo.title}
      visible={modalInfo.visible}
      onClose={hideModal}
      footer={[
        ...(modalInfo.cancelText
          ? [
              {
                title: modalInfo.cancelText,
                onClick: modalInfo.onCancel,
                type: 'secondary' as const,
              },
            ]
          : []),
        {
          title: modalInfo.okText,
          onClick: modalInfo.onOk,
          type: 'primary',
        },
      ]}
      hideCloseButton
    >
      <div className="pt-2">{modalInfo.body}</div>
    </CoveModal>
  );

  const createRules = (status: GQLRuleStatus) => {
    if (!modelSignalId) {
      setModalInfo({
        visible: true,
        title: 'Failed to load model identifier',
        body: `We encountered an error while trying to load your model. Please reload the page and try again.`,
        okText: 'OK',
        onOk: () => {
          hideModal();
        },
        okIsDangerButton: false,
        cancelVisible: false,
      });
      return;
    }

    createRulesForModel({
      variables: {
        input: {
          model: modelSignalId,
          status,
          // Sort the thresholds in descending order to ensure that we're
          // setting the correct upper and lower bounds. To set the
          // upperThrehsold, we derive the value from the previous threshold in
          // the list, so the list must be in the correct (descending) order.
          thresholds: _.reverse(_.sortBy(thresholds, 'score')).map(
            ({ score, action }, idx) => ({
              upperThreshold:
                idx === 0 ? 1.0 : Number(thresholds[idx - 1].score.toFixed(3)),
              lowerThreshold: Number(score.toFixed(3)),
              action,
            }),
          ),
        },
      },
      onCompleted: () => goToScreen(nextScreen, policyId, sessionId),
      onError: () =>
        setModalInfo({
          visible: true,
          title: 'Failed to create rules for model',
          body: `Please try again.`,
          okText: 'OK',
          onOk: () => {
            hideModal();
          },
          okIsDangerButton: false,
          cancelVisible: false,
        }),
    });
  };

  return (
    <div className="flex flex-col items-center h-full text-start">
      {loading ? (
        <TrainingPipelineLoading
          title={loadingTitle}
          subtitle={loadingSubtitle}
          loading={loading}
        />
      ) : (
        <div className="flex flex-col items-center m-8">
          <div className="w-3/5">
            <div className="text-2xl font-semibold">
              Edit thresholds for your {policyName} policy
            </div>
            <br />
            <div className="text-slate-500">
              Your AI model can make automated decisions, but it needs to know
              when you're comfortable automatically deleting content, and when
              you'd prefer a human to review it first. <br />
              <br />
              You can set those thresholds below, but first we need one piece of
              information to help you make those decisions.
            </div>
            <br />
            {/* We can comment this back in when we know how to accurately estimate
              the amount of content that would fall in each bucket */}
            {/* <div className="mt-4 mb-4 text-base font-semibold">
              How much content is created on your platform every day?
            </div>
            <div className="flex flex-row items-center justify-start p-4 mb-10 border border-solid rounded-md border-slate-200 w-fit">
              <InputNumber
                className="h-auto text-base border-none rounded-full shadow-none outline-none cursor-text active:shadow-none focus:shadow-none text-slate-500 placeholder-slate-300"
                style={{
                  width: (totalContent?.toString().length ?? 0 + 1) * 15,
                  minWidth: '40px',
                }}
                placeholder="0"
                min={0}
                formatter={(value) => Number(value).toLocaleString()}
                controls={false}
                value={totalContent}
                onChange={(value) => setTotalContent(Number(value))}
              />
              <div className="text-slate-500">
                total pieces of content / day
              </div>
            </div> */}
            <div className="mt-4 mb-4 text-lg font-semibold">
              Drag and drop your thresholds
            </div>
            <TrainingPipelineSetThresholdsExplanationComponent />
          </div>
          {allSamples ? (
            <div className="flex flex-col w-full py-8 pt-16">
              <TrainingPipelineSetThresholdsDragAndDropComponent
                entries={allSamples.map((it) => ({
                  ...it,
                  score: it.score!,
                }))}
                thresholds={thresholds}
                setThresholds={setThresholds}
              />
            </div>
          ) : (
            <ComponentLoading />
          )}
        </div>
      )}
      <TrainingPipelineFooter
        primaryButton={{
          title: 'Deploy',
          onClick: () => createRules('LIVE'),
          loading: createRulesForModelLoading,
          disabled: loading,
        }}
        secondaryButtons={[
          {
            title: 'Save as draft',
            onClick: () => createRules('DRAFT'),
            loading: createRulesForModelLoading,
            disabled: loading,
          },
          {
            title: 'Save in background',
            onClick: () => createRules('BACKGROUND'),
            loading: createRulesForModelLoading,
            disabled: loading,
          },
        ]}
        cancelButton={{
          title: 'Discard Model',
          onClick: () => navigate('/dashboard/policies'),
          disabled: loading || createRulesForModelLoading,
        }}
        onBack={() => goToScreen(previousScreen, policyId, sessionId, jobId)}
      />
      {modal}
    </div>
  );
}
