import csv import os import sys import glob import tqdm def split_csv_files(input_files, output_dir, lines_per_file=100000): # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Initialize counters total_lines = 0 file_count = 0 current_line_count = 0 # Initialize the first output file output_file = os.path.join(output_dir, f"{str(file_count).zfill(3)}.csv") output_writer = open(output_file, "w", newline="") csv_writer = None try: for file_path in tqdm.tqdm(input_files, desc="Processing files"): with open(file_path, "r") as csv_file: csv_reader = csv.reader(csv_file) # Initialize writer once we have the header row if csv_writer is None: header = next(csv_reader) csv_writer = csv.writer(output_writer) csv_writer.writerow(header) # Process each line in the current file for row in csv_reader: if current_line_count >= lines_per_file: # Close the current file and start a new one output_writer.close() file_count += 1 current_line_count = 0 output_file = os.path.join( output_dir, f"{str(file_count).zfill(3)}.csv" ) output_writer = open(output_file, "w", newline="") csv_writer = csv.writer(output_writer) csv_writer.writerow(header) # Write header to new file # Write row to the current output file csv_writer.writerow(row) current_line_count += 1 total_lines += 1 finally: # Close the last output file if output_writer: output_writer.close() print(f"Total lines processed: {total_lines}") print(f"Files created: {file_count + 1}") if __name__ == "__main__": input_dir = "../datasets/YFCC100M/yfcc100m_dataset_with_gps_train" output_dir = "../datasets/YFCC100M/yfcc100m_dataset_with_gps_train_balanced" lines_per_file = 100000 # Get all CSV files in input directory input_files = glob.glob(os.path.join(input_dir, "*.csv")) if not input_files: print(f"No CSV files found in {input_dir}") sys.exit(1) print(f"Found {len(input_files)} CSV files") split_csv_files(input_files, output_dir, lines_per_file)