Module 1_gluoncv_finetune.lib.im2rec
Expand source code
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import os
import sys
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time
import traceback
try:
import multiprocessing
except ImportError:
multiprocessing = None
def list_image(root, recursive, exts):
"""Traverses the root of directory that contains images and
generates image list iterator.
Parameters
----------
root: string
recursive: bool
exts: string
Returns
-------
image iterator that contains all the image under the specified path
"""
i = 0
if recursive:
cat = {}
for path, dirs, files in os.walk(root, followlinks=True):
dirs.sort()
files.sort()
for fname in files:
fpath = os.path.join(path, fname)
suffix = os.path.splitext(fname)[1].lower()
if os.path.isfile(fpath) and (suffix in exts):
if path not in cat:
cat[path] = len(cat)
yield (i, os.path.relpath(fpath, root), cat[path])
i += 1
for k, v in sorted(cat.items(), key=lambda x: x[1]):
print(os.path.relpath(k, root), v)
else:
for fname in sorted(os.listdir(root)):
fpath = os.path.join(root, fname)
suffix = os.path.splitext(fname)[1].lower()
if os.path.isfile(fpath) and (suffix in exts):
yield (i, os.path.relpath(fpath, root), 0)
i += 1
def write_list(path_out, image_list):
"""Hepler function to write image list into the file.
The format is as below,
integer_image_index \t float_label_index \t path_to_image
Note that the blank between number and tab is only used for readability.
Parameters
----------
path_out: string
image_list: list
"""
with open(path_out, 'w') as fout:
for i, item in enumerate(image_list):
line = '%d\t' % item[0]
for j in item[2:]:
line += '%f\t' % j
line += '%s\n' % item[1]
fout.write(line)
def make_list(args):
"""Generates .lst file.
Parameters
----------
args: object that contains all the arguments
"""
image_list = list_image(args.root, args.recursive, args.exts)
image_list = list(image_list)
if args.shuffle is True:
random.seed(100)
random.shuffle(image_list)
N = len(image_list)
chunk_size = (N + args.chunks - 1) // args.chunks
for i in range(args.chunks):
chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
if args.chunks > 1:
str_chunk = '_%d' % i
else:
str_chunk = ''
sep = int(chunk_size * args.train_ratio)
sep_test = int(chunk_size * args.test_ratio)
if args.train_ratio == 1.0:
write_list(args.prefix + str_chunk + '.lst', chunk)
else:
if args.test_ratio:
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
if args.train_ratio + args.test_ratio < 1.0:
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
def read_list(path_in):
"""Reads the .lst file and generates corresponding iterator.
Parameters
----------
path_in: string
Returns
-------
item iterator that contains information in .lst file
"""
with open(path_in) as fin:
while True:
line = fin.readline()
if not line:
break
line = [i.strip() for i in line.strip().split('\t')]
line_len = len(line)
# check the data format of .lst file
if line_len < 3:
print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line))
continue
try:
item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
except Exception as e:
print('Parsing lst met error for %s, detail: %s' % (line, e))
continue
yield item
def image_encode(args, i, item, q_out):
"""Reads, preprocesses, packs the image and put it back in output queue.
Parameters
----------
args: object
i: int
item: list
q_out: queue
"""
fullpath = os.path.join(args.root, item[1])
if len(item) > 3 and args.pack_label:
header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
else:
header = mx.recordio.IRHeader(0, item[2], item[0], 0)
if args.pass_through:
try:
with open(fullpath, 'rb') as fin:
img = fin.read()
s = mx.recordio.pack(header, img)
q_out.put((i, s, item))
except Exception as e:
traceback.print_exc()
print('pack_img error:', item[1], e)
q_out.put((i, None, item))
return
try:
img = cv2.imread(fullpath, args.color)
except:
traceback.print_exc()
print('imread error trying to load file: %s ' % fullpath)
q_out.put((i, None, item))
return
if img is None:
print('imread read blank (None) image for file: %s' % fullpath)
q_out.put((i, None, item))
return
if args.center_crop:
if img.shape[0] > img.shape[1]:
margin = (img.shape[0] - img.shape[1]) // 2
img = img[margin:margin + img.shape[1], :]
else:
margin = (img.shape[1] - img.shape[0]) // 2
img = img[:, margin:margin + img.shape[0]]
if args.resize:
if img.shape[0] > img.shape[1]:
newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
else:
newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
img = cv2.resize(img, newsize)
try:
s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
q_out.put((i, s, item))
except Exception as e:
traceback.print_exc()
print('pack_img error on file: %s' % fullpath, e)
q_out.put((i, None, item))
return
def read_worker(args, q_in, q_out):
"""Function that will be spawned to fetch the image
from the input queue and put it back to output queue.
Parameters
----------
args: object
q_in: queue
q_out: queue
"""
while True:
deq = q_in.get()
if deq is None:
break
i, item = deq
image_encode(args, i, item, q_out)
def write_worker(q_out, fname, working_dir):
"""Function that will be spawned to fetch processed image
from the output queue and write to the .rec file.
Parameters
----------
q_out: queue
fname: string
working_dir: string
"""
pre_time = time.time()
count = 0
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
os.path.join(working_dir, fname_rec), 'w')
buf = {}
more = True
while more:
deq = q_out.get()
if deq is not None:
i, s, item = deq
buf[i] = (s, item)
else:
more = False
while count in buf:
s, item = buf[count]
del buf[count]
if s is not None:
record.write_idx(item[0], s)
if count % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', count)
pre_time = cur_time
count += 1
def parse_args():
"""Defines all arguments.
Returns
-------
args object that contains all the params
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Create an image list or \
make a record database by reading from an image list')
parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
parser.add_argument('root', help='path to folder containing images.')
cgroup = parser.add_argument_group('Options for creating image lists')
cgroup.add_argument('--list', action='store_true',
help='If this is set im2rec will create image list(s) by traversing root folder\
and output to <prefix>.lst.\
Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')
cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
help='list of acceptable image extensions.')
cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
cgroup.add_argument('--train-ratio', type=float, default=1.0,
help='Ratio of images to use for training.')
cgroup.add_argument('--test-ratio', type=float, default=0,
help='Ratio of images to use for testing.')
cgroup.add_argument('--recursive', action='store_true',
help='If true recursively walk through subdirs and assign an unique label\
to images in each folder. Otherwise only include images in the root folder\
and give them label 0.')
cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',
help='If this is passed, \
im2rec will not randomize the image order in <prefix>.lst')
rgroup = parser.add_argument_group('Options for creating database')
rgroup.add_argument('--pass-through', action='store_true',
help='whether to skip transformation and save image as is')
rgroup.add_argument('--resize', type=int, default=0,
help='resize the shorter edge of image to the newsize, original images will\
be packed by default.')
rgroup.add_argument('--center-crop', action='store_true',
help='specify whether to crop the center image to make it rectangular.')
rgroup.add_argument('--quality', type=int, default=95,
help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
rgroup.add_argument('--num-thread', type=int, default=1,
help='number of thread to use for encoding. order of images will be different\
from the input list if >1. the input list will be modified to match the\
resulting order.')
rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
help='specify the color mode of the loaded image.\
1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
0: Loads image in grayscale mode.\
-1:Loads image as such including alpha channel.')
rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
help='specify the encoding of the images.')
rgroup.add_argument('--pack-label', action='store_true',
help='Whether to also pack multi dimensional label in the record file')
args = parser.parse_args()
args.prefix = os.path.abspath(args.prefix)
args.root = os.path.abspath(args.root)
return args
if __name__ == '__main__':
args = parse_args()
# if the '--list' is used, it generates .lst file
if args.list:
make_list(args)
# otherwise read .lst file to generates .rec file
else:
if os.path.isdir(args.prefix):
working_dir = args.prefix
else:
working_dir = os.path.dirname(args.prefix)
files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
if os.path.isfile(os.path.join(working_dir, fname))]
count = 0
for fname in files:
if fname.startswith(args.prefix) and fname.endswith('.lst'):
print('Creating .rec file from', fname, 'in', working_dir)
count += 1
image_list = read_list(fname)
# -- write_record -- #
if args.num_thread > 1 and multiprocessing is not None:
q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
q_out = multiprocessing.Queue(1024)
# define the process
read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
for i in range(args.num_thread)]
# process images with num_thread process
for p in read_process:
p.start()
# only use one process to write .rec to avoid race-condtion
write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
write_process.start()
# put the image list into input queue
for i, item in enumerate(image_list):
q_in[i % len(q_in)].put((i, item))
for q in q_in:
q.put(None)
for p in read_process:
p.join()
q_out.put(None)
write_process.join()
else:
print('multiprocessing not available, fall back to single threaded encoding')
try:
import Queue as queue
except ImportError:
import queue
q_out = queue.Queue()
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
os.path.join(working_dir, fname_rec), 'w')
cnt = 0
pre_time = time.time()
for i, item in enumerate(image_list):
image_encode(args, i, item, q_out)
if q_out.empty():
continue
_, s, _ = q_out.get()
record.write_idx(item[0], s)
if cnt % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', cnt)
pre_time = cur_time
cnt += 1
if not count:
print('Did not find and list file with prefix %s'%args.prefix)
Functions
def image_encode(args, i, item, q_out)
-
Reads, preprocesses, packs the image and put it back in output queue. Parameters
args
:object
i
:int
item
:list
q_out
:queue
Expand source code
def image_encode(args, i, item, q_out): """Reads, preprocesses, packs the image and put it back in output queue. Parameters ---------- args: object i: int item: list q_out: queue """ fullpath = os.path.join(args.root, item[1]) if len(item) > 3 and args.pack_label: header = mx.recordio.IRHeader(0, item[2:], item[0], 0) else: header = mx.recordio.IRHeader(0, item[2], item[0], 0) if args.pass_through: try: with open(fullpath, 'rb') as fin: img = fin.read() s = mx.recordio.pack(header, img) q_out.put((i, s, item)) except Exception as e: traceback.print_exc() print('pack_img error:', item[1], e) q_out.put((i, None, item)) return try: img = cv2.imread(fullpath, args.color) except: traceback.print_exc() print('imread error trying to load file: %s ' % fullpath) q_out.put((i, None, item)) return if img is None: print('imread read blank (None) image for file: %s' % fullpath) q_out.put((i, None, item)) return if args.center_crop: if img.shape[0] > img.shape[1]: margin = (img.shape[0] - img.shape[1]) // 2 img = img[margin:margin + img.shape[1], :] else: margin = (img.shape[1] - img.shape[0]) // 2 img = img[:, margin:margin + img.shape[0]] if args.resize: if img.shape[0] > img.shape[1]: newsize = (args.resize, img.shape[0] * args.resize // img.shape[1]) else: newsize = (img.shape[1] * args.resize // img.shape[0], args.resize) img = cv2.resize(img, newsize) try: s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) q_out.put((i, s, item)) except Exception as e: traceback.print_exc() print('pack_img error on file: %s' % fullpath, e) q_out.put((i, None, item)) return
def list_image(root, recursive, exts)
-
Traverses the root of directory that contains images and generates image list iterator. Parameters
root
:string
recursive
:bool
exts
:string
Returns
image iterator that contains all the image under the specified path
Expand source code
def list_image(root, recursive, exts): """Traverses the root of directory that contains images and generates image list iterator. Parameters ---------- root: string recursive: bool exts: string Returns ------- image iterator that contains all the image under the specified path """ i = 0 if recursive: cat = {} for path, dirs, files in os.walk(root, followlinks=True): dirs.sort() files.sort() for fname in files: fpath = os.path.join(path, fname) suffix = os.path.splitext(fname)[1].lower() if os.path.isfile(fpath) and (suffix in exts): if path not in cat: cat[path] = len(cat) yield (i, os.path.relpath(fpath, root), cat[path]) i += 1 for k, v in sorted(cat.items(), key=lambda x: x[1]): print(os.path.relpath(k, root), v) else: for fname in sorted(os.listdir(root)): fpath = os.path.join(root, fname) suffix = os.path.splitext(fname)[1].lower() if os.path.isfile(fpath) and (suffix in exts): yield (i, os.path.relpath(fpath, root), 0) i += 1
def make_list(args)
-
Generates .lst file. Parameters
args
:object that contains all the arguments
Expand source code
def make_list(args): """Generates .lst file. Parameters ---------- args: object that contains all the arguments """ image_list = list_image(args.root, args.recursive, args.exts) image_list = list(image_list) if args.shuffle is True: random.seed(100) random.shuffle(image_list) N = len(image_list) chunk_size = (N + args.chunks - 1) // args.chunks for i in range(args.chunks): chunk = image_list[i * chunk_size:(i + 1) * chunk_size] if args.chunks > 1: str_chunk = '_%d' % i else: str_chunk = '' sep = int(chunk_size * args.train_ratio) sep_test = int(chunk_size * args.test_ratio) if args.train_ratio == 1.0: write_list(args.prefix + str_chunk + '.lst', chunk) else: if args.test_ratio: write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) if args.train_ratio + args.test_ratio < 1.0: write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
def parse_args()
-
Defines all arguments. Returns
args object that contains all the params
Expand source code
def parse_args(): """Defines all arguments. Returns ------- args object that contains all the params """ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Create an image list or \ make a record database by reading from an image list') parser.add_argument('prefix', help='prefix of input/output lst and rec files.') parser.add_argument('root', help='path to folder containing images.') cgroup = parser.add_argument_group('Options for creating image lists') cgroup.add_argument('--list', action='store_true', help='If this is set im2rec will create image list(s) by traversing root folder\ and output to <prefix>.lst.\ Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec') cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'], help='list of acceptable image extensions.') cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') cgroup.add_argument('--train-ratio', type=float, default=1.0, help='Ratio of images to use for training.') cgroup.add_argument('--test-ratio', type=float, default=0, help='Ratio of images to use for testing.') cgroup.add_argument('--recursive', action='store_true', help='If true recursively walk through subdirs and assign an unique label\ to images in each folder. Otherwise only include images in the root folder\ and give them label 0.') cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false', help='If this is passed, \ im2rec will not randomize the image order in <prefix>.lst') rgroup = parser.add_argument_group('Options for creating database') rgroup.add_argument('--pass-through', action='store_true', help='whether to skip transformation and save image as is') rgroup.add_argument('--resize', type=int, default=0, help='resize the shorter edge of image to the newsize, original images will\ be packed by default.') rgroup.add_argument('--center-crop', action='store_true', help='specify whether to crop the center image to make it rectangular.') rgroup.add_argument('--quality', type=int, default=95, help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') rgroup.add_argument('--num-thread', type=int, default=1, help='number of thread to use for encoding. order of images will be different\ from the input list if >1. the input list will be modified to match the\ resulting order.') rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], help='specify the color mode of the loaded image.\ 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ 0: Loads image in grayscale mode.\ -1:Loads image as such including alpha channel.') rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], help='specify the encoding of the images.') rgroup.add_argument('--pack-label', action='store_true', help='Whether to also pack multi dimensional label in the record file') args = parser.parse_args() args.prefix = os.path.abspath(args.prefix) args.root = os.path.abspath(args.root) return args
def read_list(path_in)
-
Reads the .lst file and generates corresponding iterator. Parameters
path_in
:string
Returns
item iterator that contains information in .lst file
Expand source code
def read_list(path_in): """Reads the .lst file and generates corresponding iterator. Parameters ---------- path_in: string Returns ------- item iterator that contains information in .lst file """ with open(path_in) as fin: while True: line = fin.readline() if not line: break line = [i.strip() for i in line.strip().split('\t')] line_len = len(line) # check the data format of .lst file if line_len < 3: print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line)) continue try: item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] except Exception as e: print('Parsing lst met error for %s, detail: %s' % (line, e)) continue yield item
def read_worker(args, q_in, q_out)
-
Function that will be spawned to fetch the image from the input queue and put it back to output queue. Parameters
args
:object
q_in
:queue
q_out
:queue
Expand source code
def read_worker(args, q_in, q_out): """Function that will be spawned to fetch the image from the input queue and put it back to output queue. Parameters ---------- args: object q_in: queue q_out: queue """ while True: deq = q_in.get() if deq is None: break i, item = deq image_encode(args, i, item, q_out)
def write_list(path_out, image_list)
-
Hepler function to write image list into the file. The format is as below, integer_image_index float_label_index path_to_image Note that the blank between number and tab is only used for readability. Parameters
path_out
:string
image_list
:list
Expand source code
def write_list(path_out, image_list): """Hepler function to write image list into the file. The format is as below, integer_image_index \t float_label_index \t path_to_image Note that the blank between number and tab is only used for readability. Parameters ---------- path_out: string image_list: list """ with open(path_out, 'w') as fout: for i, item in enumerate(image_list): line = '%d\t' % item[0] for j in item[2:]: line += '%f\t' % j line += '%s\n' % item[1] fout.write(line)
def write_worker(q_out, fname, working_dir)
-
Function that will be spawned to fetch processed image from the output queue and write to the .rec file. Parameters
q_out
:queue
fname
:string
working_dir
:string
Expand source code
def write_worker(q_out, fname, working_dir): """Function that will be spawned to fetch processed image from the output queue and write to the .rec file. Parameters ---------- q_out: queue fname: string working_dir: string """ pre_time = time.time() count = 0 fname = os.path.basename(fname) fname_rec = os.path.splitext(fname)[0] + '.rec' fname_idx = os.path.splitext(fname)[0] + '.idx' record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), 'w') buf = {} more = True while more: deq = q_out.get() if deq is not None: i, s, item = deq buf[i] = (s, item) else: more = False while count in buf: s, item = buf[count] del buf[count] if s is not None: record.write_idx(item[0], s) if count % 1000 == 0: cur_time = time.time() print('time:', cur_time - pre_time, ' count:', count) pre_time = cur_time count += 1