Spaces:
Running
Running
File size: 4,561 Bytes
5129aaa 196b164 5129aaa 196b164 5129aaa 196b164 5129aaa d91013d 5129aaa d91013d 5129aaa d91013d 5129aaa 196b164 5129aaa 196b164 5129aaa 196b164 5129aaa 196b164 5129aaa 196b164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from datasets import load_dataset as _load_dataset
from os import environ
from PIL import Image
import numpy as np
import json
from pyarrow.parquet import ParquetFile
from pyarrow import Table as pa_Table
from datasets import Dataset
DATASET = "satellogic/EarthView"
sets = {
"satellogic": {
"shards" : 3676,
},
"sentinel_1": {
"shards" : 1763,
},
"neon": {
"config" : "default",
"shards" : 607,
"path" : "data",
}
}
def get_subsets():
return sets.keys()
def get_nshards(subset):
return sets[subset]["shards"]
def get_path(subset):
return sets[subset].get("path", subset)
def get_config(subset):
return sets[subset].get("config", subset)
def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs):
config = get_config(subset)
nshards = get_nshards(subset)
path = get_path(subset)
if shards is None:
data_files = None
else:
data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards]
data_files = {split: data_files}
ds = _load_dataset(
path=dataset,
name=config,
save_infos=True,
split=split,
data_files=data_files,
streaming=streaming,
token=environ.get("HF_TOKEN", None),
**kwargs)
return ds
def load_parquet(subset_or_filename, batch_size=100):
if subset_or_filename in get_subsets():
filename = f"dataset/{subset_or_filename}/sample.parquet"
else:
filename = subset_or_filename
pqfile = ParquetFile(filename)
batch = pqfile.iter_batches(batch_size=batch_size)
return Dataset(pa_Table.from_batches(batch))
def item_to_images(subset, item):
"""
Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image
subset: The name of the Subset, one of "satellogic", "default", "sentinel-1"
item: The item as retrieved from the subset
returns the item, with arrays converted to PIL.Image
"""
metadata = item["metadata"]
if type(metadata) == str:
metadata = json.loads(metadata)
item = {
k: np.asarray(v).astype("uint8")
for k,v in item.items()
if k != "metadata"
}
item["metadata"] = metadata
if subset == "satellogic":
# item["rgb"] = [
# Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8"))
# for image in item["rgb"]
# ]
rgbs = []
for rgb in item["rgb"]:
rgbs.append(Image.fromarray(rgb.transpose(1,2,0)))
# rgbs.append(Image.fromarray(rgb[0,:,:])) # Red
# rgbs.append(Image.fromarray(rgb[1,:,:])) # Green
# rgbs.append(Image.fromarray(rgb[2,:,:])) # Blue
item["rgb"] = rgbs
item["1m"] = [
Image.fromarray(image[0,:,:])
for image in item["1m"]
]
count = len(item["1m"])
elif subset == "sentinel_1":
# Mapping of V and H to RGB. May not be correct
# https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
i10m = item["10m"]
i10m = np.concatenate(
( i10m,
np.expand_dims(
i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256,
1
).astype("uint8")
),
1
)
item["10m"] = [
Image.fromarray(image.transpose(1,2,0))
for image in i10m
]
count = len(item["10m"])
elif subset == "neon":
item["rgb"] = [
Image.fromarray(image.transpose(1,2,0))
for image in item["rgb"]
]
item["chm"] = [
Image.fromarray(image[0])
for image in item["chm"]
]
# The next is a very arbitrary conversion from the 369 hyperspectral data to RGB
# It just averages each 1/3 of the bads and assigns it to a channel
item["1m"] = [
Image.fromarray(
np.concatenate((
np.expand_dims(np.average(image[:124],0),2),
np.expand_dims(np.average(image[124:247],0),2),
np.expand_dims(np.average(image[247:],0),2))
,2).astype("uint8"))
for image in item["1m"]
]
count = len(item["rgb"])
item["metadata"]["count"] = count
return item
|