File size: 2,044 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app'
import type {
  DefaultModelResponse,
  Model,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets'

export const isReRankModelSelected = ({
  rerankDefaultModel,
  isRerankDefaultModelValid,
  retrievalConfig,
  rerankModelList,
  indexMethod,
}: {
  rerankDefaultModel?: DefaultModelResponse
  isRerankDefaultModelValid: boolean
  retrievalConfig: RetrievalConfig
  rerankModelList: Model[]
  indexMethod?: string
}) => {
  const rerankModelSelected = (() => {
    if (retrievalConfig.reranking_model?.reranking_model_name) {
      const provider = rerankModelList.find(({ provider }) => provider === retrievalConfig.reranking_model?.reranking_provider_name)

      return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
    }

    if (isRerankDefaultModelValid)
      return !!rerankDefaultModel

    return false
  })()

  if (
    indexMethod === 'high_quality'
    && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)
    && !rerankModelSelected
  )
    return false

  return true
}

export const ensureRerankModelSelected = ({
  rerankDefaultModel,
  indexMethod,
  retrievalConfig,
}: {
  rerankDefaultModel: DefaultModelResponse
  retrievalConfig: RetrievalConfig
  indexMethod?: string
}) => {
  const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
  if (
    indexMethod === 'high_quality'
    && (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
    && !rerankModel
    && rerankDefaultModel
  ) {
    return {
      ...retrievalConfig,
      reranking_model: {
        reranking_provider_name: rerankDefaultModel.provider.provider,
        reranking_model_name: rerankDefaultModel.model,
      },
    }
  }
  return retrievalConfig
}