Source code for torchdrug.utils.file
import os
import struct
import logging
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]def download(url, path, save_file=None, md5=None):
"""
Download a file from the specified url.
Skip the downloading step if there exists a file satisfying the given MD5.
Parameters:
url (str): URL to download
path (str): path to store the downloaded file
save_file (str, optional): name of save file. If not specified, infer the file name from the URL.
md5 (str, optional): MD5 of the file
"""
from six.moves.urllib.request import urlretrieve
if save_file is None:
save_file = os.path.basename(url)
if "?" in save_file:
save_file = save_file[:save_file.find("?")]
save_file = os.path.join(path, save_file)
if not os.path.exists(save_file) or compute_md5(save_file) != md5:
logger.info("Downloading %s to %s" % (url, save_file))
urlretrieve(url, save_file)
return save_file
def smart_open(file_name, mode="rb"):
"""
Open a regular file or a zipped file.
This function can be used as drop-in replacement of the builtin function `open()`.
Parameters:
file_name (str): file name
mode (str, optional): open mode for the file stream
"""
import bz2
import gzip
extension = os.path.splitext(file_name)[1]
if extension == '.bz2':
return bz2.BZ2File(file_name, mode)
elif extension == '.gz':
return gzip.GzipFile(file_name, mode)
else:
return open(file_name, mode)
[docs]def extract(zip_file, member=None):
"""
Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
Parameters:
zip_file (str): file name
member (str, optional): extract specific member from the zip file.
If not specified, extract all members.
"""
import gzip
import shutil
import zipfile
import tarfile
zip_name, extension = os.path.splitext(zip_file)
if zip_name.endswith(".tar"):
extension = ".tar" + extension
zip_name = zip_name[:-4]
save_path = os.path.dirname(zip_file)
if extension == ".gz":
member = os.path.basename(zip_name)
members = [member]
save_files = [os.path.join(save_path, member)]
for _member, save_file in zip(members, save_files):
with open(zip_file, "rb") as fin:
fin.seek(-4, 2)
file_size = struct.unpack("<I", fin.read())[0]
with gzip.open(zip_file, "rb") as fin:
if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):
logger.info("Extracting %s to %s" % (zip_file, save_file))
with open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
elif extension in [".tar.gz", ".tgz", ".tar"]:
tar = tarfile.open(zip_file, "r")
if member is not None:
members = [member]
save_files = [os.path.join(save_path, os.path.basename(member))]
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
else:
members = tar.getnames()
save_files = [os.path.join(save_path, _member) for _member in members]
logger.info("Extracting %s to %s" % (zip_file, save_path))
for _member, save_file in zip(members, save_files):
if tar.getmember(_member).isdir():
os.makedirs(save_file, exist_ok=True)
continue
os.makedirs(os.path.dirname(save_file), exist_ok=True)
if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):
with tar.extractfile(_member) as fin, open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
elif extension == ".zip":
zipped = zipfile.ZipFile(zip_file)
if member is not None:
members = [member]
save_files = [os.path.join(save_path, os.path.basename(member))]
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
else:
members = zipped.namelist()
save_files = [os.path.join(save_path, _member) for _member in members]
logger.info("Extracting %s to %s" % (zip_file, save_path))
for _member, save_file in zip(members, save_files):
if zipped.getinfo(_member).is_dir():
os.makedirs(save_file, exist_ok=True)
continue
os.makedirs(os.path.dirname(save_file), exist_ok=True)
if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):
with zipped.open(_member, "r") as fin, open(save_file, "wb") as fout:
shutil.copyfileobj(fin, fout)
else:
raise ValueError("Unknown file extension `%s`" % extension)
if len(save_files) == 1:
return save_files[0]
else:
return save_path
[docs]def compute_md5(file_name, chunk_size=65536):
"""
Compute MD5 of the file.
Parameters:
file_name (str): file name
chunk_size (int, optional): chunk size for reading large files
"""
import hashlib
md5 = hashlib.md5()
with open(file_name, "rb") as fin:
chunk = fin.read(chunk_size)
while chunk:
md5.update(chunk)
chunk = fin.read(chunk_size)
return md5.hexdigest()
[docs]def get_line_count(file_name, chunk_size=8192*1024):
"""
Get the number of lines in a file.
Parameters:
file_name (str): file name
chunk_size (int, optional): chunk size for reading large files
"""
count = 0
with open(file_name, "rb") as fin:
chunk = fin.read(chunk_size)
while chunk:
count += chunk.count(b"\n")
chunk = fin.read(chunk_size)
return count