nikhil_no_persistent / lilac /parquet_writer.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
"""A Parquet file writer that wraps the pyarrow writer."""
from typing import IO, Optional
import pyarrow as pa
import pyarrow.parquet as pq
from .schema import Item, Schema, schema_to_arrow_schema
class ParquetWriter:
"""A writer to parquet."""
def __init__(self,
schema: Schema,
codec: str = 'snappy',
row_group_buffer_size: int = 128 * 1024 * 1024,
record_batch_size: int = 10_000):
self._schema = schema_to_arrow_schema(schema)
self._codec = codec
self._row_group_buffer_size = row_group_buffer_size
self._buffer: list[list[Optional[Item]]] = [[] for _ in range(len(self._schema.names))]
self._buffer_size = record_batch_size
self._record_batches: list[pa.RecordBatch] = []
self._record_batches_byte_size = 0
self.writer: pq.ParquetWriter = None
def open(self, file_handle: IO) -> None:
"""Open the destination file for writing."""
self.writer = pq.ParquetWriter(file_handle, self._schema, compression=self._codec)
def write(self, record: Item) -> None:
"""Write the record to the destination file."""
if len(self._buffer[0]) >= self._buffer_size:
self._flush_buffer()
if self._record_batches_byte_size >= self._row_group_buffer_size:
self._write_batches()
# reorder the data in columnar format.
for i, n in enumerate(self._schema.names):
self._buffer[i].append(record.get(n))
def close(self) -> None:
"""Flushes the write buffer and closes the destination file."""
if len(self._buffer[0]) > 0:
self._flush_buffer()
if self._record_batches_byte_size > 0:
self._write_batches()
self.writer.close()
def _write_batches(self) -> None:
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
self._record_batches = []
self._record_batches_byte_size = 0
self.writer.write_table(table)
def _flush_buffer(self) -> None:
arrays: list[pa.array] = [[] for _ in range(len(self._schema.names))]
for x, y in enumerate(self._buffer):
arrays[x] = pa.array(y, type=self._schema.types[x])
self._buffer[x] = []
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
self._record_batches.append(rb)
size = 0
for x in arrays:
for b in x.buffers(): # type: ignore
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size