|
import json |
|
import subprocess |
|
import yaml |
|
import os |
|
from .bucketeer import Bucketeer |
|
|
|
class MultiFilter(): |
|
def __init__(self, rules, default=False): |
|
self.rules = rules |
|
self.default = default |
|
|
|
def __call__(self, x): |
|
try: |
|
x_json = x['json'] |
|
if isinstance(x_json, bytes): |
|
x_json = json.loads(x_json) |
|
validations = [] |
|
for k, r in self.rules.items(): |
|
if isinstance(k, tuple): |
|
v = r(*[x_json[kv] for kv in k]) |
|
else: |
|
v = r(x_json[k]) |
|
validations.append(v) |
|
return all(validations) |
|
except Exception: |
|
return False |
|
|
|
class MultiGetter(): |
|
def __init__(self, rules): |
|
self.rules = rules |
|
|
|
def __call__(self, x_json): |
|
if isinstance(x_json, bytes): |
|
x_json = json.loads(x_json) |
|
outputs = [] |
|
for k, r in self.rules.items(): |
|
if isinstance(k, tuple): |
|
v = r(*[x_json[kv] for kv in k]) |
|
else: |
|
v = r(x_json[k]) |
|
outputs.append(v) |
|
if len(outputs) == 1: |
|
outputs = outputs[0] |
|
return outputs |
|
|
|
def setup_webdataset_path(paths, cache_path=None): |
|
if cache_path is None or not os.path.exists(cache_path): |
|
tar_paths = [] |
|
if isinstance(paths, str): |
|
paths = [paths] |
|
for path in paths: |
|
if path.strip().endswith(".tar"): |
|
|
|
tar_paths.append(path) |
|
continue |
|
bucket = "/".join(path.split("/")[:3]) |
|
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) |
|
files = result.stdout.decode('utf-8').split() |
|
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] |
|
tar_paths += files |
|
|
|
with open(cache_path, 'w', encoding='utf-8') as outfile: |
|
yaml.dump(tar_paths, outfile, default_flow_style=False) |
|
else: |
|
with open(cache_path, 'r', encoding='utf-8') as file: |
|
tar_paths = yaml.safe_load(file) |
|
|
|
tar_paths_str = ",".join([f"{p}" for p in tar_paths]) |
|
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" |
|
|