import argparse
import cv2
import glob
import numpy as np
import os
from basicsr.utils.lmdb_util import LmdbMaker
[docs]def convert_celeba_tfrecords(tf_file, log_resolution, save_root, save_type='img', compress_level=1):
"""Convert CelebA tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file in glob pattern.
Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if 'validation' in tf_file:
phase = 'validation'
else:
phase = 'train'
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'celeba_{2**log_resolution}_{phase}.lmdb')
lmdb_maker = LmdbMaker(save_path)
elif save_type == 'img':
save_path = os.path.join(save_root, f'celeba_{2**log_resolution}_{phase}')
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
idx = 0
for record in sorted(glob.glob(tf_file)):
print('Processing record: ', record)
record_iterator = tf.python_io.tf_record_iterator(record)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
# label = example.features.feature['label'].int64_list.value[0]
# attr = example.features.feature['attr'].int64_list.value
# male = attr[20]
# young = attr[39]
shape = example.features.feature['shape'].int64_list.value
h, w, c = shape
img_str = example.features.feature['data'].bytes_list.value[0]
img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c))
img = img[:, :, [2, 1, 0]]
if save_type == 'img':
cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
elif save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
idx += 1
print(idx)
if save_type == 'lmdb':
lmdb_maker.close()
[docs]def convert_ffhq_tfrecords(tf_file, log_resolution, save_root, save_type='img', compress_level=1):
"""Convert FFHQ tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}.lmdb')
lmdb_maker = LmdbMaker(save_path)
elif save_type == 'img':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}')
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
idx = 0
for record in sorted(glob.glob(tf_file)):
print('Processing record: ', record)
record_iterator = tf.python_io.tf_record_iterator(record)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
shape = example.features.feature['shape'].int64_list.value
c, h, w = shape
img_str = example.features.feature['data'].bytes_list.value[0]
img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w))
img = img.transpose(1, 2, 0)
img = img[:, :, [2, 1, 0]]
if save_type == 'img':
cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
elif save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
idx += 1
print(idx)
if save_type == 'lmdb':
lmdb_maker.close()
[docs]def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, save_type='lmdb', compress_level=1):
"""Make FFHQ lmdb from images.
Args:
folder_path (str): Folder path.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}_crop1.2.lmdb')
lmdb_maker = LmdbMaker(save_path)
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(folder_path, '*')))
for idx, img_path in enumerate(img_list):
print(f'Processing {idx}: ', img_path)
img = cv2.imread(img_path)
h, w, c = img.shape
if save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
if save_type == 'lmdb':
lmdb_maker.close()
if __name__ == '__main__':
"""Read tfrecords w/o define a graph.
We have tested it on TensorFlow 1.15
References: http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset', type=str, default='ffhq', help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.")
parser.add_argument(
'--tf_file',
type=str,
default='datasets/ffhq/ffhq-r10.tfrecords',
help=(
'Input tfrecords file. For celeba, it should be glob pattern. '
'Put quotes around the wildcard argument to prevent the shell '
'from expanding it.'
"Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501
))
parser.add_argument('--log_resolution', type=int, default=10, help='Log scale of resolution.')
parser.add_argument('--save_root', type=str, default='datasets/ffhq/', help='Save root path.')
parser.add_argument(
'--save_type', type=str, default='img', help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.")
parser.add_argument(
'--compress_level', type=int, default=1, help='Compress level when encoding images. Default: 1.')
args = parser.parse_args()
try:
import tensorflow as tf
except Exception:
raise ImportError('You need to install tensorflow to read tfrecords.')
if args.dataset == 'ffhq':
convert_ffhq_tfrecords(
args.tf_file,
args.log_resolution,
args.save_root,
save_type=args.save_type,
compress_level=args.compress_level)
else:
convert_celeba_tfrecords(
args.tf_file,
args.log_resolution,
args.save_root,
save_type=args.save_type,
compress_level=args.compress_level)