Spaces:
Runtime error
Runtime error
"""A script to load a dataset or set of datasets from a config for a Lilac instance. | |
Usage: | |
poetry run python -m lilac.load \ | |
--output_dir=demo_data \ | |
--config_path=lilac_hf_space.yml | |
""" | |
import gc | |
import os | |
import shutil | |
import click | |
import dask | |
import psutil | |
from distributed import Client | |
from .concepts.db_concept import DiskConceptDB, DiskConceptModelDB | |
from .config import EmbeddingConfig, SignalConfig, read_config | |
from .data.dataset_duckdb import DatasetDuckDB | |
from .data_loader import process_source | |
from .db_manager import get_dataset, list_datasets, remove_dataset_from_cache | |
from .schema import ROWID, PathTuple | |
from .tasks import TaskManager, TaskStepId | |
from .utils import DebugTimer, get_datasets_dir | |
def load_command(output_dir: str, config_path: str, overwrite: bool) -> None: | |
"""Run the source loader as a binary.""" | |
load(output_dir, config_path, overwrite) | |
def load(output_dir: str, config_path: str, overwrite: bool) -> None: | |
"""Run the source loader as a binary.""" | |
old_data_path = os.environ.get('LILAC_DATA_PATH') | |
os.environ['LILAC_DATA_PATH'] = output_dir | |
# Turn off debug logging. | |
del os.environ['DEBUG'] | |
# Use views to avoid loading duckdb tables into RAM since we aren't query heavy. | |
os.environ['DUCKDB_USE_VIEWS'] = '1' | |
config = read_config(config_path) | |
# Explicitly create a dask client in sync mode. | |
dask.config.set({'distributed.worker.daemon': False}) | |
total_memory_gb = psutil.virtual_memory().total / (1024**3) * 2 / 3 | |
task_manager = TaskManager(Client(memory_limit=f'{total_memory_gb} GB')) | |
if overwrite: | |
shutil.rmtree(get_datasets_dir(output_dir), ignore_errors=True) | |
existing_datasets = [f'{d.namespace}/{d.dataset_name}' for d in list_datasets(output_dir)] | |
print() | |
print('*** Load datasets ***') | |
if overwrite: | |
datasets_to_load = config.datasets | |
else: | |
datasets_to_load = [ | |
d for d in config.datasets if f'{d.namespace}/{d.name}' not in existing_datasets | |
] | |
skipped_datasets = [ | |
d for d in config.datasets if f'{d.namespace}/{d.name}' in existing_datasets | |
] | |
print('Skipping loaded datasets:', ', '.join([d.name for d in skipped_datasets])) | |
with DebugTimer(f'Loading datasets: {", ".join([d.name for d in datasets_to_load])}'): | |
for d in datasets_to_load: | |
shutil.rmtree(os.path.join(output_dir, d.name), ignore_errors=True) | |
task_id = task_manager.task_id(f'Load dataset {d.namespace}/{d.name}') | |
task_manager.execute(task_id, process_source, output_dir, d, (task_id, 0)) | |
task_manager.wait() | |
print() | |
total_num_rows = 0 | |
for d in datasets_to_load: | |
dataset = DatasetDuckDB(d.namespace, d.name) | |
num_rows = dataset.select_rows([ROWID], limit=1).total_num_rows | |
print(f'{d.namespace}/{d.name} loaded with {num_rows:,} rows.') | |
# Free up RAM. | |
del dataset | |
total_num_rows += num_rows | |
print(f'Done loading {len(datasets_to_load)} datasets with {total_num_rows:,} rows.') | |
print('*** Dataset settings ***') | |
for d in config.datasets: | |
if d.settings: | |
dataset = DatasetDuckDB(d.namespace, d.name) | |
dataset.update_settings(d.settings) | |
print() | |
print('*** Compute embeddings ***') | |
with DebugTimer('Loading embeddings'): | |
for d in config.datasets: | |
dataset = DatasetDuckDB(d.namespace, d.name) | |
# If embeddings are explicitly set, use only those. | |
embeddings = d.embeddings or [] | |
# If embeddings are not explicitly set, use the media paths and preferred embedding from | |
# settings. | |
if not embeddings: | |
if d.settings and d.settings.ui: | |
for path in d.settings.ui.media_paths or []: | |
if d.settings.preferred_embedding: | |
embeddings.append( | |
EmbeddingConfig(path=path, embedding=d.settings.preferred_embedding)) | |
for e in embeddings: | |
if e not in dataset.config().embeddings: | |
print('scheduling', e) | |
task_id = task_manager.task_id(f'Compute embedding {e.embedding} on {d.name}:{e.path}') | |
task_manager.execute(task_id, _compute_embedding, d.namespace, d.name, e, output_dir, | |
overwrite, (task_id, 0)) | |
else: | |
print(f'Embedding {e.embedding} already exists for {d.name}:{e.path}. Skipping.') | |
del dataset | |
# Wait for all embeddings for each dataset to reduce the memory pressure. | |
task_manager.wait() | |
print() | |
print('*** Compute signals ***') | |
with DebugTimer('Computing signals'): | |
for d in config.datasets: | |
dataset = DatasetDuckDB(d.namespace, d.name) | |
# If signals are explicitly set, use only those. | |
signals = d.signals or [] | |
# If signals are not explicitly set, use the media paths and config.signals. | |
if not signals: | |
if d.settings and d.settings.ui: | |
for path in d.settings.ui.media_paths or []: | |
for signal in config.signals or []: | |
signals.append(SignalConfig(path=path, signal=signal)) | |
# Separate signals by path to avoid computing the same signal in parallel, which can cause | |
# issues with taking too much RAM. | |
path_signals: dict[PathTuple, list[SignalConfig]] = {} | |
for s in signals: | |
path_signals.setdefault(s.path, []).append(s) | |
for path, signals in path_signals.items(): | |
for s in signals: | |
if s not in dataset.config().signals: | |
task_id = task_manager.task_id(f'Compute signal {s.signal} on {d.name}:{s.path}') | |
task_manager.execute(task_id, _compute_signal, d.namespace, d.name, s, output_dir, | |
overwrite, (task_id, 0)) | |
else: | |
print(f'Signal {s.signal} already exists for {d.name}:{s.path}. Skipping.') | |
# Wait for all signals for each path to reduce the memory pressure. | |
task_manager.wait() | |
del dataset | |
print() | |
print('*** Compute model caches ***') | |
with DebugTimer('Computing model caches'): | |
concept_db = DiskConceptDB(output_dir) | |
concept_model_db = DiskConceptModelDB(concept_db) | |
if config.concept_model_cache_embeddings: | |
for concept_info in concept_db.list(): | |
for embedding in config.concept_model_cache_embeddings: | |
concept_model_db.sync( | |
concept_info.namespace, concept_info.name, embedding_name=embedding, create=True) | |
print() | |
print('Done!') | |
if old_data_path: | |
os.environ['LILAC_DATA_PATH'] = old_data_path | |
def _compute_signal(namespace: str, name: str, signal_config: SignalConfig, output_dir: str, | |
overwrite: bool, task_step_id: TaskStepId) -> None: | |
os.environ['LILAC_DATA_PATH'] = output_dir | |
os.environ['DUCKDB_USE_VIEWS'] = '1' | |
# Turn off debug logging. | |
if 'DEBUG' in os.environ: | |
del os.environ['DEBUG'] | |
compute_signal = False | |
if overwrite: | |
compute_signal = True | |
dataset = get_dataset(namespace, name) | |
if not compute_signal: | |
field = dataset.manifest().data_schema.get_field(signal_config.path) | |
signal_field = (field.fields or {}).get(signal_config.signal.key()) | |
if not signal_field or signal_field.signal != signal_config.signal.dict(): | |
compute_signal = True | |
if compute_signal: | |
dataset.compute_signal(signal_config.signal, signal_config.path, task_step_id) | |
# Free up RAM. | |
remove_dataset_from_cache(namespace, name) | |
del dataset | |
gc.collect() | |
def _compute_embedding(namespace: str, name: str, embedding_config: EmbeddingConfig, | |
output_dir: str, overwrite: bool, task_step_id: TaskStepId) -> None: | |
os.environ['LILAC_DATA_PATH'] = output_dir | |
os.environ['DUCKDB_USE_VIEWS'] = '1' | |
# Turn off debug logging. | |
if 'DEBUG' in os.environ: | |
del os.environ['DEBUG'] | |
compute_embedding = False | |
if overwrite: | |
compute_embedding = True | |
dataset = get_dataset(namespace, name) | |
if not compute_embedding: | |
field = dataset.manifest().data_schema.get_field(embedding_config.path) | |
embedding_field = (field.fields or {}).get(embedding_config.embedding) | |
if not embedding_field: | |
compute_embedding = True | |
if compute_embedding: | |
dataset.compute_embedding(embedding_config.embedding, embedding_config.path, task_step_id) | |
remove_dataset_from_cache(namespace, name) | |
del dataset | |
gc.collect() | |
if __name__ == '__main__': | |
load_command() | |