Spaces:
Build error
Build error
import datetime | |
import logging | |
import time | |
import uuid | |
import click | |
from celery import shared_task | |
from sqlalchemy import func | |
from core.indexing_runner import IndexingRunner | |
from core.model_manager import ModelManager | |
from core.model_runtime.entities.model_entities import ModelType | |
from extensions.ext_database import db | |
from extensions.ext_redis import redis_client | |
from libs import helper | |
from models.dataset import Dataset, Document, DocumentSegment | |
def batch_create_segment_to_index_task( | |
job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str | |
): | |
""" | |
Async batch create segment to index | |
:param job_id: | |
:param content: | |
:param dataset_id: | |
:param document_id: | |
:param tenant_id: | |
:param user_id: | |
Usage: batch_create_segment_to_index_task.delay(segment_id) | |
""" | |
logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) | |
start_at = time.perf_counter() | |
indexing_cache_key = "segment_batch_import_{}".format(job_id) | |
try: | |
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |
if not dataset: | |
raise ValueError("Dataset not exist.") | |
dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | |
if not dataset_document: | |
raise ValueError("Document not exist.") | |
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |
raise ValueError("Document is not available.") | |
document_segments = [] | |
embedding_model = None | |
if dataset.indexing_technique == "high_quality": | |
model_manager = ModelManager() | |
embedding_model = model_manager.get_model_instance( | |
tenant_id=dataset.tenant_id, | |
provider=dataset.embedding_model_provider, | |
model_type=ModelType.TEXT_EMBEDDING, | |
model=dataset.embedding_model, | |
) | |
for segment in content: | |
content = segment["content"] | |
doc_id = str(uuid.uuid4()) | |
segment_hash = helper.generate_text_hash(content) | |
# calc embedding use tokens | |
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 | |
max_position = ( | |
db.session.query(func.max(DocumentSegment.position)) | |
.filter(DocumentSegment.document_id == dataset_document.id) | |
.scalar() | |
) | |
segment_document = DocumentSegment( | |
tenant_id=tenant_id, | |
dataset_id=dataset_id, | |
document_id=document_id, | |
index_node_id=doc_id, | |
index_node_hash=segment_hash, | |
position=max_position + 1 if max_position else 1, | |
content=content, | |
word_count=len(content), | |
tokens=tokens, | |
created_by=user_id, | |
indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |
status="completed", | |
completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |
) | |
if dataset_document.doc_form == "qa_model": | |
segment_document.answer = segment["answer"] | |
db.session.add(segment_document) | |
document_segments.append(segment_document) | |
# add index to db | |
indexing_runner = IndexingRunner() | |
indexing_runner.batch_add_segments(document_segments, dataset) | |
db.session.commit() | |
redis_client.setex(indexing_cache_key, 600, "completed") | |
end_at = time.perf_counter() | |
logging.info( | |
click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") | |
) | |
except Exception as e: | |
logging.exception("Segments batch created index failed:{}".format(str(e))) | |
redis_client.setex(indexing_cache_key, 600, "error") | |