Severian's picture
initial commit
a8b3f00
raw
history blame
10.6 kB
import {
useCallback,
useEffect,
useRef,
useState,
} from 'react'
import produce from 'immer'
import { isEqual } from 'lodash-es'
import type { ValueSelector, Var } from '../../types'
import { BlockEnum, VarType } from '../../types'
import {
useIsChatMode, useNodesReadOnly,
useWorkflow,
} from '../../hooks'
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
import {
getMultipleRetrievalConfig,
getSelectedDatasetsMode,
} from './utils'
import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
const isChatMode = useIsChatMode()
const { getBeforeNodesInSameBranch } = useWorkflow()
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
const newInputs = produce(s, (draft) => {
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
delete draft.single_retrieval_config
else
delete draft.multiple_retrieval_config
})
// not work in pass to draft...
doSetInputs(newInputs)
inputRef.current = newInputs
}, [doSetInputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector
})
setInputs(newInputs)
}, [inputs, setInputs])
const {
currentProvider,
currentModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
const draftModel = draft.single_retrieval_config?.model
draftModel.provider = model.provider
draftModel.name = model.modelId
draftModel.mode = model.mode!
})
setInputs(newInputs)
}, [setInputs])
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
// inputRef.current.single_retrieval_config?.model is old when change the provider...
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
return
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
draft.single_retrieval_config.model.completion_params = newParams
})
setInputs(newInputs)
}, [setInputs])
// set defaults models
useEffect(() => {
const inputs = inputRef.current
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
return
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
return
const newInput = produce(inputs, (draft) => {
if (currentProvider?.provider && currentModel?.model) {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider,
name: currentModel?.model,
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
score_threshold: multipleRetrievalConfig?.score_threshold,
reranking_model: multipleRetrievalConfig?.reranking_model,
reranking_mode: multipleRetrievalConfig?.reranking_mode,
weights: multipleRetrievalConfig?.weights,
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
? multipleRetrievalConfig.reranking_enable
: Boolean(currentRerankModel && rerankDefaultModel),
}
})
setInputs(newInput)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
const [rerankModelOpen, setRerankModelOpen] = useState(false)
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
const newInputs = produce(inputs, (draft) => {
draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
}
else {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider || '',
name: currentModel?.model || '',
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
})
setInputs(newInputs)
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
})
setInputs(newInputs)
}, [inputs, setInputs, selectedDatasets, currentRerankModel])
// datasets
useEffect(() => {
(async () => {
const inputs = inputRef.current
const datasetIds = inputs.dataset_ids
if (datasetIds?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
setSelectedDatasets(dataSetsWithDetail)
}
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds
})
setInputs(newInputs)
})()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
useEffect(() => {
const inputs = inputRef.current
let query_variable_selector: ValueSelector = inputs.query_variable_selector
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
query_variable_selector = [startNodeId, 'sys.query']
setInputs(produce(inputs, (draft) => {
draft.query_variable_selector = query_variable_selector
}))
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
const {
mixtureHighQualityAndEconomic,
mixtureInternalAndExternal,
inconsistentEmbeddingModel,
allInternal,
allExternal,
} = getSelectedDatasetsMode(newDatasets)
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id)
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
}
})
setInputs(newInputs)
setSelectedDatasets(newDatasets)
if (
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
|| mixtureInternalAndExternal
|| allExternal
)
setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string
}, [])
// single run
const {
isShowSingleRun,
hideSingleRun,
runningStatus,
handleRun,
handleStop,
runInputData,
setRunInputData,
runResult,
} = useOneStepRun<KnowledgeRetrievalNodeType>({
id,
data: inputs,
defaultRunInputData: {
query: '',
},
})
const query = runInputData.query
const setQuery = useCallback((newQuery: string) => {
setRunInputData({
...runInputData,
query: newQuery,
})
}, [runInputData, setRunInputData])
return {
readOnly,
inputs,
handleQueryVarChange,
filterVar,
handleRetrievalModeChange,
handleMultipleRetrievalConfigChange,
handleModelChanged,
handleCompletionParamsChange,
selectedDatasets: selectedDatasets.filter(d => d.name),
handleOnDatasetsChange,
isShowSingleRun,
hideSingleRun,
runningStatus,
handleRun,
handleStop,
query,
setQuery,
runResult,
rerankModelOpen,
setRerankModelOpen,
}
}
export default useConfig