import {
  GQLViolatesPolicy,
  useGQLCreateCoveModelSignalCreationSamplingJobMutation,
  useGQLGetItemTypesByIdentifiersLazyQuery,
  useGQLGetSamplingJobResultsLazyQuery,
  useGQLOrgDataForModelTrainingQuery,
  useGQLSubmitLabelsMutation,
} from '@/graphql/generated';
import { ChevronDown, CornerDownLeft, LeftArrowBox } from '@/icons';
import { getPrimaryContentFields } from '@/utils/itemUtils';
import type { ItemTypeFieldFieldData } from '@/webpages/dashboard/item_types/itemTypeUtils';
import FieldsComponent from '@/webpages/dashboard/mrt/manual_review_job/v2/ManualReviewJobFieldsComponent';
import type { ModalInfo } from '@/webpages/dashboard/types/ModalInfo';
import { gql } from '@apollo/client';
import { isContainerType } from '@protego-api/types';
import _ from 'lodash';
import { useEffect, useMemo, useRef, useState } from 'react';
import { Navigate, useSearchParams } from 'react-router-dom';

import TrainingPipelineLoading from '../../components/TrainingPipelineLoading';
import TrainingPipelineSelectableFieldsComponent from '../../components/TrainingPipelineSelectableFieldsComponent';
import CoveModal from '@/webpages/dashboard/components/CoveModal';

import TrainingPipelineFooter from '../../TrainingPipelineFooter';
import { useGoToScreen, type TrainingPipelineEnabledScreen } from '../types';
import TrainingPipelineLabelingProgressComponent from './TrainingPipelineLabelingProgressComponent';
import {
  getLabeledItemsForSubmission,
  type LabeledSample,
} from './trainingPipelineLabelingUtils';
import TrainingPipelineLabelSamplesModal from './TrainingPipelineLabelSamplesModal';
import TrainingPipelinePolicyDrawer from './TrainingPipelinePolicyDrawer';

gql`
  mutation CreateCoveModelSignalCreationSamplingJob(
    $input: CreateCoveModelSignalCreationSamplingJobInput!
  ) {
    createCoveModelSignalCreationSamplingJob(input: $input) {
      jobId
    }
  }

  mutation SubmitLabels($input: LabelItemsInput!) {
    labelItems(input: $input) {
      _
    }
  }

  query GetSamplingJobResults($jobId: ID!) {
    getSamplingJobResults(jobId: $jobId) {
      ... on SamplingJobSuccess {
        samples {
          item {
            itemId
            itemType {
              id
              version
              schemaVariant
            }
            data
          }
          score
        }
        samplingStrategy
      }
      ... on SamplingJobFailure {
        title
      }
      ... on SamplingJobNotFoundError {
        title
      }
    }
  }
`;

function SampleCard(props: {
  sample: LabeledSample;
  primaryFields: ItemTypeFieldFieldData[];
  selectedFieldJsonPointers: string[];
  onChangeSelectedFields: (fieldJsonPointers: string[]) => void;
  options: { unblurAllMedia?: boolean };
}) {
  const {
    sample,
    primaryFields,
    selectedFieldJsonPointers,
    onChangeSelectedFields,
    options,
  } = props;

  const shouldShowSelectableFieldsComponent = (() => {
    if (primaryFields.length > 1) {
      return true;
    }
    if (primaryFields.length === 0) {
      return false;
    }
    const soleField = primaryFields[0];
    if (!isContainerType(soleField.type)) {
      return false;
    }
    // soleField is a container, so we want to return true iff the container has
    // more than one value that needs labeling.
    return Object.keys(soleField.value ?? []).length > 1;
  })();

  const content = (() => {
    if (primaryFields.length === 0) {
      return (
        <div>
          No primary fields found for this item with ID {sample.item.itemId}
        </div>
      );
    }

    if (shouldShowSelectableFieldsComponent) {
      return (
        <TrainingPipelineSelectableFieldsComponent
          fields={primaryFields}
          selectedFieldJsonPointers={selectedFieldJsonPointers}
          onChangeSelectedFields={onChangeSelectedFields}
          itemTypeId={sample.item.itemType.id}
          options={{
            hideLabels: primaryFields.length === 1,
            unblurAllMedia: options.unblurAllMedia,
          }}
        />
      );
    }

    return (
      <FieldsComponent
        fields={primaryFields}
        itemTypeId={sample.item.itemType.id}
        options={{
          hideLabels: true,
          unblurAllMedia: options.unblurAllMedia,
          transparentBackground: true,
        }}
      />
    );
  })();

  return (
    <div className="w-full p-6 h-[65vh] overflow-y-scroll scrollbar-hide text-lg border border-gray-200 border-solid rounded-md bg-white text-start">
      {content}
    </div>
  );
}

type UnbalancedLabelingModalInfo = {
  visible: boolean;
  hasPositiveLabels: boolean;
  labelingStage: 'partially_complete' | 'complete';
};

export default function TrainingPipelineLabelSamplesScreen(props: {
  nextScreen: TrainingPipelineEnabledScreen;
  previousScreen: TrainingPipelineEnabledScreen;
}) {
  const { nextScreen, previousScreen } = props;
  const goToScreen = useGoToScreen();
  const [searchParams] = useSearchParams();
  const policyId = searchParams.get('policyId') ?? '';
  const sessionId = searchParams.get('sessionId') ?? '';
  const jobId = searchParams.get('jobId') ?? undefined;

  const [state, setState] = useState<{
    allSamples: LabeledSample[];
    // This is a map from Item Id to an array of the JsonPointers of the
    // fields that the user has selected for that item.
    selectedSampleFields: Record<string, string[]>;
    currentSampleIndex: number;
    highlighedButton: 'none' | 'does_not_violate' | 'violates';
    unbalancedLabelingModalInfo: UnbalancedLabelingModalInfo;
    modalInfo: ModalInfo;
    jobId: string | undefined;
    moreOptionsVisible: boolean;
    policyDrawerOpen: boolean;
  }>({
    allSamples: [],
    selectedSampleFields: {},
    currentSampleIndex: 0,
    highlighedButton: 'none',
    unbalancedLabelingModalInfo: {
      visible: false,
      hasPositiveLabels: false,
      labelingStage: 'partially_complete',
    },
    modalInfo: {
      visible: false,
      title: '',
      body: '',
      okText: '',
      onOk: () => {},
      okIsDangerButton: false,
      cancelVisible: false,
    },
    jobId: undefined,
    moreOptionsVisible: false,
    policyDrawerOpen: false,
  });

  useEffect(() => {
    if (jobId) {
      setState((prev) => ({ ...prev, jobId }));
    }
  }, [jobId]);

  const showModal = (modalInfo: ModalInfo) =>
    setState((prev) => ({ ...prev, modalInfo }));
  const hideModal = () =>
    setState((prev) => ({
      ...prev,
      modalInfo: { ...prev.modalInfo, visible: false },
    }));
  const setUnbalancedLabelingModalInfo = (
    modalInfo: UnbalancedLabelingModalInfo,
  ) =>
    setState((prev) => ({ ...prev, unbalancedLabelingModalInfo: modalInfo }));

  const resetLabelingExercise = () =>
    setState((prev) => ({
      ...prev,
      allSamples: prev.allSamples.map((it) => ({
        ...it,
        violatesPolicy: GQLViolatesPolicy.False,
      })),
      currentSampleIndex: 0,
      unbalancedLabelingModalInfo: {
        visible: false,
        hasPositiveLabels: false,
        labelingStage: 'partially_complete',
      },
    }));

  const setSelectedFields = (selectedFields: Record<string, string[]>) =>
    setState((prev) => ({ ...prev, selectedSampleFields: selectedFields }));

  const [
    startSamplingJob,
    { loading: startSamplingJobLoading, error: startSamplingJobError },
  ] = useGQLCreateCoveModelSignalCreationSamplingJobMutation({
    onCompleted: (response) => {
      const jobId = response.createCoveModelSignalCreationSamplingJob.jobId;
      setState((prev) => ({
        ...prev,
        jobId,
      }));

      const queryParams = new URLSearchParams(window.location.search);
      queryParams.set('jobId', jobId);
      window.history.replaceState(
        null,
        '',
        `${window.location.pathname}?${queryParams.toString()}`,
      );
    },
    onError: () =>
      setState((prevState) => ({
        ...prevState,
        modalInfo: {
          visible: true,
          title: 'Failed to train model',
          body: `We encountered an error while trying to train your AI model. Would you like to try again?`,
          okText: 'Try Again',
          onOk: () => {
            setState((prev) => ({
              ...prev,
              modalInfo: {
                ...prev.modalInfo,
                visible: false,
              },
            }));

            startSamplingJob({
              variables: {
                input: {
                  sessionId,
                  numSamples: 200,
                  samplingStrategy: {
                    followModelGuidanceStrategy: {
                      _: true,
                    },
                  },
                },
              },
            });
          },
          okIsDangerButton: false,
          cancelVisible: false,
        },
      })),
  });

  useEffect(() => {
    if (startSamplingJobLoading || startSamplingJobError || jobId) {
      return;
    }

    startSamplingJob({
      variables: {
        input: {
          sessionId,
          numSamples: 200,
          samplingStrategy: {
            followModelGuidanceStrategy: {
              _: true,
            },
          },
        },
      },
    });
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [startSamplingJob, sessionId]);

  // Poll the server every 5 seconds to see if the sampling job is done.
  // When it's done, the query will return the samples to be labeled.
  // Note: the query is lazy because we have to be able to retry it
  const [
    getSamplingJobResults,
    {
      loading: samplingJobResultsLoading,
      error: samplingJobResultsError,
      data: samplingJobResultsData,
      stopPolling,
    },
  ] = useGQLGetSamplingJobResultsLazyQuery({
    variables: {
      jobId: state.jobId!,
    },
    pollInterval: 5000,
    onCompleted: (response) => {
      if (response.getSamplingJobResults.__typename === 'SamplingJobFailure') {
        showModal({
          visible: true,
          title: 'Error fetching items',
          body: 'We encountered an issue while trying to fetch items for you to label. Would you like to try again?',
          okText: 'Try Again',
          onOk: () => {
            getSamplingJobResults();
            hideModal();
          },
          okIsDangerButton: false,
          cancelVisible: false,
        });
        stopPolling();
      } else if (
        response.getSamplingJobResults.__typename === 'SamplingJobNotFoundError'
      ) {
        showModal({
          visible: true,
          title: 'Error fetching items',
          body: 'We encountered an issue while trying to fetch items for you to label. Please go back and save your policy definition again.',
          okText: 'Go Back',
          onOk: () => goToScreen(previousScreen, policyId, sessionId),
          okIsDangerButton: false,
          cancelVisible: false,
        });
        stopPolling();
      }
    },
  });

  const {
    data: orgData,
    loading: orgDataLoading,
    error: orgDataError,
  } = useGQLOrgDataForModelTrainingQuery();

  const [getItemTypes, { data: itemTypesData }] =
    useGQLGetItemTypesByIdentifiersLazyQuery();
  const itemTypes = itemTypesData?.itemTypes;

  // Because the getSamplingJobResults query is lazy, we have to trigger it once
  // when the component loads
  useEffect(() => {
    if (state.jobId !== undefined) {
      getSamplingJobResults();
    }
  }, [getSamplingJobResults, state.jobId]);

  // Once the samples query finishes, stop polling and initialize the selected
  // field map
  useEffect(() => {
    if (
      samplingJobResultsData &&
      samplingJobResultsData.getSamplingJobResults.__typename !==
        'SamplingJobPending'
    ) {
      stopPolling();

      if (
        samplingJobResultsData.getSamplingJobResults.__typename ===
        'SamplingJobSuccess'
      ) {
        // Initialize the selected field map with an empty array for each sample
        // item. This will be used to store the fields that the user selects for
        // each sample. Specifically, the keys of selectedFields are item IDs
        // and the values are arrays of the JsonPointers of the fields that
        // the user selected.
        setSelectedFields(
          samplingJobResultsData.getSamplingJobResults.samples.reduce(
            (acc, it) => {
              acc[it.item.itemId] = [] as string[];
              return acc;
            },
            {} as Record<string, string[]>,
          ),
        );
      }
    }
  }, [samplingJobResultsData, stopPolling]);

  /*
    This check is to prevent users from trying to submit labels for a model
    where they found only positive results or only negative results. Without
    samples for both, the model won't be able to perform properly. In order to
    handle this, we check if the user has labeled both positive and negative
    samples at two particular points: 50% completion and 100% completion.

    In the 50% completion case, we show a modal that asks the user to either
    restart labeling or redefine their policy if they haven't seen any positive
    samples. If they haven't, give them the option to edit the policy text or
    restart labeling. In the 100% completion case, we tell users to contact us
    to help them set up models, because we won't be able to train a model with
    all positive labels (or all negative labels, but that's handled in the 50%
    completion check).
  */
  const hasPositiveLabels = state.allSamples.some(
    (sample) => sample.violatesPolicy === 'TRUE',
  );
  const hasNegativeLabels = state.allSamples.some(
    (sample) => sample.violatesPolicy === 'FALSE',
  );
  useEffect(() => {
    const { allSamples, currentSampleIndex } = state;

    // When we have this few samples, we won't be able to make a model regardless
    if (allSamples.length < 25) {
      return;
    }

    // When were 50% through the samples, check to see if we've found any
    // positive samples, and if not, prompt the user to confirm
    if (currentSampleIndex === Math.floor(allSamples.length / 2)) {
      // If we have positive labels, aka violating content, we let the user continue
      if (hasPositiveLabels) {
        return;
      }

      setUnbalancedLabelingModalInfo({
        visible: true,
        hasPositiveLabels,
        labelingStage: 'partially_complete',
      });
    }

    // When we're through all the samples, if we still don't have any negative
    // labels, prompt the user to get in touch with us
    if (currentSampleIndex >= allSamples.length - 1) {
      // If we have negative labels, aka non-violating content, we let the user continue
      if (hasNegativeLabels) {
        return;
      }

      setUnbalancedLabelingModalInfo({
        visible: true,
        hasPositiveLabels,
        labelingStage: 'complete',
      });
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [hasPositiveLabels, hasNegativeLabels, state.currentSampleIndex]);

  const { samples, samplingStrategy } =
    samplingJobResultsData?.getSamplingJobResults.__typename ===
    'SamplingJobSuccess'
      ? samplingJobResultsData.getSamplingJobResults
      : { samples: [], samplingStrategy: undefined };

  // Once we've finished fetching samples, set it to the allSamples state and
  // fetch the relevant item type versions
  useEffect(() => {
    if (samples.length > 0 && state.allSamples.length === 0) {
      setState((prev) => ({
        ...prev,
        allSamples: samples.map((sample) => ({
          ...sample,
          violatesPolicy: GQLViolatesPolicy.False,
        })),
      }));
      getItemTypes({
        variables: {
          identifiers: _.uniqBy(
            samples.map((it) => ({
              id: it.item.itemType.id,
              version: it.item.itemType.version,
              schemaVariant: 'ORIGINAL',
            })),
            (it) => `${it.id}-${it.version}`,
          ),
        },
      });
    }
  }, [getItemTypes, samples, state.allSamples.length]);

  const itemType = useMemo(() => {
    if (
      state.allSamples.length === 0 ||
      state.currentSampleIndex >= samples.length ||
      !itemTypes
    ) {
      return undefined;
    }
    return itemTypes.find(
      (it) =>
        it.id === state.allSamples[state.currentSampleIndex].item.itemType.id &&
        it.version ===
          state.allSamples[state.currentSampleIndex].item.itemType.version,
    );
  }, [itemTypes, samples.length, state.allSamples, state.currentSampleIndex]);

  const [submitLabels, { loading: submitLabelsLoading }] =
    useGQLSubmitLabelsMutation();

  const assignLabelToSample = (
    index: number,
    violatesPolicy: GQLViolatesPolicy,
  ) =>
    setState((prev) => {
      const newSamples = [...prev.allSamples];
      newSamples.splice(index, 1, {
        ...prev.allSamples[index],
        violatesPolicy,
      });
      return {
        ...state,
        allSamples: newSamples,
        currentSampleIndex: state.currentSampleIndex + 1,
        highlighedButton: 'none',
      };
    });
  const setHighlighedButton = (
    button: 'none' | 'does_not_violate' | 'violates',
  ) => setState({ ...state, highlighedButton: button });

  // Create a reference to the div element containing the buttons that need to
  // be triggered with the arrow keys.
  const containerRef = useRef<HTMLDivElement>(null);

  // Set focus to the div element on component mount. This is required to be
  // able to use the arrow keys to trigger the buttons.
  useEffect(() => {
    if (containerRef.current) {
      containerRef.current.focus();
    }
  }, [state.allSamples.length]);

  if (samplingJobResultsError) {
    throw samplingJobResultsError;
  }

  if (orgDataError) {
    throw orgDataError;
  }

  if (!policyId || !sessionId) {
    return <Navigate replace to="/dashboard/models_and_policies/policies" />;
  }
  const isLoading =
    state.jobId == null ||
    samplingJobResultsLoading ||
    orgDataLoading ||
    samplingJobResultsData?.getSamplingJobResults.__typename ===
      'SamplingJobPending';

  if (isLoading) {
    return (
      <div className="flex flex-col items-center justify-between h-full">
        <TrainingPipelineLoading
          title="Creating your Custom AI model"
          subtitle="This can take up to 5 minutes. Please don't close this window."
          loading={true}
        />
      </div>
    );
  }

  const policy = orgData?.myOrg?.policies.find((it) => it.id === policyId);
  if (policy == null) {
    throw new Error(`Policy with ID ${policyId} not found`);
  }

  const primaryFields = samples[state.currentSampleIndex]
    ? getPrimaryContentFields(
        itemType?.baseFields ?? [],
        samples[state.currentSampleIndex].item.data,
      ).filter(
        (it) =>
          it.value != null &&
          !(Array.isArray(it.value) && it.value.length === 0),
      )
    : [];
  const shouldDisableDoesNotViolateButton =
    primaryFields.length > 1 &&
    state.selectedSampleFields[samples[state.currentSampleIndex].item.itemId]
      ?.length > 0;
  const shouldDisableViolatesButton =
    primaryFields.length > 1 &&
    state.selectedSampleFields[samples[state.currentSampleIndex].item.itemId]
      ?.length < 1;

  const violatesButton = (
    <div
      className={`flex bg-white items-center justify-center p-4 border-gray-200 border-solid rounded-md border select-none gap-2 ${
        shouldDisableViolatesButton
          ? 'cursor-not-allowed fill-gray-300'
          : // Need to include this so that the button is highlighted when the user presses the right arrow key
          state.highlighedButton === 'violates'
          ? '!bg-red-300 !border-red-100'
          : 'cursor-pointer fill-gray-500  hover:bg-red-50 hover:border-red-100 !active:bg-red-300 !active:border-red-100'
      }`}
      onClick={() =>
        !shouldDisableViolatesButton &&
        assignLabelToSample(state.currentSampleIndex, 'TRUE')
      }
    >
      <div
        className={`font-semibold ${
          shouldDisableViolatesButton ? 'text-gray-300' : 'text-gray-500'
        }`}
      >
        Mark Fields as Violating
      </div>
      <CornerDownLeft className="w-4 h-4 p-0.5 border border-solid rounded-sm" />
    </div>
  );
  const doesNotViolateButton = (
    <div
      className={`flex bg-white items-center p-4 border-gray-200 border-solid rounded-md border justify-center gap-2 select-none ${
        shouldDisableDoesNotViolateButton
          ? ' cursor-not-allowed fill-gray-300'
          : // Need to include this so that the button is highlighted when the user presses the left arrow key
          state.highlighedButton === 'does_not_violate'
          ? '!bg-emerald-300 !border-emerald-100'
          : 'cursor-pointer fill-gray-500 hover:border-emerald-100 hover:bg-emerald-50 !active:bg-emerald-300 !active:border-emerald-100'
      }`}
      onClick={() =>
        !shouldDisableDoesNotViolateButton &&
        assignLabelToSample(state.currentSampleIndex, 'FALSE')
      }
    >
      <LeftArrowBox className="w-4 h-4" />
      <div
        className={`font-semibold ${
          shouldDisableDoesNotViolateButton ? 'text-gray-300' : 'text-gray-500'
        }`}
      >
        Not Violating
      </div>
    </div>
  );

  const moreOptionsButton = (
    <div className="flex flex-col">
      <div
        className="flex items-center justify-center gap-2 p-4 font-semibold text-gray-500 border border-gray-500 border-solid rounded-md cursor-pointer select-none fill-gray-500 bg-inherit hover:bg-gray-100"
        onClick={(e) => {
          setState((prevState) => ({
            ...prevState,
            moreOptionsVisible: !prevState.moreOptionsVisible,
          }));
          e.stopPropagation();
        }}
      >
        <>More Options</>
        <ChevronDown className="w-4 h-4" />
      </div>
      <div
        className="flex flex-col gap-2 mt-1 bg-white rounded-md shadow-md select-none"
        style={{
          maxHeight: state.moreOptionsVisible ? '200px' : '0px',
          overflow: 'hidden',
          transition: 'max-height 0.3s ease-in-out, padding 0.3s ease-in-out',
          padding: state.moreOptionsVisible ? '16px' : '0px 16px',
        }}
      >
        <div
          className="flex flex-row items-center gap-2 px-2 cursor-pointer"
          onClick={() =>
            setState((prev) => ({
              ...prev,
              policyDrawerOpen: !prev.policyDrawerOpen,
            }))
          }
        >
          <div className="px-2 py-1 text-xs border border-gray-500 border-solid rounded-md">
            P
          </div>
          <>Preview Policy</>
        </div>
        <div className="my-2 divider" />
        <div className="px-2 text-xs text-gray-500">
          UNSURE IF CONTENT IS VIOLATING
        </div>
        <div
          className={`px-2 ${
            shouldDisableDoesNotViolateButton
              ? 'text-secondary cursor-not-allowed'
              : 'cursor-pointer'
          }`}
          onClick={() => {
            if (shouldDisableDoesNotViolateButton) {
              return;
            }

            assignLabelToSample(state.currentSampleIndex, 'EDGE_CASE');
          }}
        >
          It's an edge case
        </div>
        <div
          className={`px-2 ${
            shouldDisableDoesNotViolateButton
              ? 'text-secondary cursor-not-allowed'
              : 'cursor-pointer'
          }`}
          onClick={() => {
            if (shouldDisableDoesNotViolateButton) {
              return;
            }

            assignLabelToSample(state.currentSampleIndex, 'NEEDS_CONTEXT');
          }}
        >
          Needs more context
        </div>
      </div>
    </div>
  );

  const rightColumn = (
    <div className="flex flex-col justify-between">
      <div className="flex flex-col gap-4">
        <div className="flex flex-col gap-2">
          <div className="text-base font-semibold">
            Does this violate your {policy.name} policy?
          </div>
          <div className="text-base">Click to mark all relevant fields</div>
        </div>
        {violatesButton}
        {doesNotViolateButton}
        {moreOptionsButton}
      </div>
      <TrainingPipelineLabelingProgressComponent
        currentIndex={state.currentSampleIndex}
        total={samples.length}
      />
    </div>
  );

  const labelingScreen = (
    <div
      ref={containerRef}
      className="flex flex-row w-full py-4 space-x-8 grow"
      onKeyDown={(e) => {
        if (e.repeat || state.highlighedButton !== 'none') {
          return;
        }

        if (e.key === 'ArrowLeft' && !shouldDisableDoesNotViolateButton) {
          setHighlighedButton('does_not_violate');
        } else if (e.key === 'Enter' && !shouldDisableViolatesButton) {
          setHighlighedButton('violates');
        } else if (e.key === 'p') {
          setState((prev) => ({ ...prev, policyDrawerOpen: true }));
        }
      }}
      onKeyUp={(e) => {
        if (e.repeat || state.currentSampleIndex >= samples.length) {
          return;
        }

        if (
          e.key === 'ArrowLeft' &&
          state.highlighedButton === 'does_not_violate' &&
          !shouldDisableDoesNotViolateButton
        ) {
          assignLabelToSample(state.currentSampleIndex, 'FALSE');
        } else if (
          e.key === 'Enter' &&
          state.highlighedButton === 'violates' &&
          !shouldDisableViolatesButton
        ) {
          assignLabelToSample(state.currentSampleIndex, 'TRUE');
        }
      }}
      tabIndex={-1}
      onClick={() =>
        setState((prevState) => ({
          ...prevState,
          moreOptionsVisible: false,
        }))
      }
    >
      <div className="flex flex-col items-stretch w-3/5 pl-16 text-start">
        {state.currentSampleIndex < state.allSamples.length && (
          <div className="flex flex-col items-stretch">
            <SampleCard
              sample={state.allSamples[state.currentSampleIndex]}
              primaryFields={primaryFields}
              selectedFieldJsonPointers={
                state.selectedSampleFields[
                  samples[state.currentSampleIndex].item.itemId
                ] ?? []
              }
              onChangeSelectedFields={(fieldJsonPointers) => {
                setSelectedFields({
                  ...state.selectedSampleFields,
                  [samples[state.currentSampleIndex].item.itemId]:
                    fieldJsonPointers,
                });
              }}
              options={{ unblurAllMedia: false }}
            />
          </div>
        )}
        {samples.length > 0 && state.currentSampleIndex >= samples.length && (
          <div className="pt-6 text-lg font-semibold">All samples labeled!</div>
        )}
      </div>
      <div className="self-stretch w-px bg-gray-200" />
      {rightColumn}
    </div>
  );

  const onSubmit = () => {
    if (samples.length === 0 || state.currentSampleIndex < samples.length) {
      return;
    }
    submitLabels({
      variables: {
        input: {
          labeledItems: getLabeledItemsForSubmission({
            samples: state.allSamples,
            selectedSampleFields: state.selectedSampleFields,
            policyId,
            itemTypes: itemTypes ?? [],
            // This non-null assertion is safe because if we have items to label,
            // we necessarily need to have received a sampling strategy as well
            // based on the graphql definition of SamplingJobSuccess
            // eslint-disable-next-line @typescript-eslint/no-unnecessary-type-assertion
            samplingStrategy: samplingStrategy!,
          }),
        },
      },
      onCompleted: () =>
        goToScreen(nextScreen, policyId, sessionId, state.jobId),
      onError: () => {
        showModal({
          visible: true,
          title: 'Error saving labels',
          body: 'We encountered an issue while trying to save your labels. Would you like to try again?',
          okText: 'Try Again',
          onOk: onSubmit,
          okIsDangerButton: false,
          cancelVisible: false,
        });
      },
    });
  };

  const modal = (
    <CoveModal
      title={state.modalInfo.title}
      visible={state.modalInfo.visible}
      onClose={hideModal}
      footer={[
        ...(state.modalInfo.cancelText
          ? [
              {
                title: state.modalInfo.cancelText,
                onClick: state.modalInfo.onCancel,
                type: 'secondary' as const,
              },
            ]
          : []),
        {
          title: state.modalInfo.okText,
          onClick: state.modalInfo.onOk,
          type: 'primary',
        },
      ]}
      hideCloseButton
    >
      <div className="pt-2">{state.modalInfo.body}</div>
    </CoveModal>
  );

  return (
    <div className="flex flex-col items-center justify-between h-full">
      {isLoading ? null : labelingScreen}
      {modal}
      <TrainingPipelineFooter
        primaryButton={{
          title: 'Continue',
          onClick: onSubmit,
          disabled:
            samples.length === 0 || state.currentSampleIndex < samples.length,
          loading: submitLabelsLoading,
        }}
        onBack={() =>
          showModal({
            visible: true,
            title: 'Are you sure you want to go back?',
            body: "You'll lose all the progress you made labeling examples if you go back and edit the policy defintion.",
            okText: 'Keep Going',
            okIsDangerButton: false,
            onOk: hideModal,
            cancelText: 'Start Over',
            cancelVisible: true,
            onCancel: () => goToScreen(previousScreen, policyId, sessionId),
          })
        }
      />
      <TrainingPipelineLabelSamplesModal
        hasPositiveLabels={state.unbalancedLabelingModalInfo.hasPositiveLabels}
        visible={state.unbalancedLabelingModalInfo.visible}
        labelingStage={state.unbalancedLabelingModalInfo.labelingStage}
        policyName={policy.name}
        onRedefinePolicy={() => goToScreen(previousScreen, policyId)}
        onRestartLabeling={resetLabelingExercise}
      />
      <TrainingPipelinePolicyDrawer
        policy={{
          name: policy.name,
          policyText: policy.policyText ?? '',
          enforcementGuidelines: policy.enforcementGuidelines ?? '',
        }}
        open={state.policyDrawerOpen}
        onClose={() => {
          setState((prev) => ({ ...prev, policyDrawerOpen: false }));

          if (containerRef.current) {
            containerRef.current.focus();
          }
        }}
      />
    </div>
  );
}
