[解決済み] カスタムデータセットをトレーニングデータとテストデータに分割するにはどうしたらいいですか?
2023-05-26 12:49:54
質問
import pandas as pd
import numpy as np
import cv2
from torch.utils.data.dataset import Dataset
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, transform=None):
self.data = pd.read_csv(csv_path)
self.labels = pd.get_dummies(self.data['emotion']).as_matrix()
self.height = 48
self.width = 48
self.transform = transform
def __getitem__(self, index):
pixels = self.data['pixels'].tolist()
faces = []
for pixel_sequence in pixels:
face = [int(pixel) for pixel in pixel_sequence.split(' ')]
# print(np.asarray(face).shape)
face = np.asarray(face).reshape(self.width, self.height)
face = cv2.resize(face.astype('uint8'), (self.width, self.height))
faces.append(face.astype('float32'))
faces = np.asarray(faces)
faces = np.expand_dims(faces, -1)
return faces, self.labels
def __len__(self):
return len(self.data)
これは、他のリポジトリからの参照を利用することでなんとかできたことです。 しかし、このデータセットをtrainとtestに分けたいのです。
このクラスの中でどのようにそれを行うことができますか?それとも、そのために別のクラスを作成する必要がありますか?
どのように解決するのですか?
Pytorchの
SubsetRandomSampler
:
import torch
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, transform=None):
self.data = pd.read_csv(csv_path)
self.labels = pd.get_dummies(self.data['emotion']).as_matrix()
self.height = 48
self.width = 48
self.transform = transform
def __getitem__(self, index):
# This method should return only 1 sample and label
# (according to "index"), not the whole dataset
# So probably something like this for you:
pixel_sequence = self.data['pixels'][index]
face = [int(pixel) for pixel in pixel_sequence.split(' ')]
face = np.asarray(face).reshape(self.width, self.height)
face = cv2.resize(face.astype('uint8'), (self.width, self.height))
label = self.labels[index]
return face, label
def __len__(self):
return len(self.labels)
dataset = CustomDatasetFromCSV(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler)
# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
# Train:
for batch_index, (faces, labels) in enumerate(train_loader):
# ...
関連
-
[解決済み] 関数デコレータを作成し、それらを連鎖させるには?
-
[解決済み] 割り当て後にリストが予期せず変更されました。その理由と防止策を教えてください。
-
[解決済み] リストを均等な大きさの塊に分割するには?
-
[解決済み] どうすれば、文字列中のリテラルな中抜き文字を印刷し、また.formatを使用することができるのでしょうか?
-
[解決済み] 2つのリストを辞書に変換するにはどうしたらいいですか?
-
[解決済み] テキストファイルを文字列変数に読み込んで、改行を除去するには?
-
[解決済み] 文字列を複数の単語境界のデリミタで単語に分割する
-
[解決済み] Pythonで0xを使わずにhex()を使うには?
-
[解決済み] Pythonで、ウェブサイトが404か200かを確認するためにurllibをどのように使用しますか?
-
[解決済み] pipの依存性/必要条件をリストアップする方法はありますか?
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
[解決済み] PythonでのAWS Lambdaのインポートモジュールエラー
-
[解決済み] Pythonのマルチプロセッシングプールimap_unorderedの呼び出しの進捗を表示しますか?
-
[解決済み] PythonでファイルのMD5チェックサムを計算するには?重複
-
[解決済み] SQLAlchemy: 日付フィールドをフィルタリングする方法は?
-
[解決済み] なぜ(0-6)は-6=偽なのか?重複
-
[解決済み] スペースがないテキストを単語のリストに分割する方法
-
[解決済み] pandasのタイムゾーンに対応したDateTimeIndexを、特定のタイムゾーンに対応したナイーブなタイムスタンプに変換する。
-
[解決済み] あるオブジェクトが数であるかどうかを確認する、最もパイソン的な方法は何でしょうか?
-
[解決済み] Cythonのコードを含むPythonパッケージはどのように構成すればよいのでしょうか?
-
[解決済み] Pythonでリストが空かどうかをチェックする方法は?重複