[解決済み] PyTorchで行列の積を計算する方法
質問
numpyでは、以下のような簡単な行列の乗算を行うことができます。
a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))
しかし、PyTorchのTensorsで試すと、これがうまくいきません。
a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())
print(b)
print(b.size())
print(torch.dot(a, b))
このコードでは以下のようなエラーが発生します。
RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503 において発生しました。
PyTorchで行列の乗算を行うにはどうしたらいいか、何かアイデアはありませんか?
どのように解決するのですか?
あなたが探しているのは
torch.mm(a,b)
なお
torch.dot()
とは動作が異なります。
np.dot()
. 何が望ましいかについての議論もありました。
ここで
. 具体的には
torch.dot()
の両方を扱います。
a
と
b
を1次元ベクトルとして扱い、それらの内積を計算します。この動作によって
a
は長さ6のベクトルであり
b
は長さ2のベクトルなので、内積は計算できません。PyTorchで行列の乗算を行うには、以下のようにします。
torch.mm()
. Numpyの
np.dot()
はより柔軟で、1次元配列では内積を計算し、2次元配列では行列の乗算を実行します。
人気のある要望により、関数
torch.matmul
は行列の乗算を行いますが,引数が両方とも
2D
であればその内積を計算し、 両方の引数が
1D
. このような次元の入力では、その動作は次のものと同じです。
np.dot
. また、ブロードキャストや
matrix x matrix
,
matrix x vector
と
vector x vector
の操作を一括して行うことができます。詳しくは、その
ドキュメント
.
# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])
# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])
関連
-
[解決済み] プログラムの実行やシステムコマンドの呼び出しはどのように行うのですか?
-
[解決済み] リストのリストからフラットなリストを作るには?
-
[解決済み] Pythonで現在時刻を取得する方法
-
[解決済み] 辞書を値で並べ替えるにはどうしたらいいですか?
-
[解決済み] リストが空かどうかを確認するにはどうすればよいですか?
-
[解決済み】ネストされたディレクトリを安全に作成するには?
-
[解決済み】2つの辞書を1つの式でマージする(辞書の和をとる)には?)
-
[解決済み] Django のテストデータベースをメモリ上だけで動作させるには?
-
[解決済み] 古いバージョンのPythonにおける辞書のキーの並び順
-
[解決済み] Alembicアップグレードスクリプトでインサートやアップデートを実行するにはどうすればよいですか?
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
[解決済み] DataFrameの文字列、dtypeがobjectの場合
-
[解決済み] バブルソートの宿題
-
[解決済み] 小数点以下1桁を取得する[重複]。
-
[解決済み] Pythonのインスタンス変数とクラス変数
-
[解決済み] Django Rest Framework ファイルアップロード
-
[解決済み] python-requests モジュールからのすべてのリクエストをログに記録します。
-
[解決済み] スペースがないテキストを単語のリストに分割する方法
-
[解決済み] CSVデータを処理する際、1行目のデータを無視する方法を教えてください。
-
[解決済み] Pandasのデータフレーム内の文字列を'date'データ型に変換するにはどうしたらいいですか?
-
[解決済み] データクラスとtyping.NamedTupleの主な使用例