From 43e46008c7513dedfe204446a280fdc76614c71b Mon Sep 17 00:00:00 2001 From: raphael60650 Date: Wed, 28 Jul 2021 21:40:35 +0200 Subject: [PATCH] added more sha1 check in download_wu8, and do 3 try before raising an exception --- source/Track.py | 58 +++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/source/Track.py b/source/Track.py index 2aaa824..1113a73 100644 --- a/source/Track.py +++ b/source/Track.py @@ -3,12 +3,24 @@ import requests from .definition import * from .wszst import * +CHUNK_SIZE = 524288 + class CantDownloadTrack(Exception): - def __init__(self, track, http_error=404): + def __init__(self, track, http_error: [str, int]): super().__init__(f"Can't download track {track.get_track_name()} (error {http_error}) !") +def check_file_sha1(file: str, excepted_sha1: str) -> int: + """ + check if track szs sha1 is correct + :return: 1 if yes, 0 if no + """ + if not os.path.exists(file): return 0 + if szs.sha1(file=file) == excepted_sha1: return 1 + else: return 0 + + class Track: def __init__(self, name: str = "_", prefix: str = None, suffix: str = None, author="Nintendo", special="T11", music="T11", new=True, sha1: str = None, since_version: str = None, @@ -60,14 +72,19 @@ class Track: """ return f"{self.get_track_name()} sha1={self.sha1} score={self.score}" - def check_sha1(self) -> int: + def check_wu8_sha1(self) -> int: """ - check if track wu8's sha1 is correct + check if track wu8 sha1 is correct :return: 0 if yes, -1 if no """ - if not os.path.exists(self.file_wu8): return -1 - if szs.sha1(file=self.file_wu8) == self.sha1: return 0 - else: return -1 + return check_file_sha1(self.file_wu8, self.sha1) + + def check_szs_sha1(self) -> int: + """ + check if track szs sha1 is correct + :return: 0 if yes, -1 if no + """ + return check_file_sha1(self.file_szs, self.sha1) def convert_wu8_to_szs(self) -> None: """ @@ -75,26 +92,25 @@ class Track: """ szs.normalize(src_file=self.file_wu8) - def download_wu8(self, github_content_root: str) -> int: + def download_wu8(self, github_content_root: str) -> None: """ download track wu8 from github :param github_content_root: url to github project root :return: 0 if correctly downloaded """ - if self.check_sha1(): return 0 # if sha1 correct, do not try to download track - - dl = requests.get(github_content_root + self.file_wu8, allow_redirects=True, stream=True) - - if dl.status_code == 200: # if page is found - with open(self.file_wu8, "wb") as file: - chunk_size = 524288 - for i, chunk in enumerate(dl.iter_content(chunk_size=chunk_size)): - file.write(chunk) - file.flush() - return 0 - else: - raise CantDownloadTrack(track=self, http_error=dl.status_code) + if self.check_wu8_sha1(): return # if sha1 correct, do not try to download track + for _ in range(3): + dl = requests.get(github_content_root + self.file_wu8, allow_redirects=True, stream=True) + if dl.status_code == 200: # if page is found + with open(self.file_wu8, "wb") as file: + for i, chunk in enumerate(dl.iter_content(chunk_size=CHUNK_SIZE)): + file.write(chunk) + file.flush() + if self.check_wu8_sha1(): return # if sha1 correct, do not try to download track + else: + raise CantDownloadTrack(track=self, http_error=dl.status_code) + raise CantDownloadTrack(track=self, http_error="Failed to download track") # if failed more than 3 times def get_ctfile(self, race=False, *args, **kwargs) -> str: """ @@ -164,7 +180,7 @@ class Track: def load_from_json(self, track_json: dict) -> None: """ load the track from a dictionary - :param track_json: track's dictionnary + :param track_json: track's dictionary """ for key, value in track_json.items(): # load all value in the json as class attribute setattr(self, key, value)