import shutil from dataclasses import dataclass, field from pathlib import Path from typing import Generator, Callable, Iterator, Iterable, TYPE_CHECKING import json from PIL import Image from source import threaded from source.mkw import Tag from source.mkw.Cup import Cup from source.mkw.MKWColor import bmg_color_text, bmg_color_raw from source.mkw.ModSettings import AbstractModSettings from source.mkw.Track import CustomTrack, DefaultTrack, Arena from source.mkw import OriginalTrack from source.progress import Progress from source.safe_eval import safe_eval, multiple_safe_eval from source.wt.szs import SZSPath from source.translation import translate as _ if TYPE_CHECKING: from source import TemplateMultipleSafeEval, TemplateSafeEval, Env CT_ICON_SIZE: int = 128 default_global_settings: dict[str, dict[str, str]] = { "replace_random_new": { "text": { "en": "Replace random new tracks by", "fr": "Remplacer les courses aléatoires nouvelle par" }, "type": "string", "preview": "track_selecting" }, "include_track_if": { "text": { "en": "Include track if", "fr": "Inclure la course si" }, "type": "string", "preview": "track_selecting" }, "sort_tracks": { "text": { "en": "Sort tracks by", "fr": "Trier les courses par" }, "type": "string", "preview": "track_sorting" } } def merge_dict(dict1: dict[str, dict] | None, dict2: dict[str, dict] | None, dict_keys: Iterable[str] = None) -> dict[str, dict]: """ Merge 2 dict subdict together :return: the merged dict { "option": {"speed": 1} }, { "option": {"mode": "hard"} } -> { "option": {"speed": 1, "mode": "hard"} } """ if dict1 is None: dict1 = {} if dict2 is None: dict2 = {} if dict_keys is None: dict_keys = dict1.keys() | dict2.keys() return {key: dict1.get(key, {}) | dict2.get(key, {}) for key in dict_keys} @dataclass(init=True, slots=True) class ModConfig: """ Representation of a mod """ path: Path | str name: str nickname: str = None version: str = "v1.0.0" variant: str = "01" _tracks: list["Track | TrackGroup"] = field(default_factory=list) _arenas: list["Arena"] = field(default_factory=list) track_file_template: "TemplateMultipleSafeEval" = "{{ getattr(track, 'sha1', '_') }}" multiplayer_disable_if: "TemplateSafeEval" = "False" tags_cups: list[Tag] = field(default_factory=list) tags_templates: dict[str, "TemplateMultipleSafeEval"] = field(default_factory=dict) original_track_prefix: bool = True swap_original_order: bool = True keep_original_track: bool = True enable_random_cup: bool = True macros: dict[str, "TemplateSafeEval"] = field(default_factory=dict) messages: dict[str, dict[str, "TemplateMultipleSafeEval"]] = field(default_factory=dict) global_settings: dict[str, "AbstractModSettings"] = field(default_factory=dict) specific_settings: dict[str, "AbstractModSettings"] = field(default_factory=dict) lpar_template: "TemplateMultipleSafeEval" = "normal.lpar" def __post_init__(self): self.path = Path(self.path) if self.nickname is None: self.nickname = self.name def __hash__(self) -> int: return hash(self.name) def __repr__(self) -> str: return f"" def __str__(self) -> str: return f"{self.name} {self.version}" @classmethod def from_dict(cls, path: Path | str, config_dict: dict, macros: dict, messages: dict) -> "ModConfig": """ Create a ModConfig from a dict :param messages: messages that can be shown to the user at some moment :param path: path of the mod_config.json :param config_dict: dict containing the configuration :param macros: macro that can be used for safe_eval :return: ModConfig """ kwargs = config_dict | { "path": path, "_tracks": [CustomTrack.from_dict(track) for track in config_dict.pop("tracks", [])], "_arenas": [Arena.from_dict(arena) for arena in config_dict.pop("arenas", [])], "macros": macros, "messages": messages, "global_settings": {name: AbstractModSettings.get(data) for name, data in merge_dict( default_global_settings, config_dict.get("global_settings", {}), dict_keys=default_global_settings.keys() # Avoid modder to add their own settings to globals one ).items()}, "specific_settings": {name: AbstractModSettings.get(data) for name, data in config_dict.get( "specific_settings", {} ).items()}, } return cls(**kwargs) @classmethod def from_file(cls, config_file: str | Path) -> "ModConfig": """ Create a ModConfig from a file :param config_file: file containing the configuration :return: ModConfig """ config_file = Path(config_file) macros_file = config_file.parent / "macros.json" messages_file = config_file.parent / "messages.json" return cls.from_dict( path=config_file, config_dict=json.loads(config_file.read_text(encoding="utf8")), macros=json.loads(macros_file.read_text(encoding="utf8")) if macros_file.exists() else None, messages=json.loads(messages_file.read_text(encoding="utf8")) if messages_file.exists() else None, ) def get_safe_eval_env(self, base_env: "Env" = None) -> dict[str, any]: """ Return the env for the modconfig safe_eval function :param base_env: additional environment :return: the modconfig environment """ return { "mod_config": self, "bmg_color_raw": bmg_color_raw, "bmg_color_text": bmg_color_text } | ( base_env if base_env is not None else {} ) def safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable: """ Safe eval with useful modconfig environment :return: the result of the evaluation """ return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs) def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable: """ Multiple safe eval with useful modconfig environment :return: the str result of the evaluation """ return multiple_safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs) def get_mod_directory(self) -> Path: """ Get the directory of the mod :return: directory of the mod """ return self.path.parent def get_all_arenas_tracks(self, *args, **kwargs) -> Generator["CustomTrack", None, None]: """ Same as get_all_tracks, but arenas are included """ yield from self.get_all_tracks(*args, **kwargs) yield from self.get_arenas() def get_arenas(self) -> Generator["Arena", None, None]: """ Yield all arenas of the mod """ yield from self._arenas def get_all_tracks(self, *args, **kwargs) -> Generator["CustomTrack", None, None]: """ Same as get_tracks, but track group are divided into subtracks """ for track in self.get_tracks(*args, **kwargs): yield from track.get_tracks() def get_tracks(self, ignore_filter: bool = False, sorting_template: str = None, ignore_sorting: bool = False) -> Generator["Track | TrackGroup", None, None]: """ Get all the tracks or tracks groups elements :ignore_filter: should the tracks filter be ignored :return: track or tracks groups elements """ filter_template: "TemplateSafeEval | None" = self.global_settings["include_track_if"].value \ if not ignore_filter else None settings_sort: "TemplateSafeEval | None" = self.global_settings["sort_tracks"].value # filter_template_func is the function checking if the track should be included. If no parameter is set, # then always return True filter_template_func: Callable = self.safe_eval( filter_template if filter_template is not None else "True", args=["track"] ) # if a sorting function is set, use it. If no function is set, but sorting is not disabled, use settings. iterator: Iterator = filter(lambda track: filter_template_func(track=track) is True, self._tracks) if not ignore_sorting and (sorting_template is not None or settings_sort is not None): # get the sorting_template_func. If not defined, use the settings one. sorting_template_func: Callable = self.safe_eval( template=sorting_template if sorting_template is not None else settings_sort, args=["track"] ) # wrap the iterator inside a sort function iterator = sorted(iterator, key=sorting_template_func) # Go though all the tracks and filter them if enabled for track in iterator: yield track def get_ordered_cups(self) -> Generator["Cup", None, None]: """ Get all the cups with cup tags :return: cups with cup tags """ # for track that have a tag in self.tags_cups for tag_cup in self.tags_cups: track_buffer: "Track | TrackGroup" = [] current_tag_name, current_tag_count = tag_cup, 0 # every four 4 tracks, create a cup for track in filter(lambda track: tag_cup in getattr(track, "tags", []), self.get_tracks()): track_buffer.append(track) if len(track_buffer) > 4: current_tag_count += 1 yield Cup(tracks=track_buffer, cup_name=f"{current_tag_name}/{current_tag_count}") track_buffer = [] # if there is still tracks in the buffer, create a cup with them and fill with default> if len(track_buffer) > 0: track_buffer.extend([DefaultTrack()] * (4 - len(track_buffer))) yield Cup(tracks=track_buffer, cup_name=f"{current_tag_name}/{current_tag_count + 1}") def get_unordered_cups(self) -> Generator["Cup", None, None]: """ Get all the cups with no cup tags :return: cups with no cup tags """ # for track that have don't have a tag in self.tags_cups track_buffer: "Track | TrackGroup" = [] for track in filter( lambda track: not any(item in getattr(track, "tags", []) for item in self.tags_cups), self.get_tracks() ): track_buffer.append(track) if len(track_buffer) > 4: yield Cup(tracks=track_buffer) track_buffer = [] # if there is still tracks in the buffer, create a cup with them and fill with default if len(track_buffer) > 0: track_buffer.extend([DefaultTrack()] * (4 - len(track_buffer))) yield Cup(tracks=track_buffer) def get_cups(self) -> Generator["Cup", None, None]: """ Get all the cups :return: cups """ yield from self.get_ordered_cups() yield from self.get_unordered_cups() def get_ctfile(self, template: "TemplateMultipleSafeEval") -> str: """ Return the ct_file generated from the ModConfig :template: template for the track name :return: ctfile content """ lecode_flags = filter(lambda v: v is not None, [ "N$SHOW" if self.keep_original_track else "N$NONE", "N$F_WII" if self.original_track_prefix else None, "N$SWAP" if self.swap_original_order else None ]) ctfile = ( f"#CT-CODE\n" # magic number f"[RACING-TRACK-LIST]\n" # start of the track section f"%LE-FLAGS=1\n" # enable lecode mode f"%WIIMM-CUP={int(self.enable_random_cup)}\n" # enable random cup f"N {' | '.join(lecode_flags)}\n" # other flags to disable default tracks, ... f"\n" ) for cup in self.get_cups(): # get all the cup ctfile, use "-" for the template since the track's name are not used here ctfile += cup.get_ctfile(mod_config=self, template=template) ctfile_override_property = "[SETUP-ARENA]\n" for arena in self.get_arenas(): arena_definition, arena_override_property = arena.get_ctfile(mod_config=self, template=template) ctfile += arena_definition ctfile_override_property += arena_override_property ctfile += ctfile_override_property return ctfile def get_base_cticons(self) -> Generator[Image.Image, None, None]: """ Return the base cticon :return: """ icon_names = ["left", "right"] if self.keep_original_track: icon_names += [ f"_DEFAULT/{name}" for name in ( ["mushroom", "shell", "flower", "banana", "star", "leaf", "special", "lightning"] if self.swap_original_order else ["mushroom", "flower", "star", "special", "shell", "banana", "leaf", "lightning"] ) ] if self.enable_random_cup: icon_names.append("random") for icon_name in icon_names: yield Image.open(self.get_mod_directory() / f"_CUPS/{icon_name}.png").resize((CT_ICON_SIZE, CT_ICON_SIZE)) def get_custom_cticons(self) -> Generator[Image.Image, None, None]: """ Return the custom ct_icon generated from the ModConfig :return: """ for cup in self.get_cups(): yield cup.get_cticon(mod_config=self).resize((CT_ICON_SIZE, CT_ICON_SIZE)) def get_cticons(self) -> Generator[Image.Image, None, None]: """ Return all the ct_icon generated from the ModConfig :return: """ yield from self.get_base_cticons() yield from self.get_custom_cticons() @staticmethod def get_default_font() -> Path: """ Return the default font for creating ct_icons :return: the path to the default font file """ return Path("./assets/SuperMario256.ttf") def get_full_cticon(self) -> Image.Image: """ Return the full ct_icon generated from the ModConfig :return: """ cticons = list(self.get_cticons()) full_cticon = Image.new("RGBA", (CT_ICON_SIZE, CT_ICON_SIZE * len(cticons))) for i, cticon in enumerate(cticons): full_cticon.paste(cticon, (0, i * CT_ICON_SIZE)) return full_cticon def normalize_all_tracks(self, autoadd_path: "Path | str", destination_path: "Path | str", original_tracks_path: "Path | str", thread_amount: int = 8) -> Generator[Progress, None, None]: """ Convert all tracks of the mod to szs into the destination_path :param original_tracks_path: path to the originals tracks (if a track is disabled for multiplayer) :param thread_amount: number of thread to use :param autoadd_path: autoadd directory :param destination_path: destination where the files are converted """ yield Progress(description=_("NORMALIZING_TRACKS")) destination_path = Path(destination_path) original_tracks_path = Path(original_tracks_path) destination_path.mkdir(parents=True, exist_ok=True) normalize_threads: list[dict] = [] def remove_finished_threads() -> Generator[Progress, None, None]: """ Remove all the thread that stopped in a thread list :return: the list without the stopped thread """ nonlocal normalize_threads yield Progress( description=_("NORMALIZING_TRACKS", " :\n" + "\n".join(thread['name'] for thread in normalize_threads)) ) normalize_threads = list(filter(lambda thread: thread["thread"].is_alive(), normalize_threads)) track_directory = self.path.parent / "_TRACKS" multiplayer_disable_if_func: Callable = self.safe_eval( self.multiplayer_disable_if, args=["track"] ) for track in self.get_all_arenas_tracks(): track_file: Path = next( track_directory.rglob(f"{track.repr_format(self, self.track_file_template)}.*") ) @threaded def normalize_track(track: CustomTrack, track_file: Path): SZSPath(track_file).normalize( autoadd_path, destination_path / f"{track_file.stem}.szs", format="szs" ) if multiplayer_disable_if_func(track=track) is True: # if the track should use the default track instead in multiplayer, # copy the default track to the same file but with a _d at the end shutil.copy( original_tracks_path / f"{OriginalTrack.get(slot=track.special).name}_d.szs", destination_path / f"{track_file.stem}_d.szs" ) else: # delete the _d file if it exists (destination_path / f"{track_file.stem}_d.szs").unlink(missing_ok=True) normalize_threads.append({"name": track_file.name, "thread": normalize_track(track, track_file)}) while len(normalize_threads) > thread_amount: yield from remove_finished_threads() # if there is more than the max amount of thread running, wait for one to finish while len(normalize_threads) > 0: yield from remove_finished_threads() # if there is no longer any track to add, wait for all process to finish