t44or2’s blog

PC関係の忘備録、ときどき趣味とか

ユニバーサル関数を作る

はじめに

numpy配列を直接引数に取れるように関数を定義すると、forで回すより早くなるらしい。

例えばnumpyデフォルトのnumpy.expnumpy.logなんかがそうである。こいつらは適当なnumpy配列aに対して、numpy.exp(a)などと使うと、aの全ての要素のexponentialを吐き出してくれる。

こういう関数を自分でも作りたいときにどうするのか、というのを調べたりしたのでまとめる。

例題:ReLU(ランプ関数)を実装してみる

ReLUはこういう関数である。

def ReLU(x):
  if x > 0:
    return x
  else:
    return 0

xがプラスの時はxをそのまま返し、それ以外の時はゼロにして返す。活性化関数とか呼ばれるものの一種だ。

上の定義のまま、numpy配列であるaを引数に打ち込んでしまうと、きっとエラーになるはずである。これを改良して、aを引数に取れるようにしたい。

numpy.where()

そのために今回使うのが、`numpy.where()'という関数である。 使い方は、

aというnumpy配列に対して、
    numpy.where("aについての条件式,"真の時の値","偽の時の値")

という感じだ。 ちなみに、条件式のみを引数にすると、インデックスを返す。

これを使って、 新しく次のように関数を定義する。

def unv_ReLU(x):
  return numpy.where(x>0, x, 0)

これでnumpy配列に対してReLUを一発で適用できる。 きっとこれより速くてスマートな方法はあるきがするけれど、こういう風にユニバーサルに定義するだけで引き締まったコードになったのでよかった。