Browse Source

rasterize geojson with PyTorch dataset, cf #1

master
Olivier Courtin 4 months ago
parent
commit
09a3265d8d
1 changed files with 98 additions and 68 deletions
  1. +98
    -68
      neat_eo/tools/rasterize.py

+ 98
- 68
neat_eo/tools/rasterize.py View File

@@ -1,6 +1,7 @@
import os
import re
import sys
import math
import json
import collections

@@ -10,6 +11,7 @@ from functools import partial
import concurrent.futures as futures

import psycopg2
from torch.utils.data import DataLoader, Dataset

from neat_eo.core import load_config, check_classes, make_palette, web_ui, Logs
from neat_eo.tiles import tiles_from_csv, tile_label_to_file, tile_bbox
@@ -66,6 +68,78 @@ def worker_spatial_index(zoom, buffer, add_progress, geojson_path):
return feature_map


class CreateLabelsDataset(Dataset):
def __init__(self, tiles, feature_map, args, config, mode):
super().__init__()
assert mode in ["geojson", "postgis"]
self.mode = mode
self.tiles = tiles
self.args = args
self.feature_map = feature_map
self.ts_shape = list(map(int, args.ts.split(",")))

self.palette, self.transparency = make_palette(
[classe["color"] for classe in config["classes"]], complementary=not (args.original_color)
)
index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type]
assert index, "Requested type is not contains in your config file classes."
self.burn_value = index[0]
assert 0 < self.burn_value <= 255

if mode == "postgis":
conn = psycopg2.connect(args.pg)
self.db = conn.cursor()
sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I)
self.db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql))
self.srid = self.db.fetchone()[0]
assert self.srid and int(self.srid) > 0, "Unable to retrieve geometry SRID."
conn.close()

def __len__(self):
return len(self.tiles)

def __getitem__(self, i):
tile = self.tiles[i]

if self.mode == "postgis":
w, s, e, n = tile_bbox(tile)
tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, self.srid)

query = """
WITH
sql AS ({}),
geom AS (SELECT "1" AS geom FROM sql AS t("1")),
json AS (SELECT '{{"type": "Feature", "geometry": '
|| ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
|| '}}' AS features
FROM geom)
SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
FROM json
""".format(
self.args.sql.replace("TILE_GEOM", tile_geom)
)

conn = psycopg2.connect(self.args.pg) # TODO use a connection pooler
db = conn.cursor()
db.execute(query)
row = db.fetchone()
features = json.loads(row[0])["features"] if row and row[0] else list()

if self.mode == "geojson":
try:
features = self.feature_map[tile]
except KeyError:
features = None

if not features:
label = np.zeros(shape=self.ts_shape, dtype=np.uint8)
else:
label = geojson_tile_burn(tile, features, 4326, self.ts_shape, self.burn_value)

tile_label_to_file(self.args.out, tile, self.palette, self.transparency, label, append=self.args.append)
return (len(features), tile.x, tile.y, tile.z)


def main(args):

assert not (args.geojson is not None and args.pg is not None), "You have to choose between --pg or --geojson"
@@ -78,13 +152,6 @@ def main(args):
args.pg = config["auth"]["pg"] if not args.pg and "pg" in config["auth"].keys() else args.pg
assert not (args.sql and not args.pg), "With --sql option, --pg dsn setting must also be provided"

complementary = not (args.original_color)
palette, transparency = make_palette([classe["color"] for classe in config["classes"]], complementary=complementary)
index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type]
assert index, "Requested type is not contains in your config file classes."
burn_value = index[0]
assert 0 < burn_value <= 255

if args.sql:
assert "limit" not in args.sql.lower(), "LIMIT is not supported"
assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL"
@@ -98,8 +165,10 @@ def main(args):

tiles = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))]
assert len(tiles), "Empty Cover: {}".format(args.cover)
feature_map = collections.defaultdict(list)

if args.geojson:
mode = "geojson"
zoom = tiles[0].z
assert not [tile for tile in tiles if tile.z != zoom], "Unsupported zoom mixed cover. Use PostGIS instead"

@@ -112,11 +181,9 @@ def main(args):
progress = tqdm(total=len(args.geojson), ascii=True, unit="file")
log_from = "{} geojson files".format(len(args.geojson))

feature_map = collections.defaultdict(list)
with futures.ProcessPoolExecutor(workers) as executor:
for fm in executor.map(
partial(worker_spatial_index, zoom, args.buffer, True if progress is None else False), args.geojson
):
add_progress = True if progress is None else False
for fm in executor.map(partial(worker_spatial_index, zoom, args.buffer, add_progress), args.geojson):
for k, v in fm.items():
try:
feature_map[k] += v
@@ -127,69 +194,32 @@ def main(args):
if progress:
progress.close()

if args.sql:
conn = psycopg2.connect(args.pg)
db = conn.cursor()

db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql))
srid = db.fetchone()[0]
assert srid and int(srid) > 0, "Unable to retrieve geometry SRID."
if not len(feature_map):
log.log("-----------------------------------------------")
log.log("NOTICE: no feature to rasterize, seems peculiar")
log.log("-----------------------------------------------")

if args.sql:
mode = "postgis"
log_from = args.sql
workers = math.ceil(args.workers / 2)

if not len(feature_map):
log.log("-----------------------------------------------")
log.log("NOTICE: no feature to rasterize, seems peculiar")
log.log("-----------------------------------------------")

log.log("neo rasterize - rasterizing {} from {} on cover {}".format(args.type, log_from, args.cover))
log.log(
"neo rasterize - rasterizing {} from {} on cover {}, with {} tiles/batch and {} workers".format(
args.type, log_from, args.cover, workers, workers
)
)
label_dataset = CreateLabelsDataset(tiles, feature_map, args, config, mode)
loader = DataLoader(label_dataset, batch_size=workers, num_workers=workers)
with open(os.path.join(os.path.expanduser(args.out), args.type.lower() + "_cover.csv"), mode="w") as cover:

for tile in tqdm(tiles, ascii=True, unit="tile"):

geojson = None

if args.sql:
w, s, e, n = tile_bbox(tile)
tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, srid)

query = """
WITH
sql AS ({}),
geom AS (SELECT "1" AS geom FROM sql AS t("1")),
json AS (SELECT '{{"type": "Feature", "geometry": '
|| ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
|| '}}' AS features
FROM geom)
SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
FROM json
""".format(
args.sql.replace("TILE_GEOM", tile_geom)
for n, x, y, z, in tqdm(loader, desc="Rasterize", unit="batch", ascii=True):
for i in range(len(n)):
cover.write(
"{},{},{} {}{}".format(
x[i].data.numpy(), y[i].data.numpy(), z[i].data.numpy(), n[i].data.numpy(), os.linesep
)
)

db.execute(query)
row = db.fetchone()
try:
geojson = json.loads(row[0])["features"] if row and row[0] else None
except Exception:
log.log("Warning: Invalid geometries, skipping {}".format(tile))
conn = psycopg2.connect(args.pg)
db = conn.cursor()

if args.geojson:
geojson = feature_map[tile] if tile in feature_map else None

if geojson:
num = len(geojson)
out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value)

if not geojson or out is None:
num = 0
out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8)

tile_label_to_file(args.out, tile, palette, transparency, out, append=args.append)
cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep))

if not args.no_web_ui:
template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
base_url = args.web_ui_base_url if args.web_ui_base_url else "."


Loading…
Cancel
Save