pytorch-DataLoader (データイテレータ)
このブログでは、pytorchフレームワークについて説明し
DataLoader
各メソッドは例題付きで紹介されており、ちょっと複雑ですが、黙って見ていれば理解できますよ :)
<マーク
個人的なアドバイス
(方法3は複雑すぎて推奨できない)、3項の処理例では非
Method 2>Method 1>Method 3
から学ぶこともできる、データセット処理のためのメソッドです。
目次
ニューラルネットワークの学習には、一般的にforループ(または多層ループ)を使用します。反復ごとにデータのバッチがロードされ、ニューラルネットワークはそれぞれ1回ずつ順方向逆伝播され、パラメータが1回更新されます。
データのバッチをロードする処理には、torch.utils.data.DataLoaderオブジェクトの使用が必要で、DataLoaderはあるデータセットに基づく反復処理で、あるサンプリング原理に基づいてデータセットからデータのバッチを一つ取り込みます。
また、Torch.utils.data.DatasetオブジェクトをTorchで作成し、torch.utils.data.DataLoaderで使用すれば、学習に合わせてモデルにデータを提供し続けることができると言えます。
1 torch.utils.data.DataLoader
定義:データローダー。データセットとサンプラーを組み合わせ、与えられたデータセットに対するイテラブルを提供する。
そのコンストラクタのパラメータを見てみましょう。
<ブロッククオート
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None.DataLoaderは、データセット、バッチサイズ、シャッフル、サンプラーを含む。
batch_sampler=None, num_workers=0, collate_fn=None。
pin_memory=False、drop_last=False、timeout=0です。
worker_init_fn=None)
次に、そのパラメータを分解してみます。
最も重要なパラメータはデータセットで、マップ形式のデータセットと反復可能形式のデータセットの2種類を含む抽象クラスである。
データセット (データセット) - データを読み込むためのデータセット。
バッチサイズ (int, オプション) - バッチごとにロードするサンプル数 (デフォルト: 1).
シャッフル (bool, オプション) - True に設定すると、エポック毎にデータを再シャッフルします (デフォルト: False)。
サンプラー (Sampler or Iterable, オプション) - データセットからサンプルを抽出するための方法を定義します. len 指定する場合、shuffle は指定してはならない。
バッチサンプラー (Sampler or Iterable, optional) - samplerに似ていますが、一度にインデックスのバッチを返します。batch_size、shuffle、sampler、drop_lastとは排他的です。
num_workers (int, オプション) - データのロードに使用するサブプロセスの数。0 は、データがメイン・プロセスでロードされることを意味します。(デフォルト: 0)
collate_fn (呼び出し可能、オプション) - サンプルのリストをマージして、Tensorのミニバッチを形成します。マップスタイルのデータセットからバッチロードを使用するときに使用する。
ピンメモリ (bool, オプション) - Trueの場合、データローダーはTensorを返す前にCUDA pinned memoryにコピーします。カスタムタイプ、またはcollate_fnがカスタムタイプであるバッチを返す場合は、以下の例を参照してください。
ドロップラスト (bool, オプション) - Trueに設定すると、データセットのサイズがバッチサイズで割り切れない場合、最後の不完全なバッチを削除します。もしFalseで、データセットのサイズがバッチサイズで割り切れない場合は、最後のバッチが小さくなります。(デフォルト: False)
タイムアウト (数値、オプション) - 正の場合、ワーカーからバッチを収集するためのタイムアウト値です。(デフォルト: 0)
ワーカーインit_fn (callable, optional) - None でない場合、ワーカー ID ([0, num_workers - 1] の int) を入力として、各ワーカーのサブプロセスで呼び出されます (デフォルト: None)
1.1データセット
2種類のデータセットにのみ対応しています。
DataLoader
,
map-style datasets
1.1.1 地図風データセット
は、クラスが
iterable-style datasets.
と
__getitem__()
これらは、インデックスからデータサンプルへのマッピングを表す2つのコンストラクタです。
(1) __getitem__関数がインデックス索引をもとにデータをたどる役割を果たす場合
(2) __len__関数は、データセットの長さを返します。
(3) 作成したデータセットクラスでは、必要に応じてデータを加工することができます。getitem__ 関数内で呼び出すデータ処理関数を別途記述してもよいし、 __getitem__ 関数や __init__ 関数内に直接データ処理メソッドを記述してもよいですが、 __getitem__ はインデックスに基づく応答値を返す必要があり、これを dataloader に渡してその後のバッチ処理を行うことになります。
つまり、基本的に満たしている。
__len__()
def __getitem__(self, index):
return self.src[index], self.trg[index]
def __len__(self):
return len(self.src)
彼のおおよその構造を見ること。
class Dataset(object):
"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key.
Subclasses could also optionally overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key.
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
Sampler` implementations and the default options of :class:`~torch.utils.data.
... Note::
DataLoader` by default constructs a index
sampler that yields integral indices.
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
上記のコードは、pytorchのDatasetsのソースコードです。メンバメソッドである __getitem__ と __len__ は両方とも未実装であることに注意してください。もし、データの読み込みを行うためにカスタムのDatasetsクラスを実装したい場合は、これら2つのメンバーメソッドをオーバーライドするだけでよいのです。
まず 取得項目 () メソッドを用いて,データセットから学習画像(既に CV 距離)とラベルを含むデータを読み込みます.また,パラメータ index は,全データセットにおける画像とラベルのインデックスを表します.
len () メソッドはデータセットの総長(学習セットの総数)を返します.
MyDatasetsクラスの2つの簡単な実装を以下に説明します。
実装方法1(シンプルでわかりやすい方法)
ポイントは、x と label の両方を self.src と self.trg という2つの別々のリストにロードし、そのリストに対して getitem (self, index) を使って対応する要素を返す。
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class My_dataset(Dataset):
def __init__(self):
super(). __init__().
# Use the sin function to return 10000 time series, if you do not construct your own data, you can use numpy, pandas, etc. to read your own data as x.
# The following piece of data organization can be placed either in the init method or in the getitem method
self.x = torch.randn(1000,3)
self.y = self.x.sum(axis=1)
self.src, self.trg = [], []
for i in range(1000):
self.src.append(self.x[i])
self.trg.append(self.y[i])
def __getitem__(self, index):
return self.src[index], self.trg[index]
def __len__(self):
return len(self.src)
# or return len(self.trg), src and trg are the same length
data_train = My_dataset()
data_test = My_dataset()
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)
# How much i_batch is determined by the batch size and the length returned by def __len__(self)
# The value returned by batch_data is determined according to def __getitem__(self, index)
# For the training set: (try more prints if you're not sure what enumerate returns)
for i_batch, batch_data in enumerate(data_loader_train):
print(i_batch) # Print the batch number
print(batch_data[0]) # print the src inside the batch
print(batch_data[1]) # Print the trg of the batch
# For the test set: (the following statement will also work)
for i_batch, (src, trg) in enumerate(data_loader_test):
print(i_batch) # Print the batch number
print(src) # Print the size of src inside the batch
print(trg) # Print the size of trg inside the batch
もう一言:生成されたdata_trainは、data_train[xxx]で直接インデックスを付けるか、next(iter(data_train))でデータのストリップを取得することが可能です。
実装2(TensorDatasetの助けを借りてデータを直接データセットクラスにラップする)。
もう一つの方法は、TensorDatasetを直接使ってデータをdatasetクラスにラップし、dataloaderを使う方法である。
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
src = torch.sin(torch.range(1, 1000, 0.1))
trg = torch.cos(torch.range(1, 1000, 0.1))
data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
print(i_batch) # Print the batch number
print(batch_data[0].size()) # print the src inside the batch
print(batch_data[1].size()) # print the trg inside the batch
を出力します。
0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
...
実装方法3(アドレス読み取り方式)
lfwのように、各データがフォルダに対応するようなデータセットや、データ量が多くて一度に読み出せないような場合。そして、そのようなデータセットに要求されるのは
txt file
インデックスを作成することができます。
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg
# Generate path-label map.txt for all images This program can be modified as needed
def generate_map(root_dir):
# Get the current absolute path
current_path = os.path.abspath('.')
# os.path.dirname() to go back one path
father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ". ")
with open(root_dir + 'map.txt', 'w') as wfp:
for idx in range(10):
subdir = os.path.join(root_dir, '%d/' % idx)
for file_name in os.listdir(subdir):
abs_name = os.path.join(father_path, subdir, file_name)
# linux_abs_name = abs_name.replace("\\\", '/')
wfp.write('{file_dir} {label}\n'.format(file_dir=linux_abs_name, label=idx))
# Implement the MyDatasets class
class MyDatasets(Dataset):
def __init__(self, dir):
# Get the dir where the data is stored
# For example, d:/images/
self.data_dir = dir
# list for storing (image, label) tuple, storing data such as (d:/image/1.png,4)
self.image_target_list = []
# Read all pairs of tuples from dir--label's map file into image_target_list
# map.txt stores all the tuples in d:/... /image_data/1/3.jpg 1 The path should preferably be an absolute path
with open(os.path.join(dir, 'map.txt'), 'r') as fp:
content = fp.readlines()
#s.rstrip() removes the specified character at the end of the string (default is character)
# get [['d:/... /image_data/1/3.jpg', '1'], ... ,]
str_list = [s.rstrip().split() for s in content]
# Put all the images' dir--label pairs into the list, if you want to perform multiple epochs, you can copy them more than once here, and then shuffle them uniformly is better
self.image_target_list = [(x[0], int(x[1])) for x in str_list]
def __getitem__(self, index):
image_label_pair = self.image_target_list[index]
# Read image data by path and convert to image format e.g. [3,32,32]
# You can use something else instead
img = mpimg.imread(image_label_pair[0])
return img, image_label_pair[1]
def __len__(self):
return len(self.image_target_list)
if __name__ == '__main__':
# generate map.txt
# generate_map('train/')
train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)
for step in range(20000):
for idx, (img, label) in enumerate(train_loader):
print(img.shape)
print(label.shape)
バイナリファイルなど他の形式のデータを使う場合は、ファイルのバイトを読み込んで、各画像とラベルに分割し、それを __getitem__ から返すだけで良いのですが、その場合は、ファイルのバイトを読み込んで、各画像とラベルに分割し、それを __getitem__ から返します。例えば、cifar-10 のデータであれば、__getitem__ メソッドで index で対応する位置のバイトを読み込み、label と img に変換して返すだけでよいのです。DataLoaderは、index,len,batch_size,shuffleの指定に従って、対応するバッチデータとラベルを返すことができる。
1.1.1 イテラブル形式のデータセット
Iterable 形式のデータセットは、IterableDataset のインスタンスで、 データセットを反復処理するための __iter__ メソッドをオーバーライドしなければなりません。この形式のデータセットは、特にデータのランダムな読み込みができない場合に適しており、 バッチサイズは取得するデータによって異なります。この形式は、たとえばデータベースやリモートサーバー、ライブログなどからデータを読み込む際に使用します。一般に、時系列データには使用しません。
例えば、このようなデータセットでiter(dataset)を呼び出すと、データベースやリモートサーバー、あるいはリアルタイムで生成されたログから読み込んだデータのストリームを返すことができる。
。。。。。
ここでは詳しく説明しません、あまりに複雑なので〜。
2 torchvision.datasets(トーチビジョン.データセット
このパッケージの目的は、既製のデータセットの提供を容易にすることである。
torchvision.datasets
には、以下のデータセットが含まれています。
-
MNIST
-COCO (画像アノテーションとターゲット検出用) (キャプションと検出)
-LSUNの分類
-画像フォルダ
-イメージネット-12
-CIFAR10とCIFAR100
-STL10
データセットには以下のAPIがあります。
__getitem__
と
__len__
具体的な使い方は参考文献4(torch.utils.data.DataLoaderを使用)を参照してください。
2.1 ImageFolder
これは、DatasetFolderと同様に、ダウンロードしたデータセットで、一定の要件を満たすものに適している。
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
使用方法
my_transform = transforms.Compose(
[transforms.ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
torchvision.datasets.ImageFolder(root=". /my_dataset/", transform=my_transform)
3 加工例
1.1.1節で3つのデータセット読み込み方法を説明したが、今回はCrimeデータセットによるもう一つのデータセット読み込み方法を紹介する。この方法は
DataLoader
は関係なく、実装の複雑さも平均的である。
import numpy as np
from matplotlib import pyplot as plt
import os
import torch
class CrimeDataset():
def __init__(self, device):
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/communities.data'))
attributes = []
while True:
# Read a comma-separated dataset file
line = reader.readline().split(',')
if len(line) < 128:
break
# set the ? as -1
line = ['-1' if val == '? else val for val in line]
line = np.array(line[5:], dtype=np.float)
attributes.append(line)
reader.close()
# attributes.shape=(1994, 123)
attributes = np.stack(attributes, axis=0)
# load the name of each column; total: 128
reader = open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/names'))
names = []
for i in range(128):
# reader.readline().split() = ['@attribute', 'county', 'numeric'] and we choose 'county'
line = reader.readline().split()[1]
# exclude the first 5 columns. thus the number of column names = 123, arroding with attributes.shape
if i >= 5:
names.append(line)
names = np.array(names)
# shuffle the attribute by axis0
attributes = attributes[np.random.permutation(range(attributes.shape[0])), :]
val_size = 500
# the last column of attributes is the labels
self.train_labels = attributes[val_size:, -1:]
self.test_labels = attributes[:val_size:, -1:]
# exclude the last column of attributes. thus attributes.shape = (1994,122)
attributes = attributes[:, :-1]
# select the column whose minimum >= 0. selected has 99 features
selected = np.argwhere(np.array([np.min(attributes[:, i]) for i in range(attributes.shape[1])]) >= 0).flatten()
self.train_features = attributes[val_size:, selected]
self.test_features = attributes[:val_size:, selected]
self.names = names[selected]
# self.train_ptr is the counter which counts the number of data records having been loaded
self.train_ptr = 0
self.test_ptr = 0
self.x_dim = self.train_features.shape[1]
# train_size = 1494; test_size = 500
self.train_size = self.train_features.shape[0]
self.test_size = self.test_features.shape[0]
self.device = device
def train_batch(self, batch_size=None):
# if batch_size is None, then each iteration outputs all the training set
if batch_size is None:
batch_size = self.train_features.shape[0]
self.train_ptr = 0
# if all data has been outputed, reset the trailoader.
if self.train_ptr + batch_size > self.train_features.shape[0]:
self.train_ptr = 0
bx, by = self.train_features[self.train_ptr:s
5つの便利な機能
5.1 データローダーの分割
時には、完全なデータセットがtorchvisionからダウンロードされることがあります。
dataloader
' 後に、さらにデータセットを分割したい。
def split(dataloader, batch_size, split=0.2):
"""Splits the given dataset into training/validation.
Args:
dataset[torch dataloader]: Dataset which has to be split
batch_size[int]: Batch size
split[float]: Indicates ratio of validation samples
Returns:
train_set[list]: Training set
val_set[list]: Validation set
"""
index = 0
length = len(dataloader)
train_set = []
val_set = []
for data, target in dataloader:
if index <= (length * split):
train_set.append([data, target])
else:
val_set.append([data, target])
index += 1
return train_set, val_set
には、より優れた分割方法が用意されています。 pytorchデータセットのセグメンテーション
参考
https://www.cnblogs.com/leokale-zz/p/11275800.html
https://www.lagou.com/lgeduarticle/74174.html
本当にありがとうございました!https://blog.csdn.net/zuiyishihefang/article/details/105985760
トーチビジョンデータセット
関連
-
[解決済み】pytorchでテンソルを平らにする方法は?
-
[解決済み] Pytorch ある割合で特定の値を持つランダムなint型テンソルを作成する方法は?例えば、25%が1で残りが0というような。
-
AttributeError NoneType オブジェクトに属性データがない。
-
PytorchがNotImplementedErrorを発生させるようです。
-
pytorchのSpeat()関数
-
pytorch学習におけるtorch.squeeze()とtorch.unsqueeze()の使用法
-
torch.stack()の使用
-
torch.catとtorch.stackの違いについて
-
PyTorchのF.cross_entropy()関数
-
AttributeError: 'Graph' オブジェクトには 'node' という属性がありません。
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン