ultpixgen / core /data /__init__.py
roubaofeipi's picture
Upload 100 files
5231633 verified
raw
history blame
2.35 kB
import json
import subprocess
import yaml
import os
from .bucketeer import Bucketeer
class MultiFilter():
def __init__(self, rules, default=False):
self.rules = rules
self.default = default
def __call__(self, x):
try:
x_json = x['json']
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
validations = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
validations.append(v)
return all(validations)
except Exception:
return False
class MultiGetter():
def __init__(self, rules):
self.rules = rules
def __call__(self, x_json):
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
outputs = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
outputs.append(v)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
def setup_webdataset_path(paths, cache_path=None):
if cache_path is None or not os.path.exists(cache_path):
tar_paths = []
if isinstance(paths, str):
paths = [paths]
for path in paths:
if path.strip().endswith(".tar"):
# Avoid looking up s3 if we already have a tar file
tar_paths.append(path)
continue
bucket = "/".join(path.split("/")[:3])
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
files = result.stdout.decode('utf-8').split()
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
tar_paths += files
with open(cache_path, 'w', encoding='utf-8') as outfile:
yaml.dump(tar_paths, outfile, default_flow_style=False)
else:
with open(cache_path, 'r', encoding='utf-8') as file:
tar_paths = yaml.safe_load(file)
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"