サムネがコーヒーの記事は書きかけです。

ベイズ情報量規準(BIC)による多項式回帰の最適化【機械学習】

多項式回帰

以下で与えられる多項式回帰は、フィッティングの際に多項式オーダーを上げすぎるとオーバーフィットしてしまうという側面があります。

$$\hat{y} = \beta_0 + \sum_{i=1}^n \beta_i x^i,\:\exists \beta_i \in \mathbb{R}$$

そのため、ベイズ情報量規準を用いて多項式オーダーを最適化する必要が出てきます。

ベイズ情報量規準

BICは以下のように定義されます。

$BIC =n\ln (\hat{\sigma}^2)+k\ln (n) $

この時、$n\ln (\hat{\sigma}^2)$は回帰曲線と実測値のズレの部分を評価します。

また、$k\ln (n) $は多項式オーダーによる複雑さを評価します。

よって、両方の合計が小さくなる方向、つまり$ BIC_{min}$を求めることで、多項式回帰の妥協点となる多項式オーダーを探索します。

多項式回帰の実装

Numpyを使用して、多項式回帰を行なってみます。

初めにターゲットとなる多項式(k=5)を設定しておきます。

import matplotlib.pyplot as plt 
import numpy as np 
import random 

fig = plt.figure()

def pol(i:int) -> float:
    return 0.01*i**5-0.05*i**4+0.03*i**3+0.004*i**2-i

x = [i for i in range(-100,100)]
y = [pol(i)+random.randint(-4300**2,5990**2) for i in x]
y_h = [pol(i) for i in x]

この時、

$$\begin{align}\hat{y}&=\beta_0 +\sum_{i=1}^9\beta_i x^i\\&=\boldsymbol{\beta}^\mathrm{T}\mathbf{X}\end{align}$$

となる$\boldsymbol{\beta}$を探します。

初期状態では、散布図は以下のようになります。

k=0からk=9までの多項式回帰を実装すると、以下のようになります。

def pol_k(x,l):
    return l[0]*x**9+l[1]*x**8+l[2]*x**7+l[3]*x**6+l[4]*x**5+l[5]*x**4+l[6]*x**3+l[7]*x**2+l[8]*x**1+l[9]*x**0+l[-1]

for i in range(10):
    z_i = [0 for i in range(9-i)]+[i for i in np.polyfit(x,y,i)]
    plt.scatter(x,[pol_k(i,z_i) for i in x],s = 3,label = f"k={i}")
>>>
[      0.          0.          0.          0.          0.          0.
       0.          0.          0.    5046301.094]
[      0.               0.               0.               0.
       0.               0.               0.               0.
  397288.04353589 5244945.11576795]
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00 -8.47842117e+02
  3.96440201e+05  8.07080289e+06]
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  1.12512322e+02 -6.79073634e+02
 -2.78509965e+05  7.73329968e+06]
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00 -7.72556166e-03  1.12496871e+02 -6.12873296e+02
 -2.78443757e+05  7.66711369e+06]
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  6.33179017e-03  8.10391376e-03  4.21717878e+01 -7.18376750e+02
 -1.27797390e+05  7.74245445e+06]
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -1.09101466e-04
  6.00448577e-03  1.49505771e+00  4.51462409e+01 -5.67257602e+03
 -1.32753076e+05  1.01006531e+07]
[ 0.00000000e+00  0.00000000e+00 -4.23700835e-07 -1.10584419e-04
  1.28442069e-02  1.51216072e+00  1.40757749e+01 -5.71919883e+03
 -9.82493290e+04  1.01179127e+07]
[ 0.00000000e+00 -1.14262374e-08 -4.69405785e-07  1.02524050e-04
  1.34836922e-02  2.83652065e-01  1.16176917e+01 -3.48714803e+03
 -9.60160490e+04  9.49830179e+06]
[ 7.54614131e-10 -8.03047385e-09 -1.64329570e-05  4.66357741e-05
  1.25120506e-01  5.62883829e-01 -2.74372830e+02 -3.91641307e+03
  9.88149203e+04  9.59578883e+06]

出力結果

それぞれのkに対するBICを求めるには、$\hat{\sigma}^2$を求める必要があるので、以下のようにコードを書き換えます。

for i in range(10):
    z_i = [0 for i in range(9-i)]+[i for i in np.polyfit(x,y,i)]
    y_pk =  [pol_k(i,z_i) for i in x]
    sigma_2 = sum([(y_pk[i]-y_h[i])**2 for i in range(len(x))])/len(x)
    plt.scatter(i,len(x)*np.log(sigma_2)+i*np.log(len(x)))

出力結果

元の多項式オーダーであるk=5が妥協点であることがわかります。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です