Spaces:
Runtime error
Runtime error
File size: 19,359 Bytes
e4f9cbe 51b77d2 e4f9cbe 55dc3dd e4f9cbe 51b77d2 e4f9cbe 51b77d2 e4f9cbe 51b77d2 e4f9cbe 2226ee3 e4f9cbe 2226ee3 e4f9cbe 2226ee3 e4f9cbe 3fcddc3 e4f9cbe cf614fd e4f9cbe 2226ee3 e4f9cbe 3fcddc3 e4f9cbe 3fcddc3 e4f9cbe 2226ee3 e4f9cbe 55dc3dd e4f9cbe 55dc3dd e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 |
"""Item: an individual entry in the dataset."""
import csv
import io
from collections import deque
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union, cast
import pyarrow as pa
from pydantic import BaseModel, StrictInt, StrictStr, validator
MANIFEST_FILENAME = 'manifest.json'
PARQUET_FILENAME_PREFIX = 'data'
# We choose `__rowid__` inspired by the standard `rowid` pseudocolumn in DBs:
# https://docs.oracle.com/cd/B19306_01/server.102/b14200/pseudocolumns008.htm
UUID_COLUMN = '__rowid__'
PATH_WILDCARD = '*'
VALUE_KEY = '__value__'
SIGNAL_METADATA_KEY = '__metadata__'
TEXT_SPAN_START_FEATURE = 'start'
TEXT_SPAN_END_FEATURE = 'end'
# Python doesn't work with recursive types. These types provide some notion of type-safety.
Scalar = Union[bool, datetime, int, float, str, bytes]
Item = Any
# Contains a string field name, a wildcard for repeateds, or a specific integer index for repeateds.
# This path represents a path to a particular column.
# Examples:
# ['article', 'field'] represents {'article': {'field': VALUES}}
# ['article', '*', 'field'] represents {'article': [{'field': VALUES}, {'field': VALUES}]}
# ['article', '0', 'field'] represents {'article': {'field': VALUES}}
PathTuple = tuple[StrictStr, ...]
Path = Union[PathTuple, StrictStr]
PathKeyedItem = tuple[Path, Item]
# These fields are for for python only and not written to a schema.
RichData = Union[str, bytes]
VectorKey = tuple[Union[StrictStr, StrictInt], ...]
PathKey = VectorKey
class DataType(str, Enum):
"""Enum holding the dtype for a field."""
STRING = 'string'
# Contains {start, end} offset integers with a reference_column.
STRING_SPAN = 'string_span'
BOOLEAN = 'boolean'
# Ints.
INT8 = 'int8'
INT16 = 'int16'
INT32 = 'int32'
INT64 = 'int64'
UINT8 = 'uint8'
UINT16 = 'uint16'
UINT32 = 'uint32'
UINT64 = 'uint64'
# Floats.
FLOAT16 = 'float16'
FLOAT32 = 'float32'
FLOAT64 = 'float64'
### Time ###
# Time of day (no time zone).
TIME = 'time'
# Calendar date (year, month, day), no time zone.
DATE = 'date'
# An "Instant" stored as number of microseconds (µs) since 1970-01-01 00:00:00+00 (UTC time zone).
TIMESTAMP = 'timestamp'
# Time span, stored as microseconds.
INTERVAL = 'interval'
BINARY = 'binary'
EMBEDDING = 'embedding'
NULL = 'null'
def __repr__(self) -> str:
return self.value
class SignalInputType(str, Enum):
"""Enum holding the signal input type."""
TEXT = 'text'
TEXT_EMBEDDING = 'text_embedding'
IMAGE = 'image'
def __repr__(self) -> str:
return self.value
SIGNAL_TYPE_TO_VALID_DTYPES: dict[SignalInputType, list[DataType]] = {
SignalInputType.TEXT: [DataType.STRING, DataType.STRING_SPAN],
SignalInputType.IMAGE: [DataType.BINARY],
}
def signal_type_supports_dtype(input_type: SignalInputType, dtype: DataType) -> bool:
"""Returns True if the signal compute type supports the dtype."""
return dtype in SIGNAL_TYPE_TO_VALID_DTYPES[input_type]
Bin = tuple[str, Optional[Union[float, int]], Optional[Union[float, int]]]
class Field(BaseModel):
"""Holds information for a field in the schema."""
repeated_field: Optional['Field'] = None
fields: Optional[dict[str, 'Field']] = None
dtype: Optional[DataType] = None
# Defined as the serialized signal when this field is the root result of a signal.
signal: Optional[dict[str, Any]] = None
# Maps a named bin to a tuple of (start, end) values.
bins: Optional[list[Bin]] = None
categorical: Optional[bool] = None
@validator('fields')
def either_fields_or_repeated_field_is_defined(
cls, fields: Optional[dict[str, 'Field']], values: dict[str,
Any]) -> Optional[dict[str, 'Field']]:
"""Error if both `fields` and `repeated_fields` are defined."""
if not fields:
return fields
if values.get('repeated_field'):
raise ValueError('Both "fields" and "repeated_field" should not be defined')
if VALUE_KEY in fields:
raise ValueError(f'{VALUE_KEY} is a reserved field name.')
return fields
@validator('dtype', always=True)
def infer_default_dtype(cls, dtype: Optional[DataType], values: dict[str,
Any]) -> Optional[DataType]:
"""Infers the default value for dtype if not explicitly provided."""
if dtype and values.get('repeated_field'):
raise ValueError('dtype and repeated_field cannot both be defined.')
if not values.get('repeated_field') and not values.get('fields') and not dtype:
raise ValueError('One of "fields", "repeated_field", or "dtype" should be defined')
return dtype
@validator('bins')
def validate_bins(cls, bins: list[Bin]) -> list[Bin]:
"""Validate the bins."""
if len(bins) < 2:
raise ValueError('Please specify at least two bins.')
_, first_start, _ = bins[0]
if first_start is not None:
raise ValueError('The first bin should have a `None` start value.')
_, _, last_end = bins[-1]
if last_end is not None:
raise ValueError('The last bin should have a `None` end value.')
for i, (_, start, _) in enumerate(bins):
if i == 0:
continue
prev_bin = bins[i - 1]
_, _, prev_end = prev_bin
if start != prev_end:
raise ValueError(
f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.')
return bins
@validator('categorical')
def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool:
"""Validate the categorical field."""
if categorical and is_float(values['dtype']):
raise ValueError('Categorical fields cannot be float dtypes.')
return categorical
def __str__(self) -> str:
return _str_field(self, indent=0)
def __repr__(self) -> str:
return f' {self.__class__.__name__}::{self.json(exclude_none=True, indent=2)}'
class Schema(BaseModel):
"""Database schema."""
fields: dict[str, Field]
# Cached leafs.
_leafs: Optional[dict[PathTuple, Field]] = None
class Config:
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@property
def leafs(self) -> dict[PathTuple, Field]:
"""Return all the leaf fields in the schema. A leaf is defined as a node that contains a value.
NOTE: Leafs may contain children. Leafs can be found as any node that has a dtype defined.
"""
if self._leafs:
return self._leafs
result: dict[PathTuple, Field] = {}
q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))])
while q:
path, field = q.popleft()
if field.dtype:
# Nodes with dtypes act as leafs. They also may have children.
result[path] = field
if field.fields:
for name, child_field in field.fields.items():
child_path = (*path, name)
q.append((child_path, child_field))
elif field.repeated_field:
child_path = (*path, PATH_WILDCARD)
q.append((child_path, field.repeated_field))
self._leafs = result
return result
def has_field(self, path: PathTuple) -> bool:
"""Returns if the field is found at the given path."""
field = cast(Field, self)
for path_part in path:
if field.fields:
field = cast(Field, field.fields.get(path_part))
if not field:
return False
elif field.repeated_field:
if path_part != PATH_WILDCARD:
return False
field = field.repeated_field
else:
return False
return True
def get_field(self, path: PathTuple) -> Field:
"""Returns the field at the given path."""
field = cast(Field, self)
for name in path:
if field.fields:
if name not in field.fields:
raise ValueError(f'Path {path} not found in schema')
field = field.fields[name]
elif field.repeated_field:
if name != PATH_WILDCARD:
raise ValueError(f'Invalid path {path}')
field = field.repeated_field
else:
raise ValueError(f'Invalid path {path}')
return field
def __str__(self) -> str:
return _str_fields(self.fields, indent=0)
def __repr__(self) -> str:
return self.json(exclude_none=True, indent=2)
def schema(schema_like: object) -> Schema:
"""Parse a schema-like object to a Schema object."""
field = _parse_field_like(schema_like)
if not field.fields:
raise ValueError('Schema must have fields')
return Schema(fields=field.fields)
def field(
dtype: Optional[Union[DataType, str]] = None,
signal: Optional[dict] = None,
fields: Optional[object] = None,
bins: Optional[list[Bin]] = None,
categorical: Optional[bool] = None,
) -> Field:
"""Parse a field-like object to a Field object."""
field = _parse_field_like(fields or {}, dtype)
if signal:
field.signal = signal
if dtype:
if isinstance(dtype, str):
dtype = DataType(dtype)
field.dtype = dtype
if bins:
field.bins = bins
if categorical is not None:
field.categorical = categorical
return field
def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field:
if isinstance(field_like, Field):
return field_like
elif isinstance(field_like, dict):
fields: dict[str, Field] = {}
for k, v in field_like.items():
fields[k] = _parse_field_like(v)
if isinstance(dtype, str):
dtype = DataType(dtype)
return Field(fields=fields or None, dtype=dtype)
elif isinstance(field_like, str):
return Field(dtype=DataType(field_like))
elif isinstance(field_like, list):
return Field(repeated_field=_parse_field_like(field_like[0], dtype=dtype))
else:
raise ValueError(f'Cannot parse field like: {field_like}')
def child_item_from_column_path(item: Item, path: Path) -> Item:
"""Return the last (child) item from a column path."""
child_item_value = item
for path_part in path:
if path_part == PATH_WILDCARD:
raise ValueError(
'child_item_from_column_path cannot be called with a path that contains a repeated '
f'wildcard: "{path}"')
# path_part can either be an integer or a string for a dictionary, both of which we can
# directly index with.
child_path = int(path_part) if path_part.isdigit() else path_part
child_item_value = child_item_value[child_path]
return child_item_value
def column_paths_match(path_match: Path, specific_path: Path) -> bool:
"""Test whether two column paths match.
Args:
path_match: A column path that contains wildcards, and sub-paths. This path will be used for
testing the second specific path.
specific_path: A column path that specifically identifies an field.
Returns
Whether specific_path matches the path_match. This will only match when the
paths are equal length. If a user wants to enrich everything with an array, they must use the
path wildcard '*' in their patch match.
"""
if isinstance(path_match, str):
path_match = (path_match,)
if isinstance(specific_path, str):
specific_path = (specific_path,)
if len(path_match) != len(specific_path):
return False
for path_match_p, specific_path_p in zip(path_match, specific_path):
if path_match_p == PATH_WILDCARD:
continue
if path_match_p != specific_path_p:
return False
return True
def normalize_path(path: Path) -> PathTuple:
"""Normalizes a dot seperated path, but ignores dots inside quotes, like regular SQL.
Examples
- 'a.b.c' will be parsed as ('a', 'b', 'c').
- '"a.b".c' will be parsed as ('a.b', 'c').
- '"a".b.c' will be parsed as ('a', 'b', 'c').
"""
if isinstance(path, str):
return tuple(next(csv.reader(io.StringIO(path), delimiter='.')))
return path
class ImageInfo(BaseModel):
"""Info about an individual image."""
path: Path
class SourceManifest(BaseModel):
"""The manifest that describes the dataset run, including schema and parquet files."""
# List of a parquet filepaths storing the data. The paths can be relative to `manifest.json`.
files: list[str]
# The data schema.
data_schema: Schema
# Image information for the dataset.
images: Optional[list[ImageInfo]] = None
def _str_fields(fields: dict[str, Field], indent: int) -> str:
prefix = ' ' * indent
out: list[str] = []
for name, field in fields.items():
out.append(f'{prefix}{name}:{_str_field(field, indent=indent + 2)}')
return '\n'.join(out)
def _str_field(field: Field, indent: int) -> str:
if field.fields:
prefix = '\n' if indent > 0 else ''
return f'{prefix}{_str_fields(field.fields, indent)}'
if field.repeated_field:
return f' list({_str_field(field.repeated_field, indent)})'
return f' {cast(DataType, field.dtype)}'
def dtype_to_arrow_schema(dtype: DataType) -> Union[pa.Schema, pa.DataType]:
"""Convert the dtype to an arrow dtype."""
if dtype == DataType.STRING:
return pa.string()
elif dtype == DataType.BOOLEAN:
return pa.bool_()
elif dtype == DataType.FLOAT16:
return pa.float16()
elif dtype == DataType.FLOAT32:
return pa.float32()
elif dtype == DataType.FLOAT64:
return pa.float64()
elif dtype == DataType.INT8:
return pa.int8()
elif dtype == DataType.INT16:
return pa.int16()
elif dtype == DataType.INT32:
return pa.int32()
elif dtype == DataType.INT64:
return pa.int64()
elif dtype == DataType.UINT8:
return pa.uint8()
elif dtype == DataType.UINT16:
return pa.uint16()
elif dtype == DataType.UINT32:
return pa.uint32()
elif dtype == DataType.UINT64:
return pa.uint64()
elif dtype == DataType.BINARY:
return pa.binary()
elif dtype == DataType.TIME:
return pa.time64()
elif dtype == DataType.DATE:
return pa.date64()
elif dtype == DataType.TIMESTAMP:
return pa.timestamp('us')
elif dtype == DataType.INTERVAL:
return pa.duration('us')
elif dtype == DataType.EMBEDDING:
# We reserve an empty column for embeddings in parquet files so they can be queried.
# The values are *not* filled out. If parquet and duckdb support embeddings in the future, we
# can set this dtype to the relevant pyarrow type.
return pa.null()
elif dtype == DataType.STRING_SPAN:
return pa.struct({
VALUE_KEY: pa.struct({
TEXT_SPAN_START_FEATURE: pa.int32(),
TEXT_SPAN_END_FEATURE: pa.int32()
})
})
elif dtype == DataType.NULL:
return pa.null()
else:
raise ValueError(f'Can not convert dtype "{dtype}" to arrow dtype')
def schema_to_arrow_schema(schema: Union[Schema, Field]) -> pa.Schema:
"""Convert our schema to arrow schema."""
arrow_schema = cast(pa.Schema, _schema_to_arrow_schema_impl(schema))
arrow_fields = {field.name: field.type for field in arrow_schema}
return pa.schema(arrow_fields)
def _schema_to_arrow_schema_impl(schema: Union[Schema, Field]) -> Union[pa.Schema, pa.DataType]:
"""Convert a schema to an apache arrow schema."""
if schema.fields:
arrow_fields: dict[str, Union[pa.Schema, pa.DataType]] = {}
for name, field in schema.fields.items():
if name == UUID_COLUMN:
arrow_schema = dtype_to_arrow_schema(cast(DataType, field.dtype))
else:
arrow_schema = _schema_to_arrow_schema_impl(field)
arrow_fields[name] = arrow_schema
if isinstance(schema, Schema):
# Top-level schemas do not have __value__ fields.
return pa.schema(arrow_fields)
else:
# When nodes have both dtype and children, we add __value__ alongside the fields.
if schema.dtype:
value_schema = dtype_to_arrow_schema(schema.dtype)
if schema.dtype == DataType.STRING_SPAN:
value_schema = value_schema[VALUE_KEY].type
arrow_fields[VALUE_KEY] = value_schema
return pa.struct(arrow_fields)
field = cast(Field, schema)
if field.repeated_field:
return pa.list_(_schema_to_arrow_schema_impl(field.repeated_field))
return dtype_to_arrow_schema(cast(DataType, field.dtype))
def arrow_dtype_to_dtype(arrow_dtype: pa.DataType) -> DataType:
"""Convert arrow dtype to our dtype."""
# Ints.
if arrow_dtype == pa.int8():
return DataType.INT8
elif arrow_dtype == pa.int16():
return DataType.INT16
elif arrow_dtype == pa.int32():
return DataType.INT32
elif arrow_dtype == pa.int64():
return DataType.INT64
elif arrow_dtype == pa.uint8():
return DataType.UINT8
elif arrow_dtype == pa.uint16():
return DataType.UINT16
elif arrow_dtype == pa.uint32():
return DataType.UINT32
elif arrow_dtype == pa.uint64():
return DataType.UINT64
# Floats.
elif arrow_dtype == pa.float16():
return DataType.FLOAT16
elif arrow_dtype == pa.float32():
return DataType.FLOAT32
elif arrow_dtype == pa.float64():
return DataType.FLOAT64
# Time.
elif pa.types.is_time(arrow_dtype):
return DataType.TIME
elif pa.types.is_date(arrow_dtype):
return DataType.DATE
elif pa.types.is_timestamp(arrow_dtype):
return DataType.TIMESTAMP
elif pa.types.is_duration(arrow_dtype):
return DataType.INTERVAL
# Others.
elif arrow_dtype == pa.string():
return DataType.STRING
elif pa.types.is_binary(arrow_dtype) or pa.types.is_fixed_size_binary(arrow_dtype):
return DataType.BINARY
elif pa.types.is_boolean(arrow_dtype):
return DataType.BOOLEAN
elif arrow_dtype == pa.null():
return DataType.NULL
else:
raise ValueError(f'Can not convert arrow dtype "{arrow_dtype}" to our dtype')
def arrow_schema_to_schema(schema: pa.Schema) -> Schema:
"""Convert arrow schema to our schema."""
# TODO(nsthorat): Change this implementation to allow more complicated reading of arrow schemas
# into our schema by inferring values when {__value__: value} is present in the pyarrow schema.
# This isn't necessary today as this util is only needed by sources which do not have data in the
# lilac format.
return cast(Schema, _arrow_schema_to_schema_impl(schema))
def _arrow_schema_to_schema_impl(schema: Union[pa.Schema, pa.DataType]) -> Union[Schema, Field]:
"""Convert an apache arrow schema to our schema."""
if isinstance(schema, (pa.Schema, pa.StructType)):
fields: dict[str, Field] = {
field.name: cast(Field, _arrow_schema_to_schema_impl(field.type)) for field in schema
}
return Schema(fields=fields) if isinstance(schema, pa.Schema) else Field(fields=fields)
elif isinstance(schema, pa.ListType):
return Field(repeated_field=cast(Field, _arrow_schema_to_schema_impl(schema.value_field.type)))
else:
return Field(dtype=arrow_dtype_to_dtype(schema))
def is_float(dtype: DataType) -> bool:
"""Check if a dtype is a float dtype."""
return dtype in [DataType.FLOAT16, DataType.FLOAT32, DataType.FLOAT64]
def is_integer(dtype: DataType) -> bool:
"""Check if a dtype is an integer dtype."""
return dtype in [
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16,
DataType.UINT32, DataType.UINT64
]
def is_temporal(dtype: DataType) -> bool:
"""Check if a dtype is a temporal dtype."""
return dtype in [DataType.TIME, DataType.DATE, DataType.TIMESTAMP, DataType.INTERVAL]
def is_ordinal(dtype: DataType) -> bool:
"""Check if a dtype is an ordinal dtype."""
return is_float(dtype) or is_integer(dtype) or is_temporal(dtype)
|