import {
  namedOperations,
  useGQLDeleteLabelsMutation,
  useGQLGetLabelsForPolicyLazyQuery,
  useGQLInternalModelTrainingPageDataQuery,
  useGQLSetLabelsForItemsMutation,
  type GQLGetLabelsForPolicyQuery,
} from '@/graphql/generated';
import { TrashCan } from '@/icons';
import { jsonStringify } from '@/utils/typescript-types';
import { gql } from '@apollo/client';
import { useEffect, useMemo, useState } from 'react';
import { Navigate } from 'react-router-dom';

import ComponentLoading from '@/components/common/ComponentLoading';
import CoveSelect from '@/components/common/CoveSelect';
import FullScreenLoading from '@/components/common/FullScreenLoading';
import CoveBadge from '@/webpages/dashboard/components/CoveBadge';
import CoveButton from '@/webpages/dashboard/components/CoveButton';
import CoveModal from '@/webpages/dashboard/components/CoveModal';
import Table from '@/webpages/dashboard/components/table/Table';

import InternalModelTrainingAddLabelModal from './InternalModelTrainingAddLabelModal';
import { getFieldAtJsonPointer } from './internalModelTrainingUtils';

gql`
  query InternalModelTrainingPageData {
    myOrg {
      id
      email
      name
      policies {
        id
        name
      }
    }
    me {
      id
      email
    }
  }

  query GetLabelsForPolicy($input: GetLabelsForPolicyInput!) {
    getLabelsForPolicy(input: $input) {
      items {
        id
        data
        itemType {
          id
          version
          schemaVariant
        }
      }
      labels {
        id
        itemId
        itemFieldJsonPointers
        violatesPolicy
        labelerId
        labelerType
      }
    }
  }

  mutation SetLabelsForItems($input: SetLabelsForItemsInput!) {
    setLabelsForItems(input: $input)
  }

  mutation DeleteLabels($labelIds: [ID!]!) {
    deleteLabels(labelIds: $labelIds)
  }
`;

export default function InternalModelTraining() {
  const [selectedPolicyId, setSelectedPolicyId] = useState<string | undefined>(
    undefined,
  );
  const [showErrorModal, setShowErrorModal] = useState(false);
  const [labelIdToSwap, setLabelIdToSwap] = useState<string | undefined>(
    undefined,
  );
  const [labelIdToDelete, setLabelIdToDelete] = useState<string | undefined>(
    undefined,
  );
  const [showAddLabelModal, setShowAddLabelModal] = useState(false);

  const [
    getLabelsForPolicy,
    { loading: getLabelsForPolicyLoading, data: getLabelsForPolicyData },
  ] = useGQLGetLabelsForPolicyLazyQuery();
  const { data, loading, error } = useGQLInternalModelTrainingPageDataQuery();

  const [setLabels, { loading: setLabelsLoading }] =
    useGQLSetLabelsForItemsMutation({
      refetchQueries: [namedOperations.Query.GetLabelsForPolicy],
      onError: () => setShowErrorModal(true),
    });

  const [deleteLabels, { loading: deleteLabelsLoading }] =
    useGQLDeleteLabelsMutation({
      refetchQueries: [namedOperations.Query.GetLabelsForPolicy],
      onCompleted: () => setLabelIdToDelete(undefined),
      onError: () => setShowErrorModal(true),
    });

  useEffect(() => {
    if (selectedPolicyId) {
      getLabelsForPolicy({
        variables: {
          input: {
            policyId: selectedPolicyId,
          },
        },
      });
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [selectedPolicyId]);

  const columns = [
    { Header: 'Item', accessor: 'item' },
    { Header: 'Label', accessor: 'label' },
    {
      Header: '',
      accessor: 'mutations',
      canSort: false,
    },
  ] as const;
  const dataValues = useMemo(
    () =>
      getLabelsForPolicyData?.getLabelsForPolicy.labels
        .filter((it) => it.itemFieldJsonPointers[0] !== '')
        .map((it) => ({
          id: it.id,
          item: getLabelsForPolicyData?.getLabelsForPolicy.items.find(
            (item) => item.id === it.itemId,
          )?.data,
          violatingField: it.itemFieldJsonPointers[0],
          label: it.violatesPolicy
            ? ('VIOLATES' as const)
            : ('DOES NOT VIOLATE' as const),
        })),
    [getLabelsForPolicyData],
  );
  const tableData = useMemo(() => {
    const switchLabel = (labelId: string) => {
      const labeledItem =
        getLabelsForPolicyData?.getLabelsForPolicy.labels.find(
          (it) => it.id === labelId,
        );
      if (!labeledItem) {
        return;
      }

      setLabelIdToSwap(labelId);

      setLabels({
        variables: {
          input: {
            labels: [
              {
                id: labelId,
                violatesPolicy: !labeledItem.violatesPolicy,
              },
            ],
          },
        },
        onCompleted: () => {
          setLabelIdToSwap(undefined);

          // Unfortunately we have to handle a bit of somewhat complex logic
          // here. Basically, the way we label items is that each field in the
          // data gets its own label, and the parent item is violating if any of
          // its fields are violating, and not violating if none of its fields
          // are violating.
          //
          // This means that when we swap a label, we need to check to see if we
          // need to swap the parent label as well. To do that, we check that
          // one of two conditions are met:
          // 1. All individual fields are non-violating but the item as a whole
          //    is still marked as violating.
          // 2. Any individual field is violating but the item as a whole is
          //    marked as not violating.
          //
          // NB: This all happens in the background when we swap a label, we
          // won't keep the loading indicator on while we do this and this will
          // be best effort (as it would be non-trivial to implement rolling
          // back a given label if the parent label swap fails).
          const itemIdForLabel = labeledItem.itemId;
          const parentLabelNeedsToBeSwapped = (() => {
            const labelsForItem =
              getLabelsForPolicyData?.getLabelsForPolicy.labels.filter(
                (it) => it.itemId === itemIdForLabel,
              );

            if (!labelsForItem || labelsForItem.length === 0) {
              return false;
            }

            // Return true if all individual fields are non-violating but the
            // item as a whole is still marked as violating OR any individual
            // field is violating but the item as a whole is not violating.
            return labelsForItem.every(
              (it) =>
                (it.itemFieldJsonPointers[0] === '' && it.violatesPolicy) ||
                (it.itemFieldJsonPointers[0] !== '' && !it.violatesPolicy),
            );
          })();

          if (parentLabelNeedsToBeSwapped) {
            const parentLabelToSwap =
              getLabelsForPolicyData?.getLabelsForPolicy.labels.find(
                (it) =>
                  it.itemId === itemIdForLabel &&
                  it.itemFieldJsonPointers[0] === '',
              );

            if (parentLabelToSwap) {
              setLabels({
                variables: {
                  input: {
                    labels: [
                      {
                        id: parentLabelToSwap.id,
                        violatesPolicy: !parentLabelToSwap.violatesPolicy,
                      },
                    ],
                  },
                },
              });
            }
          }
        },
        onError: () => {
          setShowErrorModal(true);
          setLabelIdToSwap(undefined);
        },
      });
    };

    return dataValues?.map((it) => ({
      item: (
        <div>
          {it.item
            ? (() => {
                const fieldData = getFieldAtJsonPointer(
                  it.item,
                  it.violatingField,
                );

                if (
                  typeof fieldData === 'object' &&
                  fieldData !== null &&
                  'url' in fieldData &&
                  typeof fieldData.url === 'string'
                ) {
                  return (
                    <a href={fieldData.url} target="_blank" rel="noreferrer">
                      {fieldData.url}
                    </a>
                  );
                }

                return jsonStringify(
                  getFieldAtJsonPointer(it.item, it.violatingField),
                );
              })()
            : undefined}
        </div>
      ),
      label: (
        <CoveBadge
          colorVariant={it.label === 'VIOLATES' ? 'soft-red' : 'soft-green'}
          shapeVariant="pill"
          label={it.label}
        />
      ),
      mutations: (
        <LabelMutationButtons
          isViolating={it.label === 'VIOLATES'}
          onClickFlipLabel={() => switchLabel(it.id)}
          flipLabelLoading={it.id === labelIdToSwap && setLabelsLoading}
          onClickDelete={() => setLabelIdToDelete(it.id)}
        />
      ),
    }));
  }, [
    dataValues,
    getLabelsForPolicyData?.getLabelsForPolicy,
    labelIdToSwap,
    setLabels,
    setLabelsLoading,
  ]);

  if (loading) {
    return <FullScreenLoading />;
  }

  if (error) {
    throw error;
  }

  const userEmail = data?.me?.email;
  if (
    !userEmail?.includes('@getcove.com') &&
    !userEmail?.includes('@example.com') &&
    !userEmail?.includes('@protegoapi.com')
  ) {
    return <Navigate replace to="/" />;
  }

  const policies = data?.myOrg?.policies;

  return (
    <div className="flex flex-col p-8 gap-4 w-fit">
      <div className="text-2xl font-bold">Model Training Label Debugging</div>
      {policies && policies.length > 0 ? (
        <div className="flex flex-col gap-2">
          <div className="text-sm font-bold">Policy</div>
          <div className="w-48">
            <CoveSelect
              placeholder="Select a policy"
              value={selectedPolicyId}
              onSelect={setSelectedPolicyId}
              options={
                policies
                  .map((policy) => ({
                    value: policy.id,
                    label: policy.name,
                  }))
                  .sort((a, b) => a.label.localeCompare(b.label)) ?? []
              }
            />
          </div>
        </div>
      ) : (
        <div className="text-sm text-gray-600">
          This organization has no policies.
        </div>
      )}
      {getLabelsForPolicyLoading ? (
        <ComponentLoading />
      ) : getLabelsForPolicyData?.getLabelsForPolicy && tableData ? (
        <div className="flex flex-col">
          <LabelStats
            labels={getLabelsForPolicyData.getLabelsForPolicy.labels}
            onClickAddLabel={() => setShowAddLabelModal(true)}
          />
          <Table columns={columns} data={tableData} />
        </div>
      ) : undefined}
      <CoveModal
        visible={showErrorModal}
        onClose={() => setShowErrorModal(false)}
        title="Error"
      >
        <div>An error occurred</div>
      </CoveModal>
      <CoveModal
        visible={labelIdToDelete != null}
        onClose={() => setLabelIdToDelete(undefined)}
        title="Delete Label"
        footer={[
          {
            title: 'Cancel',
            onClick: () => setLabelIdToDelete(undefined),
            type: 'secondary',
          },
          {
            title: 'Delete',
            onClick: async () => {
              await deleteLabels({
                variables: { labelIds: [labelIdToDelete!] },
              });
            },
            type: 'primary',
            loading: deleteLabelsLoading,
          },
        ]}
      >
        <div>Are you sure you want to delete this label?</div>
      </CoveModal>
      <InternalModelTrainingAddLabelModal
        visible={showAddLabelModal}
        onSave={() => setShowAddLabelModal(false)}
        onClose={() => setShowAddLabelModal(false)}
      />
    </div>
  );
}

function LabelMutationButtons(props: {
  isViolating: boolean;
  onClickFlipLabel: () => void;
  flipLabelLoading: boolean;
  onClickDelete: () => void;
}) {
  const { isViolating, onClickFlipLabel, onClickDelete, flipLabelLoading } =
    props;

  return (
    <div className="flex flex-row items-center justify-end w-full gap-2">
      <div
        onClick={onClickFlipLabel}
        title="Flip label"
        className="p-2 border border-solid cursor-pointer rounded-md"
      >
        {flipLabelLoading ? (
          <ComponentLoading />
        ) : isViolating ? (
          'Label as non-violating'
        ) : (
          'Label as violating'
        )}
      </div>
      <div
        onClick={onClickDelete}
        title="Delete label"
        className="p-2 cursor-pointer"
      >
        <TrashCan className="w-5 h-5 text-gray-600" />
      </div>
    </div>
  );
}

function LabelStats(props: {
  labels: GQLGetLabelsForPolicyQuery['getLabelsForPolicy']['labels'];
  onClickAddLabel: () => void;
}) {
  const { labels, onClickAddLabel } = props;
  const violatingSamples = labels.filter((it) => it.violatesPolicy);
  const nonViolatingSamples = labels.filter((it) => !it.violatesPolicy);

  return (
    <div className="flex flex-row justify-between p-4 bg-slate-300">
      <div className="flex flex-row pt-1 gap-4">
        <div className="flex flex-col">
          <div className="text-sm font-bold">Violating Samples</div>
          <div className="text-sm">
            {violatingSamples.length} (
            {((violatingSamples.length / labels.length) * 100).toFixed(2)}
            %)
          </div>
        </div>
        <div className="flex flex-col">
          <div className="text-sm font-bold">Non-Violating Samples</div>
          <div className="text-sm">
            {nonViolatingSamples.length} (
            {((nonViolatingSamples.length / labels.length) * 100).toFixed(2)}
            %)
          </div>
        </div>
      </div>
      <CoveButton title="Add Label" onClick={onClickAddLabel} />
    </div>
  );
}
