最適輸送距離について調べていた時に、何をやっているのかパッとわからないノルムの取り方をしているコードがあった。
最適輸送入門
n = 100 # 点群サイズ mu = np.random.randn(n, 2) # 入力分布 1 nu = np.random.randn(n, 2) + 1 # 入力分布 2 C = np.linalg.norm(nu[np.newaxis] - mu[:,np.newaxis], axis=2) # コスト行列
mu, nuは二次元分布を表しており、Cがmuの各点とnuの各点の間の距離を表す行列になっている。
Cの計算(定義)の部分は、以下の愚直な内容と同じになっている。
C = np.zeros([n,n]) for i1, mu_ in enumerate(mu): for i2, nu_ in enumerate(nu): C[i1][i2] = np.linalg.norm(mu_-nu_)
np.linalg.normのaxisは以下で説明されている。
normのordとaxisの役割がわからない
個人的には、normのaxisの意味は「指定したaxisの自由度を潰す様にnormを取って、残りの次元は残す」という風に理解した。
2次元の行列を例に取ると、
- axis=0: 行を潰す様にノルムを取るため、残るのは列方向の横ベクトル
- axis=1: 列を潰す様にノルムを取るため、残るのは行方向の縦ベクトル(実際の出力では普通に1次元配列(横ベクトル))
また、ndarrayの使用で、要素数が1の次元はあたかもスカラーのように加減乗除できる。
したがって、np.newaxisで増やした次元部分は要素数が1しかないため、v1, v2が要素数nの1次元配列だとすると、
- v1[np.newaxis]: (1,n)の2次元配列(行列)
- v2[:, np.newaxis]: (n,1)の2次元配列(行列)
v1[np.newaxis] - v2[:, np.newaxis]
(n,n)の2次元配列となる。
v1[np.newaxis] - v2[:, np.newaxis]の各列は、v1のある1要素に対して、v2の各要素で差を取ったものが列の要素として並ぶ。
C = np.linalg.norm(nu[np.newaxis] - mu[:,np.newaxis], axis=2)に対し、
- nu[np.newaxis]: (1,n,2)の3次元配列
- mu[:, np.newaxis]: (n,1,2)の3次元配列
- nu[np.newaxis] - mu[:,np.newaxis]: (n,n,2)の3次元配列
であるため、np.linalg.norm(nu[np.newaxis] - mu[:,np.newaxis], axis=2)で要素数2の次元(x,y座標に対応するベクトル部分)でノルムを取り、その次元を潰した(n,n)の2次元配列になる。
説明が難しいが、「nu[np.newaxis]のaxis=1の要素(要素数2の1次元配列)を固定したときに、mu[:, np.newaxis]のaxis=0における(要素数2の)各1次元配列で差を取ったものが並んでいる」とでも言えば良いだろうか。
そういう訳で、一見何をしているかわからなかったが、結果は非常に単純なものが得られていることがわかった。