pytorchラーニングノート(XIV)。DataLoaderのソースコード読み込み
pytorch
インターフェースのデータロード部分は、既存の深層学習フレームワークの中で間違いなく最もよく設計されており、十分な柔軟性を与えてくれます。このブログ記事では、pytorchのマルチスレッドローディングモジュール(
DataLoader
) を使ってソースコードのアノテーションを行います。
入力パイプライン
pytorch
入力パイプラインの操作順序は以下の通りです。
- Datasetオブジェクトの作成
- DataLoaderオブジェクトを作成する
- このDataLoaderオブジェクトを連続的にループさせる
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....
前回の記事で紹介したように、もし既存の
Dataset
をカスタマイズすることもできます。
Dataset
を継承することで
torch.utils.data.Dataset
. 継承する場合
override
の3つのメソッドがあります。
-
__init__
: データセットを初期化するために使用します -
__getitem__ __len__
この記事から、以下のことがわかります。
__getitem__
と
__len__
での
DataLoader
での使用方法
データローダー
からの
DataLoader
以下はソースコードです。便宜上、ソースコードにコメントを付けて解釈しています。
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() is used in the sampler.
# The purpose is to generate a sequence index (random) of length len(dataset).
sampler = RandomSampler(dataset)
else:
# dataset.__len__() is used in the Sampler.
# The purpose is to generate a sequential index of length len(dataset) (sequential).
sampler = SequentialSampler(dataset)
# Sampler is an iterator that returns one index at a time
# BatchSampler is also an iterator, but returns batch_size one index at a time
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
# The following two codes are equivalent
for data in dataloader:
...
# Equivalent to
iterr = iter(dataloader)
while True:
try:
next(iterr)
except StopIteration:
break
DataLoader
で
iter(dataloader)
で
DataLoaderIter
が返されます。
next
オブジェクトを作成し、それを
Sampler
オブジェクトを作成します。
まず最初に、いくつかの
DataLoaderIter
を紹介し、次にコアとなる
RandomSampler
.
RandomSampler、SequentialSampler、BatchSampler
まず
iter(randomSampler)
は、その
next
は反復可能なオブジェクトを返し、それが
index
は、現在の
SequentialSampler
と
index
を生成している以外は同じです。
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
BatchSampler
は
オーダー
の
Sampler
wrapper
は、一般的な
Sampler
の
index
は、通常の
BatchSampler
は1つだけ生成されます。
batch
一方
indices
が1つ生成されます。
class BatchSampler(object):
def __init__(self, sampler, batch_size, drop_last):
# The sampler here is either a RandomSampler or a SequentialSampler
# They spit out one idx at a time
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
の
self.index_queue
.
(batch_idx, sample_indices)
データローダイーター
-
batch_idx
はint
ここでsample_indices
はlist
の値を指定します。batch
はsample indices
を構成するコードのリストを保持するものです。self.data_queue
の(batch_idx, samples)
. -
samples
はmini-batch
は、ここでself.send_idx
はself.index_queue
のサンプルです。 -
batch_id self.rcvd_idx
の意味: 今回は、この中に入れてください。batch_id self.batches_outstanding
の中にclass DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): # loader is the DataLoader object self.dataset = loader.dataset # This is left for the last section self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler # Indicates how many processes to start. self.num_workers = loader.num_workers # Whether to use pin_memory self.pin_memory = loader.pin_memory self.done_event = threading. # So that you can use next to manipulate the batch_sampler self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: # The queue used to hold the batch_idx, the elements of which are a list that holds the index of the samples in a batch self.index_queue = multiprocessing.SimpleQueue() # A queue to hold batch_data, where the elements of a batch's data are self.data_queue = multiprocessing.SimpleQueue() # The number of batches currently ready (some may be in preparation) # When 0, it means that there is no more data left in the dataset. # Initial value is 0, +1 in self._put_indices(), minus one in self.__next__ self.batches_outstanding = 0 self.shutdown = False # to record the idx of the batch to be put in index_queue this time self.send_idx = 0 # to record the idx of the batch to be retrieved from the data_queue this time self.rcvd_idx = 0 # Because of multiple threads, the batch in the data_queue may be out of order # Use this to ensure that the batch returns are idx-escalated. self.reorder_dict = {} # This is where the multiprocessing starts, with a total of num_workers # Execute _worker_loop, which is described below self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn)) for _ in range(self.num_workers)] for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() if self.pin_memory: in_data = self.data_queue self.data_queue = queue.Queue() self.pin_thread = threading. target=_pin_memory_loop, args=(in_data, self.data_queue, self.done_event)) self.pin_thread.daemon = True self.pin_thread.start() # prime the prefetch loop # prime the prefetch loop by putting 2 * num_workers individual (batch_idx, sampler_indices) into index_queue. for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return len(self.batch_sampler) def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: # means there is no more data left to work with, you can stop the worker now self._shutdown_workers() raise StopIteration while True: # The operation here is to arrange the disordered data_queue assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self.data_queue.get() # A batch is returned, batches_outstanding -1 self.batches_outstanding -= 1 if idx ! = self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue # When returned, drop another (batch_idx, sample_indices) into indice_queue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queue.put((self.send_idx, indices)) self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 # put down a (batch_idx, sample_indices) self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): __worker_loop
-
の意味:今回は
index_queue
- 示す。
data_queue
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
# want to put None in data_queue
data_queue.put(None)
break
idx, batch_indices = r
try:
# Here's where you can see what dataset.__getiterm__ does.
# The data passed to collate_fn is the list of ...
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
これは、マルチプロセスが実行するコードです。
__getiterm__
で、データを処理し、処理したバッチデータを
collate_fn
の中にあります。
[(img_tensor, label), ....]
collate_fn
-
私たち
batch[0]
は、(img_tensor, label)を返すことが多い。 -
そこで
(img_tensor, label)
パラメータはcollections.Sequence
. -
def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" if torch.is_tensor(batch[0]): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy # Count the number of all elements in the batch numel = sum([x.numel() for x in batch]) # No corresponding api found 。。。。。。 storage = batch[0].storage(). _new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif type(batch[0]). __module__ == 'numpy': elem = batch[0] if type(elem). __name__ == 'ndarray': return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], collections.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" .format(type(batch[0]))))
はimage captioning
であり、それはまたtorchtext
タイプになります。
collate_fn
について
DataLoader
タスクで、画像とテキストの両方を扱う場合、公式の pytorch オープンソースツールキットの
torchtext
を使用すると、テキストデータを非常に簡単に扱えるようになります。
data_queue
から
2*num_worker
と
batch
というように、完全に統合されています。
要約
-
__getitem__()
の最大数です。collate_fn()
でnumpy
備考
を使用している場合は
torch
または
queue.get()
の使用は
queue.put()
を使用して乱数を生成する場合、次に注意することは
なぜnumpy-random-randは異なるコアで同じ値を生成するのか?
を使用している場合は
get
を使用して乱数を生成する場合は、そのような心配は必要ありません。
キュー」の特徴
-
中にデータがない場合
queue.put()
はブロックされ、ブロックされている間、他のプロセス/スレッドがqueue.put()
操作を行うと、このスレッド/プロセスにはその旨が通知され、その後get
が成功する。 -
データが一杯になったとき :
queue.put()
がブロックされます。
関連
-
[解決済み】pytorchでテンソルを平らにする方法は?
-
[解決済み] Pytorch ある割合で特定の値を持つランダムなint型テンソルを作成する方法は?例えば、25%が1で残りが0というような。
-
Pytorch-1-TX2にpytorchをインストール(自分でやったよ)
-
AttributeError NoneType オブジェクトに属性データがない。
-
ピトーチリピートの使用方法
-
pytorchのSpeat()関数
-
pytorch学習におけるtorch.squeeze()とtorch.unsqueeze()の使用法
-
torch.stack()の公式解説、詳細、例題について
-
torch.catとtorch.stackの違いについて
-
ピトーチテンソルインデックス
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
[Centernet recurrence] AttributeError:Can't pickle local object 'get_dataset.<locals>.Dataset
-
PytorchがNotImplementedErrorを発生させるようです。
-
Pytorch torch.Tensor.detach()メソッドの使い方と、指定したモジュールの重みを変更する方法
-
torch.stack()の使用
-
pytorch-DataLoader (データイテレータ)
-
PyTorchのF.cross_entropy()関数
-
AttributeError: 'Graph' オブジェクトには 'node' という属性がありません。