File size: 19,359 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51b77d2
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
51b77d2
e4f9cbe
 
 
 
 
51b77d2
e4f9cbe
51b77d2
e4f9cbe
 
 
 
 
 
 
2226ee3
 
 
e4f9cbe
2226ee3
e4f9cbe
2226ee3
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fcddc3
 
 
 
 
 
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf614fd
 
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2226ee3
 
e4f9cbe
 
 
 
 
 
 
 
3fcddc3
e4f9cbe
 
 
 
 
 
 
 
 
 
 
3fcddc3
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2226ee3
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
 
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
"""Item: an individual entry in the dataset."""

import csv
import io
from collections import deque
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union, cast

import pyarrow as pa
from pydantic import BaseModel, StrictInt, StrictStr, validator

MANIFEST_FILENAME = 'manifest.json'
PARQUET_FILENAME_PREFIX = 'data'

# We choose `__rowid__` inspired by the standard `rowid` pseudocolumn in DBs:
# https://docs.oracle.com/cd/B19306_01/server.102/b14200/pseudocolumns008.htm
UUID_COLUMN = '__rowid__'
PATH_WILDCARD = '*'
VALUE_KEY = '__value__'
SIGNAL_METADATA_KEY = '__metadata__'
TEXT_SPAN_START_FEATURE = 'start'
TEXT_SPAN_END_FEATURE = 'end'

# Python doesn't work with recursive types. These types provide some notion of type-safety.
Scalar = Union[bool, datetime, int, float, str, bytes]
Item = Any

# Contains a string field name, a wildcard for repeateds, or a specific integer index for repeateds.
# This path represents a path to a particular column.
# Examples:
#  ['article', 'field'] represents {'article': {'field': VALUES}}
#  ['article', '*', 'field'] represents {'article': [{'field': VALUES}, {'field': VALUES}]}
#  ['article', '0', 'field'] represents {'article': {'field': VALUES}}
PathTuple = tuple[StrictStr, ...]
Path = Union[PathTuple, StrictStr]

PathKeyedItem = tuple[Path, Item]

# These fields are for for python only and not written to a schema.
RichData = Union[str, bytes]
VectorKey = tuple[Union[StrictStr, StrictInt], ...]
PathKey = VectorKey


class DataType(str, Enum):
  """Enum holding the dtype for a field."""
  STRING = 'string'
  # Contains {start, end} offset integers with a reference_column.
  STRING_SPAN = 'string_span'
  BOOLEAN = 'boolean'

  # Ints.
  INT8 = 'int8'
  INT16 = 'int16'
  INT32 = 'int32'
  INT64 = 'int64'
  UINT8 = 'uint8'
  UINT16 = 'uint16'
  UINT32 = 'uint32'
  UINT64 = 'uint64'

  # Floats.
  FLOAT16 = 'float16'
  FLOAT32 = 'float32'
  FLOAT64 = 'float64'

  ### Time ###
  # Time of day (no time zone).
  TIME = 'time'
  # Calendar date (year, month, day), no time zone.
  DATE = 'date'
  # An "Instant" stored as number of microseconds (µs) since 1970-01-01 00:00:00+00 (UTC time zone).
  TIMESTAMP = 'timestamp'
  # Time span, stored as microseconds.
  INTERVAL = 'interval'

  BINARY = 'binary'

  EMBEDDING = 'embedding'

  NULL = 'null'

  def __repr__(self) -> str:
    return self.value


class SignalInputType(str, Enum):
  """Enum holding the signal input type."""
  TEXT = 'text'
  TEXT_EMBEDDING = 'text_embedding'
  IMAGE = 'image'

  def __repr__(self) -> str:
    return self.value


SIGNAL_TYPE_TO_VALID_DTYPES: dict[SignalInputType, list[DataType]] = {
  SignalInputType.TEXT: [DataType.STRING, DataType.STRING_SPAN],
  SignalInputType.IMAGE: [DataType.BINARY],
}


def signal_type_supports_dtype(input_type: SignalInputType, dtype: DataType) -> bool:
  """Returns True if the signal compute type supports the dtype."""
  return dtype in SIGNAL_TYPE_TO_VALID_DTYPES[input_type]


Bin = tuple[str, Optional[Union[float, int]], Optional[Union[float, int]]]


class Field(BaseModel):
  """Holds information for a field in the schema."""
  repeated_field: Optional['Field'] = None
  fields: Optional[dict[str, 'Field']] = None
  dtype: Optional[DataType] = None
  # Defined as the serialized signal when this field is the root result of a signal.
  signal: Optional[dict[str, Any]] = None
  # Maps a named bin to a tuple of (start, end) values.
  bins: Optional[list[Bin]] = None
  categorical: Optional[bool] = None

  @validator('fields')
  def either_fields_or_repeated_field_is_defined(
      cls, fields: Optional[dict[str, 'Field']], values: dict[str,
                                                              Any]) -> Optional[dict[str, 'Field']]:
    """Error if both `fields` and `repeated_fields` are defined."""
    if not fields:
      return fields
    if values.get('repeated_field'):
      raise ValueError('Both "fields" and "repeated_field" should not be defined')
    if VALUE_KEY in fields:
      raise ValueError(f'{VALUE_KEY} is a reserved field name.')
    return fields

  @validator('dtype', always=True)
  def infer_default_dtype(cls, dtype: Optional[DataType], values: dict[str,
                                                                       Any]) -> Optional[DataType]:
    """Infers the default value for dtype if not explicitly provided."""
    if dtype and values.get('repeated_field'):
      raise ValueError('dtype and repeated_field cannot both be defined.')
    if not values.get('repeated_field') and not values.get('fields') and not dtype:
      raise ValueError('One of "fields", "repeated_field", or "dtype" should be defined')
    return dtype

  @validator('bins')
  def validate_bins(cls, bins: list[Bin]) -> list[Bin]:
    """Validate the bins."""
    if len(bins) < 2:
      raise ValueError('Please specify at least two bins.')
    _, first_start, _ = bins[0]
    if first_start is not None:
      raise ValueError('The first bin should have a `None` start value.')
    _, _, last_end = bins[-1]
    if last_end is not None:
      raise ValueError('The last bin should have a `None` end value.')
    for i, (_, start, _) in enumerate(bins):
      if i == 0:
        continue
      prev_bin = bins[i - 1]
      _, _, prev_end = prev_bin
      if start != prev_end:
        raise ValueError(
          f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.')
    return bins

  @validator('categorical')
  def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool:
    """Validate the categorical field."""
    if categorical and is_float(values['dtype']):
      raise ValueError('Categorical fields cannot be float dtypes.')
    return categorical

  def __str__(self) -> str:
    return _str_field(self, indent=0)

  def __repr__(self) -> str:
    return f' {self.__class__.__name__}::{self.json(exclude_none=True, indent=2)}'


class Schema(BaseModel):
  """Database schema."""
  fields: dict[str, Field]
  # Cached leafs.
  _leafs: Optional[dict[PathTuple, Field]] = None

  class Config:
    arbitrary_types_allowed = True
    underscore_attrs_are_private = True

  @property
  def leafs(self) -> dict[PathTuple, Field]:
    """Return all the leaf fields in the schema. A leaf is defined as a node that contains a value.

    NOTE: Leafs may contain children. Leafs can be found as any node that has a dtype defined.
    """
    if self._leafs:
      return self._leafs
    result: dict[PathTuple, Field] = {}
    q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))])
    while q:
      path, field = q.popleft()
      if field.dtype:
        # Nodes with dtypes act as leafs. They also may have children.
        result[path] = field
      if field.fields:
        for name, child_field in field.fields.items():
          child_path = (*path, name)
          q.append((child_path, child_field))
      elif field.repeated_field:
        child_path = (*path, PATH_WILDCARD)
        q.append((child_path, field.repeated_field))

    self._leafs = result
    return result

  def has_field(self, path: PathTuple) -> bool:
    """Returns if the field is found at the given path."""
    field = cast(Field, self)
    for path_part in path:
      if field.fields:
        field = cast(Field, field.fields.get(path_part))
        if not field:
          return False
      elif field.repeated_field:
        if path_part != PATH_WILDCARD:
          return False
        field = field.repeated_field
      else:
        return False
    return True

  def get_field(self, path: PathTuple) -> Field:
    """Returns the field at the given path."""
    field = cast(Field, self)
    for name in path:
      if field.fields:
        if name not in field.fields:
          raise ValueError(f'Path {path} not found in schema')
        field = field.fields[name]
      elif field.repeated_field:
        if name != PATH_WILDCARD:
          raise ValueError(f'Invalid path {path}')
        field = field.repeated_field
      else:
        raise ValueError(f'Invalid path {path}')
    return field

  def __str__(self) -> str:
    return _str_fields(self.fields, indent=0)

  def __repr__(self) -> str:
    return self.json(exclude_none=True, indent=2)


def schema(schema_like: object) -> Schema:
  """Parse a schema-like object to a Schema object."""
  field = _parse_field_like(schema_like)
  if not field.fields:
    raise ValueError('Schema must have fields')
  return Schema(fields=field.fields)


def field(
  dtype: Optional[Union[DataType, str]] = None,
  signal: Optional[dict] = None,
  fields: Optional[object] = None,
  bins: Optional[list[Bin]] = None,
  categorical: Optional[bool] = None,
) -> Field:
  """Parse a field-like object to a Field object."""
  field = _parse_field_like(fields or {}, dtype)
  if signal:
    field.signal = signal
  if dtype:
    if isinstance(dtype, str):
      dtype = DataType(dtype)
    field.dtype = dtype
  if bins:
    field.bins = bins
  if categorical is not None:
    field.categorical = categorical
  return field


def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field:
  if isinstance(field_like, Field):
    return field_like
  elif isinstance(field_like, dict):
    fields: dict[str, Field] = {}
    for k, v in field_like.items():
      fields[k] = _parse_field_like(v)
    if isinstance(dtype, str):
      dtype = DataType(dtype)
    return Field(fields=fields or None, dtype=dtype)
  elif isinstance(field_like, str):
    return Field(dtype=DataType(field_like))
  elif isinstance(field_like, list):
    return Field(repeated_field=_parse_field_like(field_like[0], dtype=dtype))
  else:
    raise ValueError(f'Cannot parse field like: {field_like}')


def child_item_from_column_path(item: Item, path: Path) -> Item:
  """Return the last (child) item from a column path."""
  child_item_value = item
  for path_part in path:
    if path_part == PATH_WILDCARD:
      raise ValueError(
        'child_item_from_column_path cannot be called with a path that contains a repeated '
        f'wildcard: "{path}"')
    # path_part can either be an integer or a string for a dictionary, both of which we can
    # directly index with.
    child_path = int(path_part) if path_part.isdigit() else path_part
    child_item_value = child_item_value[child_path]
  return child_item_value


def column_paths_match(path_match: Path, specific_path: Path) -> bool:
  """Test whether two column paths match.

  Args:
    path_match: A column path that contains wildcards, and sub-paths. This path will be used for
       testing the second specific path.
    specific_path: A column path that specifically identifies an field.

  Returns
    Whether specific_path matches the path_match. This will only match when the
    paths are equal length. If a user wants to enrich everything with an array, they must use the
    path wildcard '*' in their patch match.
  """
  if isinstance(path_match, str):
    path_match = (path_match,)
  if isinstance(specific_path, str):
    specific_path = (specific_path,)

  if len(path_match) != len(specific_path):
    return False

  for path_match_p, specific_path_p in zip(path_match, specific_path):
    if path_match_p == PATH_WILDCARD:
      continue

    if path_match_p != specific_path_p:
      return False

  return True


def normalize_path(path: Path) -> PathTuple:
  """Normalizes a dot seperated path, but ignores dots inside quotes, like regular SQL.

  Examples
  - 'a.b.c' will be parsed as ('a', 'b', 'c').
  - '"a.b".c' will be parsed as ('a.b', 'c').
  - '"a".b.c' will be parsed as ('a', 'b', 'c').
  """
  if isinstance(path, str):
    return tuple(next(csv.reader(io.StringIO(path), delimiter='.')))
  return path


class ImageInfo(BaseModel):
  """Info about an individual image."""
  path: Path


class SourceManifest(BaseModel):
  """The manifest that describes the dataset run, including schema and parquet files."""
  # List of a parquet filepaths storing the data. The paths can be relative to `manifest.json`.
  files: list[str]
  # The data schema.
  data_schema: Schema

  # Image information for the dataset.
  images: Optional[list[ImageInfo]] = None


def _str_fields(fields: dict[str, Field], indent: int) -> str:
  prefix = ' ' * indent
  out: list[str] = []
  for name, field in fields.items():
    out.append(f'{prefix}{name}:{_str_field(field, indent=indent + 2)}')
  return '\n'.join(out)


def _str_field(field: Field, indent: int) -> str:
  if field.fields:
    prefix = '\n' if indent > 0 else ''
    return f'{prefix}{_str_fields(field.fields, indent)}'
  if field.repeated_field:
    return f' list({_str_field(field.repeated_field, indent)})'
  return f' {cast(DataType, field.dtype)}'


def dtype_to_arrow_schema(dtype: DataType) -> Union[pa.Schema, pa.DataType]:
  """Convert the dtype to an arrow dtype."""
  if dtype == DataType.STRING:
    return pa.string()
  elif dtype == DataType.BOOLEAN:
    return pa.bool_()
  elif dtype == DataType.FLOAT16:
    return pa.float16()
  elif dtype == DataType.FLOAT32:
    return pa.float32()
  elif dtype == DataType.FLOAT64:
    return pa.float64()
  elif dtype == DataType.INT8:
    return pa.int8()
  elif dtype == DataType.INT16:
    return pa.int16()
  elif dtype == DataType.INT32:
    return pa.int32()
  elif dtype == DataType.INT64:
    return pa.int64()
  elif dtype == DataType.UINT8:
    return pa.uint8()
  elif dtype == DataType.UINT16:
    return pa.uint16()
  elif dtype == DataType.UINT32:
    return pa.uint32()
  elif dtype == DataType.UINT64:
    return pa.uint64()
  elif dtype == DataType.BINARY:
    return pa.binary()
  elif dtype == DataType.TIME:
    return pa.time64()
  elif dtype == DataType.DATE:
    return pa.date64()
  elif dtype == DataType.TIMESTAMP:
    return pa.timestamp('us')
  elif dtype == DataType.INTERVAL:
    return pa.duration('us')
  elif dtype == DataType.EMBEDDING:
    # We reserve an empty column for embeddings in parquet files so they can be queried.
    # The values are *not* filled out. If parquet and duckdb support embeddings in the future, we
    # can set this dtype to the relevant pyarrow type.
    return pa.null()
  elif dtype == DataType.STRING_SPAN:
    return pa.struct({
      VALUE_KEY: pa.struct({
        TEXT_SPAN_START_FEATURE: pa.int32(),
        TEXT_SPAN_END_FEATURE: pa.int32()
      })
    })
  elif dtype == DataType.NULL:
    return pa.null()
  else:
    raise ValueError(f'Can not convert dtype "{dtype}" to arrow dtype')


def schema_to_arrow_schema(schema: Union[Schema, Field]) -> pa.Schema:
  """Convert our schema to arrow schema."""
  arrow_schema = cast(pa.Schema, _schema_to_arrow_schema_impl(schema))
  arrow_fields = {field.name: field.type for field in arrow_schema}
  return pa.schema(arrow_fields)


def _schema_to_arrow_schema_impl(schema: Union[Schema, Field]) -> Union[pa.Schema, pa.DataType]:
  """Convert a schema to an apache arrow schema."""
  if schema.fields:
    arrow_fields: dict[str, Union[pa.Schema, pa.DataType]] = {}
    for name, field in schema.fields.items():
      if name == UUID_COLUMN:
        arrow_schema = dtype_to_arrow_schema(cast(DataType, field.dtype))
      else:
        arrow_schema = _schema_to_arrow_schema_impl(field)
      arrow_fields[name] = arrow_schema

    if isinstance(schema, Schema):
      # Top-level schemas do not have __value__ fields.
      return pa.schema(arrow_fields)
    else:
      # When nodes have both dtype and children, we add __value__ alongside the fields.
      if schema.dtype:
        value_schema = dtype_to_arrow_schema(schema.dtype)
        if schema.dtype == DataType.STRING_SPAN:
          value_schema = value_schema[VALUE_KEY].type
        arrow_fields[VALUE_KEY] = value_schema

      return pa.struct(arrow_fields)

  field = cast(Field, schema)
  if field.repeated_field:
    return pa.list_(_schema_to_arrow_schema_impl(field.repeated_field))

  return dtype_to_arrow_schema(cast(DataType, field.dtype))


def arrow_dtype_to_dtype(arrow_dtype: pa.DataType) -> DataType:
  """Convert arrow dtype to our dtype."""
  # Ints.
  if arrow_dtype == pa.int8():
    return DataType.INT8
  elif arrow_dtype == pa.int16():
    return DataType.INT16
  elif arrow_dtype == pa.int32():
    return DataType.INT32
  elif arrow_dtype == pa.int64():
    return DataType.INT64
  elif arrow_dtype == pa.uint8():
    return DataType.UINT8
  elif arrow_dtype == pa.uint16():
    return DataType.UINT16
  elif arrow_dtype == pa.uint32():
    return DataType.UINT32
  elif arrow_dtype == pa.uint64():
    return DataType.UINT64
  # Floats.
  elif arrow_dtype == pa.float16():
    return DataType.FLOAT16
  elif arrow_dtype == pa.float32():
    return DataType.FLOAT32
  elif arrow_dtype == pa.float64():
    return DataType.FLOAT64
  # Time.
  elif pa.types.is_time(arrow_dtype):
    return DataType.TIME
  elif pa.types.is_date(arrow_dtype):
    return DataType.DATE
  elif pa.types.is_timestamp(arrow_dtype):
    return DataType.TIMESTAMP
  elif pa.types.is_duration(arrow_dtype):
    return DataType.INTERVAL
  # Others.
  elif arrow_dtype == pa.string():
    return DataType.STRING
  elif pa.types.is_binary(arrow_dtype) or pa.types.is_fixed_size_binary(arrow_dtype):
    return DataType.BINARY
  elif pa.types.is_boolean(arrow_dtype):
    return DataType.BOOLEAN
  elif arrow_dtype == pa.null():
    return DataType.NULL
  else:
    raise ValueError(f'Can not convert arrow dtype "{arrow_dtype}" to our dtype')


def arrow_schema_to_schema(schema: pa.Schema) -> Schema:
  """Convert arrow schema to our schema."""
  # TODO(nsthorat): Change this implementation to allow more complicated reading of arrow schemas
  # into our schema by inferring values when {__value__: value} is present in the pyarrow schema.
  # This isn't necessary today as this util is only needed by sources which do not have data in the
  # lilac format.
  return cast(Schema, _arrow_schema_to_schema_impl(schema))


def _arrow_schema_to_schema_impl(schema: Union[pa.Schema, pa.DataType]) -> Union[Schema, Field]:
  """Convert an apache arrow schema to our schema."""
  if isinstance(schema, (pa.Schema, pa.StructType)):
    fields: dict[str, Field] = {
      field.name: cast(Field, _arrow_schema_to_schema_impl(field.type)) for field in schema
    }
    return Schema(fields=fields) if isinstance(schema, pa.Schema) else Field(fields=fields)
  elif isinstance(schema, pa.ListType):
    return Field(repeated_field=cast(Field, _arrow_schema_to_schema_impl(schema.value_field.type)))
  else:
    return Field(dtype=arrow_dtype_to_dtype(schema))


def is_float(dtype: DataType) -> bool:
  """Check if a dtype is a float dtype."""
  return dtype in [DataType.FLOAT16, DataType.FLOAT32, DataType.FLOAT64]


def is_integer(dtype: DataType) -> bool:
  """Check if a dtype is an integer dtype."""
  return dtype in [
    DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16,
    DataType.UINT32, DataType.UINT64
  ]


def is_temporal(dtype: DataType) -> bool:
  """Check if a dtype is a temporal dtype."""
  return dtype in [DataType.TIME, DataType.DATE, DataType.TIMESTAMP, DataType.INTERVAL]


def is_ordinal(dtype: DataType) -> bool:
  """Check if a dtype is an ordinal dtype."""
  return is_float(dtype) or is_integer(dtype) or is_temporal(dtype)