1. ホーム
  2. python

[解決済み] torch.clampの列依存境界線

2022-02-16 04:14:31

質問

2次元配列のPyTorchテンソルでnp.clipと似たようなことをしたいです。具体的には、各列を特定の値の範囲(列依存)でクリップしたいと思います。例えば、numpyでは、以下のようなことができます。

x = np.array([-1,10,3])
low = np.array([0,0,1])
high = np.array([2,5,4])
clipped_x = np.clip(x, low, high)

clipped_x == np.array([0,5,3]) # True

torch.clampを見つけたのですが、残念ながら多次元境界をサポートしていません(テンソル全体に対して1つのスカラー値のみ)。私のケースにその関数を拡張する"neat"方法はありますか?

ありがとうございます。

解決方法は?

のようにきちんとしたものではありません。 np.clip を使用することができます。 torch.max torch.min :

In [1]: x
Out[1]:
tensor([[0.9752, 0.5587, 0.0972],
        [0.9534, 0.2731, 0.6953]])

列ごとの下限と上限の設定

l = torch.tensor([[0.2, 0.3, 0.]])
u = torch.tensor([[0.8, 1., 0.65]])

なお、下限の l と上界 u は1×3のテンソル(単次元を持つ2次元)である。これらの次元は lu になります。 ブロードキャスト可能 の形状に x .
を使用してクリップすることができます。 minmax :

clipped_x = torch.max(torch.min(x, u), l)

で結果

tensor([[0.8000, 0.5587, 0.0972],
        [0.8000, 0.3000, 0.6500]])