E2E_SCSI / scripts /checksum_check.py
kungchuking's picture
Copied from github repository.
2c76547
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import argparse
import hashlib
import json
from typing import Optional
from multiprocessing import Pool
from tqdm import tqdm
DEFAULT_SHA256S_FILE = os.path.join(__file__.rsplit(os.sep, 2)[0], "dr_sha256.json")
BLOCKSIZE = 65536
def main(
download_folder: str,
sha256s_file: str,
dump: bool = False,
n_sha256_workers: int = 4
):
if not os.path.isfile(sha256s_file):
raise ValueError(f"The SHA256 file does not exist ({sha256s_file}).")
expected_sha256s = get_expected_sha256s(
sha256s_file=sha256s_file
)
zipfiles = sorted(glob.glob(os.path.join(download_folder, "*.zip")))
print(f"Extracting SHA256 hashes for {len(zipfiles)} files in {download_folder}.")
extracted_sha256s_list = []
with Pool(processes=n_sha256_workers) as sha_pool:
for extracted_hash in tqdm(
sha_pool.imap(_sha256_file_and_print, zipfiles),
total=len(zipfiles),
):
extracted_sha256s_list.append(extracted_hash)
pass
extracted_sha256s = dict(
zip([os.path.split(z)[-1] for z in zipfiles], extracted_sha256s_list)
)
if dump:
print(extracted_sha256s)
with open(sha256s_file, "w") as f:
json.dump(extracted_sha256s, f, indent=2)
missing_keys, invalid_keys = [], []
for k in expected_sha256s.keys():
if k not in extracted_sha256s:
print(f"{k} missing!")
missing_keys.append(k)
elif expected_sha256s[k] != extracted_sha256s[k]:
print(
f"'{k}' does not match!"
+ f" ({expected_sha256s[k]} != {extracted_sha256s[k]})"
)
invalid_keys.append(k)
if len(invalid_keys) + len(missing_keys) > 0:
raise ValueError(
f"Checksum checker failed!"
+ f" Non-matching checksums: {str(invalid_keys)};"
+ f" missing files: {str(missing_keys)}."
)
def get_expected_sha256s(
sha256s_file: str
):
with open(sha256s_file, "r") as f:
expected_sha256s = json.load(f)
return expected_sha256s
def check_dr_sha256(
path: str,
sha256s_file: str,
expected_sha256s: Optional[dict] = None,
do_assertion: bool = True,
):
zipname = os.path.split(path)[-1]
if expected_sha256s is None:
expected_sha256s = get_expected_sha256s(
sha256s_file=sha256s_file,
)
extracted_hash = sha256_file(path)
if do_assertion:
assert (
extracted_hash == expected_sha256s[zipname]
), f"{zipname}: ({extracted_hash} != {expected_sha256s[zipname]})"
else:
return extracted_hash == expected_sha256s[zipname]
def sha256_file(path: str):
sha256_hash = hashlib.sha256()
with open(path, "rb") as f:
file_buffer = f.read(BLOCKSIZE)
while len(file_buffer) > 0:
sha256_hash.update(file_buffer)
file_buffer = f.read(BLOCKSIZE)
digest_ = sha256_hash.hexdigest()
return digest_
def _sha256_file_and_print(path: str):
digest_ = sha256_file(path)
print(f"{path}: {digest_}")
return digest_
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Check SHA256 hashes of the Dynamic Replica dataset."
)
parser.add_argument(
"--download_folder",
type=str,
help="A local target folder for downloading the the dataset files.",
)
parser.add_argument(
"--sha256s_file",
type=str,
help="A local target folder for downloading the the dataset files.",
default=DEFAULT_SHA256S_FILE,
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="The number of sha256 extraction workers.",
)
parser.add_argument(
"--dump_sha256s",
action="store_true",
help="Store sha256s hashes.",
)
args = parser.parse_args()
main(
str(args.download_folder),
dump=bool(args.dump_sha256s),
n_sha256_workers=int(args.num_workers),
sha256s_file=str(args.sha256s_file),
)