Plonk / data /to_webdataset /rebalance_csv.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import csv
import os
import sys
import glob
import tqdm
def split_csv_files(input_files, output_dir, lines_per_file=100000):
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Initialize counters
total_lines = 0
file_count = 0
current_line_count = 0
# Initialize the first output file
output_file = os.path.join(output_dir, f"{str(file_count).zfill(3)}.csv")
output_writer = open(output_file, "w", newline="")
csv_writer = None
try:
for file_path in tqdm.tqdm(input_files, desc="Processing files"):
with open(file_path, "r") as csv_file:
csv_reader = csv.reader(csv_file)
# Initialize writer once we have the header row
if csv_writer is None:
header = next(csv_reader)
csv_writer = csv.writer(output_writer)
csv_writer.writerow(header)
# Process each line in the current file
for row in csv_reader:
if current_line_count >= lines_per_file:
# Close the current file and start a new one
output_writer.close()
file_count += 1
current_line_count = 0
output_file = os.path.join(
output_dir, f"{str(file_count).zfill(3)}.csv"
)
output_writer = open(output_file, "w", newline="")
csv_writer = csv.writer(output_writer)
csv_writer.writerow(header) # Write header to new file
# Write row to the current output file
csv_writer.writerow(row)
current_line_count += 1
total_lines += 1
finally:
# Close the last output file
if output_writer:
output_writer.close()
print(f"Total lines processed: {total_lines}")
print(f"Files created: {file_count + 1}")
if __name__ == "__main__":
input_dir = "../datasets/YFCC100M/yfcc100m_dataset_with_gps_train"
output_dir = "../datasets/YFCC100M/yfcc100m_dataset_with_gps_train_balanced"
lines_per_file = 100000
# Get all CSV files in input directory
input_files = glob.glob(os.path.join(input_dir, "*.csv"))
if not input_files:
print(f"No CSV files found in {input_dir}")
sys.exit(1)
print(f"Found {len(input_files)} CSV files")
split_csv_files(input_files, output_dir, lines_per_file)