Spaces:
Build error
Build error
import { | |
useCallback, | |
useEffect, | |
useRef, | |
useState, | |
} from 'react' | |
import produce from 'immer' | |
import { isEqual } from 'lodash-es' | |
import type { ValueSelector, Var } from '../../types' | |
import { BlockEnum, VarType } from '../../types' | |
import { | |
useIsChatMode, useNodesReadOnly, | |
useWorkflow, | |
} from '../../hooks' | |
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types' | |
import { | |
getMultipleRetrievalConfig, | |
getSelectedDatasetsMode, | |
} from './utils' | |
import { RETRIEVE_TYPE } from '@/types/app' | |
import { DATASET_DEFAULT } from '@/config' | |
import type { DataSet } from '@/models/datasets' | |
import { fetchDatasets } from '@/service/datasets' | |
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' | |
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' | |
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |
const { nodesReadOnly: readOnly } = useNodesReadOnly() | |
const isChatMode = useIsChatMode() | |
const { getBeforeNodesInSameBranch } = useWorkflow() | |
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) | |
const startNodeId = startNode?.id | |
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) | |
const inputRef = useRef(inputs) | |
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { | |
const newInputs = produce(s, (draft) => { | |
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) | |
delete draft.single_retrieval_config | |
else | |
delete draft.multiple_retrieval_config | |
}) | |
// not work in pass to draft... | |
doSetInputs(newInputs) | |
inputRef.current = newInputs | |
}, [doSetInputs]) | |
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { | |
const newInputs = produce(inputs, (draft) => { | |
draft.query_variable_selector = newVar as ValueSelector | |
}) | |
setInputs(newInputs) | |
}, [inputs, setInputs]) | |
const { | |
currentProvider, | |
currentModel, | |
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) | |
const { | |
modelList: rerankModelList, | |
defaultModel: rerankDefaultModel, | |
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |
const { | |
currentModel: currentRerankModel, | |
} = useCurrentProviderAndModel( | |
rerankModelList, | |
rerankDefaultModel | |
? { | |
...rerankDefaultModel, | |
provider: rerankDefaultModel.provider.provider, | |
} | |
: undefined, | |
) | |
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { | |
const newInputs = produce(inputRef.current, (draft) => { | |
if (!draft.single_retrieval_config) { | |
draft.single_retrieval_config = { | |
model: { | |
provider: '', | |
name: '', | |
mode: '', | |
completion_params: {}, | |
}, | |
} | |
} | |
const draftModel = draft.single_retrieval_config?.model | |
draftModel.provider = model.provider | |
draftModel.name = model.modelId | |
draftModel.mode = model.mode! | |
}) | |
setInputs(newInputs) | |
}, [setInputs]) | |
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => { | |
// inputRef.current.single_retrieval_config?.model is old when change the provider... | |
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params)) | |
return | |
const newInputs = produce(inputRef.current, (draft) => { | |
if (!draft.single_retrieval_config) { | |
draft.single_retrieval_config = { | |
model: { | |
provider: '', | |
name: '', | |
mode: '', | |
completion_params: {}, | |
}, | |
} | |
} | |
draft.single_retrieval_config.model.completion_params = newParams | |
}) | |
setInputs(newInputs) | |
}, [setInputs]) | |
// set defaults models | |
useEffect(() => { | |
const inputs = inputRef.current | |
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) | |
return | |
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) | |
return | |
const newInput = produce(inputs, (draft) => { | |
if (currentProvider?.provider && currentModel?.model) { | |
const hasSetModel = draft.single_retrieval_config?.model?.provider | |
if (!hasSetModel) { | |
draft.single_retrieval_config = { | |
model: { | |
provider: currentProvider?.provider, | |
name: currentModel?.model, | |
mode: currentModel?.model_properties?.mode as string, | |
completion_params: {}, | |
}, | |
} | |
} | |
} | |
const multipleRetrievalConfig = draft.multiple_retrieval_config | |
draft.multiple_retrieval_config = { | |
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, | |
score_threshold: multipleRetrievalConfig?.score_threshold, | |
reranking_model: multipleRetrievalConfig?.reranking_model, | |
reranking_mode: multipleRetrievalConfig?.reranking_mode, | |
weights: multipleRetrievalConfig?.weights, | |
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined | |
? multipleRetrievalConfig.reranking_enable | |
: Boolean(currentRerankModel && rerankDefaultModel), | |
} | |
}) | |
setInputs(newInput) | |
// eslint-disable-next-line react-hooks/exhaustive-deps | |
}, [currentProvider?.provider, currentModel, rerankDefaultModel]) | |
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([]) | |
const [rerankModelOpen, setRerankModelOpen] = useState(false) | |
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { | |
const newInputs = produce(inputs, (draft) => { | |
draft.retrieval_mode = newMode | |
if (newMode === RETRIEVE_TYPE.multiWay) { | |
const multipleRetrievalConfig = draft.multiple_retrieval_config | |
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |
} | |
else { | |
const hasSetModel = draft.single_retrieval_config?.model?.provider | |
if (!hasSetModel) { | |
draft.single_retrieval_config = { | |
model: { | |
provider: currentProvider?.provider || '', | |
name: currentModel?.model || '', | |
mode: currentModel?.model_properties?.mode as string, | |
completion_params: {}, | |
}, | |
} | |
} | |
} | |
}) | |
setInputs(newInputs) | |
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) | |
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | |
const newInputs = produce(inputs, (draft) => { | |
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |
}) | |
setInputs(newInputs) | |
}, [inputs, setInputs, selectedDatasets, currentRerankModel]) | |
// datasets | |
useEffect(() => { | |
(async () => { | |
const inputs = inputRef.current | |
const datasetIds = inputs.dataset_ids | |
if (datasetIds?.length > 0) { | |
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } }) | |
setSelectedDatasets(dataSetsWithDetail) | |
} | |
const newInputs = produce(inputs, (draft) => { | |
draft.dataset_ids = datasetIds | |
}) | |
setInputs(newInputs) | |
})() | |
// eslint-disable-next-line react-hooks/exhaustive-deps | |
}, []) | |
useEffect(() => { | |
const inputs = inputRef.current | |
let query_variable_selector: ValueSelector = inputs.query_variable_selector | |
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) | |
query_variable_selector = [startNodeId, 'sys.query'] | |
setInputs(produce(inputs, (draft) => { | |
draft.query_variable_selector = query_variable_selector | |
})) | |
// eslint-disable-next-line react-hooks/exhaustive-deps | |
}, []) | |
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => { | |
const { | |
mixtureHighQualityAndEconomic, | |
mixtureInternalAndExternal, | |
inconsistentEmbeddingModel, | |
allInternal, | |
allExternal, | |
} = getSelectedDatasetsMode(newDatasets) | |
const newInputs = produce(inputs, (draft) => { | |
draft.dataset_ids = newDatasets.map(d => d.id) | |
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | |
const multipleRetrievalConfig = draft.multiple_retrieval_config | |
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) | |
} | |
}) | |
setInputs(newInputs) | |
setSelectedDatasets(newDatasets) | |
if ( | |
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) | |
|| mixtureInternalAndExternal | |
|| allExternal | |
) | |
setRerankModelOpen(true) | |
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) | |
const filterVar = useCallback((varPayload: Var) => { | |
return varPayload.type === VarType.string | |
}, []) | |
// single run | |
const { | |
isShowSingleRun, | |
hideSingleRun, | |
runningStatus, | |
handleRun, | |
handleStop, | |
runInputData, | |
setRunInputData, | |
runResult, | |
} = useOneStepRun<KnowledgeRetrievalNodeType>({ | |
id, | |
data: inputs, | |
defaultRunInputData: { | |
query: '', | |
}, | |
}) | |
const query = runInputData.query | |
const setQuery = useCallback((newQuery: string) => { | |
setRunInputData({ | |
...runInputData, | |
query: newQuery, | |
}) | |
}, [runInputData, setRunInputData]) | |
return { | |
readOnly, | |
inputs, | |
handleQueryVarChange, | |
filterVar, | |
handleRetrievalModeChange, | |
handleMultipleRetrievalConfigChange, | |
handleModelChanged, | |
handleCompletionParamsChange, | |
selectedDatasets: selectedDatasets.filter(d => d.name), | |
handleOnDatasetsChange, | |
isShowSingleRun, | |
hideSingleRun, | |
runningStatus, | |
handleRun, | |
handleStop, | |
query, | |
setQuery, | |
runResult, | |
rerankModelOpen, | |
setRerankModelOpen, | |
} | |
} | |
export default useConfig | |