1. ホーム
  2. pytorch

pytorchラーニングノート(XIV)。DataLoaderのソースコード読み込み

2022-02-18 02:15:27
<パス

pytorch インターフェースのデータロード部分は、既存の深層学習フレームワークの中で間違いなく最もよく設計されており、十分な柔軟性を与えてくれます。このブログ記事では、pytorchのマルチスレッドローディングモジュール( DataLoader ) を使ってソースコードのアノテーションを行います。

入力パイプライン

pytorch 入力パイプラインの操作順序は以下の通りです。

  • Datasetオブジェクトの作成
  • DataLoaderオブジェクトを作成する
  • このDataLoaderオブジェクトを連続的にループさせる
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for data in dataloader:
        ....

前回の記事で紹介したように、もし既存の Dataset をカスタマイズすることもできます。 Dataset を継承することで torch.utils.data.Dataset . 継承する場合 override の3つのメソッドがあります。

  • __init__ : データセットを初期化するために使用します
  • __getitem__
    __len__
    

この記事から、以下のことがわかります。 __getitem__ __len__ での DataLoader での使用方法

データローダー

からの DataLoader 以下はソースコードです。便宜上、ソースコードにコメントを付けて解釈しています。

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    # dataset.__len__() is used in the sampler.
                    # The purpose is to generate a sequence index (random) of length len(dataset).
                    sampler = RandomSampler(dataset)
                else:
                    # dataset.__len__() is used in the Sampler.
                    # The purpose is to generate a sequential index of length len(dataset) (sequential).
                    sampler = SequentialSampler(dataset)
            # Sampler is an iterator that returns one index at a time
            # BatchSampler is also an iterator, but returns batch_size one index at a time
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)
# The following two codes are equivalent
for data in dataloader:
    ...
# Equivalent to
iterr = iter(dataloader)
while True:
    try:
        next(iterr)
    except StopIteration:
        break

DataLoader

iter(dataloader) DataLoaderIter が返されます。 next オブジェクトを作成し、それを Sampler オブジェクトを作成します。

まず最初に、いくつかの DataLoaderIter を紹介し、次にコアとなる RandomSampler .

RandomSampler、SequentialSampler、BatchSampler

まず iter(randomSampler) は、その next は反復可能なオブジェクトを返し、それが index は、現在の SequentialSampler index を生成している以外は同じです。 class RandomSampler(Sampler): def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(torch.randperm(len(self.data_source)).long()) def __len__(self): return len(self.data_source) BatchSampler オーダー

Sampler

wrapper は、一般的な Sampler index は、通常の BatchSampler は1つだけ生成されます。 batch 一方 indices が1つ生成されます。 class BatchSampler(object): def __init__(self, sampler, batch_size, drop_last): # The sampler here is either a RandomSampler or a SequentialSampler # They spit out one idx at a time self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size self.index_queue .

(batch_idx, sample_indices)

データローダイーター

  1. batch_idx int ここで sample_indices list の値を指定します。 batch sample indices を構成するコードのリストを保持するものです。 self.data_queue (batch_idx, samples) .
  2. samples mini-batch は、ここで self.send_idx self.index_queue のサンプルです。
  3. batch_id self.rcvd_idx の意味: 今回は、この中に入れてください。 batch_id self.batches_outstanding の中に
    class DataLoaderIter(object):
        "Iterates once over the DataLoader's dataset, as specified by the sampler"
    
        def __init__(self, loader):
            # loader is the DataLoader object
            self.dataset = loader.dataset
            # This is left for the last section
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            # Indicates how many processes to start.
            self.num_workers = loader.num_workers
            # Whether to use pin_memory
            self.pin_memory = loader.pin_memory
            self.done_event = threading.
    
            # So that you can use next to manipulate the batch_sampler
            self.sample_iter = iter(self.batch_sampler)
    
            if self.num_workers > 0:
                # The queue used to hold the batch_idx, the elements of which are a list that holds the index of the samples in a batch
                self.index_queue = multiprocessing.SimpleQueue()
                # A queue to hold batch_data, where the elements of a batch's data are
                self.data_queue = multiprocessing.SimpleQueue()
    
                # The number of batches currently ready (some may be in preparation)
                # When 0, it means that there is no more data left in the dataset.
                # Initial value is 0, +1 in self._put_indices(), minus one in self.__next__
                self.batches_outstanding = 0 
                self.shutdown = False
                # to record the idx of the batch to be put in index_queue this time
                self.send_idx = 0
                # to record the idx of the batch to be retrieved from the data_queue this time
                self.rcvd_idx = 0
                # Because of multiple threads, the batch in the data_queue may be out of order
                # Use this to ensure that the batch returns are idx-escalated.
                self.reorder_dict = {}
                # This is where the multiprocessing starts, with a total of num_workers
                # Execute _worker_loop, which is described below
                self.workers = [
                    multiprocessing.Process(
                        target=_worker_loop,
                        args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
                    for _ in range(self.num_workers)]
    
                for w in self.workers:
                    w.daemon = True # ensure that the worker exits on process exit
                    w.start()
    
                if self.pin_memory:
                    in_data = self.data_queue
                    self.data_queue = queue.Queue()
                    self.pin_thread = threading.
                        target=_pin_memory_loop,
                        args=(in_data, self.data_queue, self.done_event))
                    self.pin_thread.daemon = True
                    self.pin_thread.start()
    
                # prime the prefetch loop
                # prime the prefetch loop by putting 2 * num_workers individual (batch_idx, sampler_indices) into index_queue.
                for _ in range(2 * self.num_workers):
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def __next__(self):
            if self.num_workers == 0: # same-process loading
                indices = next(self.sample_iter) # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                # means there is no more data left to work with, you can stop the worker now
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                # The operation here is to arrange the disordered data_queue
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self.data_queue.get()
                # A batch is returned, batches_outstanding -1
                self.batches_outstanding -= 1
                if idx ! = self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                # When returned, drop another (batch_idx, sample_indices) into indice_queue
                return self._process_next_batch(batch)
    
        next = __next__ # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queue.put((self.send_idx, indices))
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            # put down a (batch_idx, sample_indices)
            self._put_indices()
            if isinstance(batch, ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    
        def __getstate__(self):
          
    __worker_loop
    
  4. の意味:今回は
    index_queue
    
  5. 示す。
data_queue

def _worker_loop(dataset, index_queue, data_queue, collate_fn): global _use_shared_memory _use_shared_memory = True torch.set_num_threads(1) while True: r = index_queue.get() if r is None: # want to put None in data_queue data_queue.put(None) break idx, batch_indices = r try: # Here's where you can see what dataset.__getiterm__ does. # The data passed to collate_fn is the list of ... samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples))

これは、マルチプロセスが実行するコードです。 __getiterm__ で、データを処理し、処理したバッチデータを collate_fn の中にあります。

[(img_tensor, label), ....]

collate_fn

  • 私たち batch[0] は、(img_tensor, label)を返すことが多い。

  • そこで (img_tensor, label) パラメータは collections.Sequence .

  • def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" if torch.is_tensor(batch[0]): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy # Count the number of all elements in the batch numel = sum([x.numel() for x in batch]) # No corresponding api found 。。。。。。 storage = batch[0].storage(). _new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif type(batch[0]). __module__ == 'numpy': elem = batch[0] if type(elem). __name__ == 'ndarray': return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], collections.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" .format(type(batch[0])))) image captioning であり、それはまた torchtext タイプになります。
collate_fn

について DataLoader タスクで、画像とテキストの両方を扱う場合、公式の pytorch オープンソースツールキットの torchtext を使用すると、テキストデータを非常に簡単に扱えるようになります。 data_queue から 2*num_worker batch というように、完全に統合されています。

要約

  • __getitem__() の最大数です。 collate_fn()
    numpy
    

備考

を使用している場合は torch または queue.get() の使用は queue.put() を使用して乱数を生成する場合、次に注意することは なぜnumpy-random-randは異なるコアで同じ値を生成するのか? を使用している場合は get を使用して乱数を生成する場合は、そのような心配は必要ありません。

キュー」の特徴

  • 中にデータがない場合 queue.put() はブロックされ、ブロックされている間、他のプロセス/スレッドが queue.put() 操作を行うと、このスレッド/プロセスにはその旨が通知され、その後 get が成功する。
  • データが一杯になったとき : queue.put() がブロックされます。