Efficient AI4EO OpenSource framework
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
6.5KB

  1. import os
  2. import sys
  3. import glob
  4. import toml
  5. from importlib import import_module
  6. import re
  7. import colorsys
  8. import webcolors
  9. from pathlib import Path
  10. from neat_eo.tiles import tile_pixel_to_location, tiles_to_geojson
  11. #
  12. # Import module
  13. #
  14. def load_module(module):
  15. module = import_module(module)
  16. assert module, "Unable to import module {}".format(module)
  17. return module
  18. #
  19. # Config
  20. #
  21. def load_config(path):
  22. """Loads a dictionary from configuration file."""
  23. if not path and "NEO_CONFIG" in os.environ:
  24. path = os.environ["NEO_CONFIG"]
  25. if not path and os.path.isfile(os.path.expanduser("~/.neo_config")):
  26. path = "~/.neo_config"
  27. assert path, "Either ~/.neo_config or NEO_CONFIG env var or --config parameter, is required."
  28. config = toml.load(os.path.expanduser(path))
  29. assert config, "Unable to parse config file"
  30. # Set default values
  31. if "model" not in config.keys():
  32. config["model"] = {}
  33. if "ts" not in config["model"].keys():
  34. config["model"]["ts"] = (512, 512)
  35. if "train" not in config.keys():
  36. config["train"] = {}
  37. if "pretrained" not in config["train"].keys():
  38. config["train"]["pretrained"] = True
  39. if "bs" not in config["train"].keys():
  40. config["train"]["bs"] = 4
  41. if "auth" not in config.keys():
  42. config["auth"] = {}
  43. if "da" in config["train"].keys():
  44. config["train"]["da"] = dict(config["train"]["da"]) # dict is serializable
  45. if "optimizer" in config["train"].keys():
  46. config["train"]["optimizer"] = dict(config["train"]["optimizer"]) # dict is serializable
  47. else:
  48. config["train"]["optimizer"] = {"name": "Adam", "lr": 0.0001}
  49. assert "classes" in config.keys(), "CONFIG: Classes are mandatory"
  50. for c, classe in enumerate(config["classes"]):
  51. config["classes"][c]["weight"] = config["classes"][c]["weight"] if "weight" in config["classes"][c].keys() else 1.0
  52. if config["classes"][c]["color"] == "transparent" and "weight" not in config["classes"][c].keys():
  53. config["classes"][c]["weight"] = 0.0
  54. return config
  55. def check_channels(config):
  56. assert "channels" in config.keys(), "CONFIG: At least one Channel is mandatory"
  57. # TODO
  58. def check_classes(config):
  59. """Check if config file classes subpart is consistent. Exit on error if not."""
  60. assert "classes" in config.keys() and len(config["classes"]) >= 2, "CONFIG: At least 2 Classes are mandatory"
  61. for classe in config["classes"]:
  62. assert "title" in classe.keys() and len(classe["title"]), "CONFIG: Missing or Empty classes.title value"
  63. assert "color" in classe.keys() and check_color(classe["color"]), "CONFIG: Missing or Invalid classes.color value"
  64. def check_model(config):
  65. assert "model" in config.keys(), "CONFIG: Missing or Invalid model"
  66. # TODO
  67. #
  68. # Logs
  69. #
  70. class Logs:
  71. def __init__(self, path, out=sys.stderr):
  72. """Create a logs instance on a logs file."""
  73. self.fp = None
  74. self.out = out
  75. if path:
  76. if not os.path.isdir(os.path.dirname(path)):
  77. os.makedirs(os.path.dirname(path), exist_ok=True)
  78. self.fp = open(path, mode="a")
  79. def log(self, msg):
  80. """Log a new message to the opened logs file, and optionnaly on stdout or stderr too."""
  81. if self.fp:
  82. self.fp.write(msg + os.linesep)
  83. self.fp.flush()
  84. if self.out:
  85. print(msg, file=self.out)
  86. #
  87. # Colors
  88. #
  89. def make_palette(colors, complementary=False):
  90. """Builds a PNG PIL color palette from Classes CSS3 color names, or hex values patterns as #RRGGBB."""
  91. assert 0 < len(colors) < 256
  92. try:
  93. transparency = [key for key, color in enumerate(colors) if color == "transparent"][0]
  94. except:
  95. transparency = None
  96. colors = ["white" if color.lower() == "transparent" else color for color in colors]
  97. hex_colors = [webcolors.CSS3_NAMES_TO_HEX[color.lower()] if color[0] != "#" else color for color in colors]
  98. rgb_colors = [(int(h[1:3], 16), int(h[3:5], 16), int(h[5:7], 16)) for h in hex_colors]
  99. palette = list(sum(rgb_colors, ())) # flatten
  100. palette = palette if not complementary else complementary_palette(palette)
  101. return palette, transparency
  102. def complementary_palette(palette):
  103. """Creates a PNG PIL complementary colors palette based on an initial PNG PIL palette."""
  104. comp_palette = []
  105. colors = [palette[i : i + 3] for i in range(0, len(palette), 3)]
  106. for color in colors:
  107. r, g, b = [v for v in color]
  108. h, s, v = colorsys.rgb_to_hsv(r, g, b)
  109. comp_palette.extend(map(int, colorsys.hsv_to_rgb((h + 0.5) % 1, s, v)))
  110. return comp_palette
  111. def check_color(color):
  112. """Check if an input color is or not valid (i.e CSS3 color name, transparent, or #RRGGBB)."""
  113. color = "white" if color.lower() == "transparent" else color
  114. hex_color = webcolors.CSS3_NAMES_TO_HEX[color.lower()] if color[0] != "#" else color
  115. return bool(re.match(r"^#([0-9a-fA-F]){6}$", hex_color))
  116. #
  117. # Web UI
  118. #
  119. def web_ui(out, base_url, coverage_tiles, selected_tiles, ext, template, union_tiles=True):
  120. out = os.path.expanduser(out)
  121. template = os.path.expanduser(template)
  122. templates = glob.glob(os.path.join(Path(__file__).parent, "web_ui", "*"))
  123. if os.path.isfile(template):
  124. templates.append(template)
  125. if os.path.lexists(os.path.join(out, "index.html")):
  126. os.remove(os.path.join(out, "index.html")) # if already existing output dir, as symlink can't be overwriten
  127. os.symlink(os.path.basename(template), os.path.join(out, "index.html"))
  128. def process_template(template):
  129. web_ui = open(template, "r").read()
  130. web_ui = re.sub("{{base_url}}", base_url, web_ui)
  131. web_ui = re.sub("{{ext}}", ext, web_ui)
  132. web_ui = re.sub("{{tiles}}", "tiles.json" if selected_tiles else "''", web_ui)
  133. if coverage_tiles:
  134. tile = list(coverage_tiles)[0] # Could surely be improved, but for now, took the first tile to center on
  135. x, y, z = map(int, [tile.x, tile.y, tile.z])
  136. web_ui = re.sub("{{zoom}}", str(z), web_ui)
  137. web_ui = re.sub("{{center}}", str(list(tile_pixel_to_location(tile, 0.5, 0.5))[::-1]), web_ui)
  138. with open(os.path.join(out, os.path.basename(template)), "w", encoding="utf-8") as fp:
  139. fp.write(web_ui)
  140. for template in templates:
  141. process_template(template)
  142. if selected_tiles:
  143. with open(os.path.join(out, "tiles.json"), "w", encoding="utf-8") as fp:
  144. fp.write(tiles_to_geojson(selected_tiles, union_tiles))