(備忘録-python)optuna+prophetの使い方まとめ-時系列データ解析-

IT系知識

祝日効果・パラメータ調整(optuna)を含めたoptunaの使い方を備忘録的にまとめてみました。コピペして使えるようにしてありますので、使ってみてください。(後日update予定)

1.プロフェットについて

Prophet(プロフェット)は、Facebookによって開発されたオープンソースの時系列予測ツールです。時系列データ(時間の経過に伴うデータの変化)を分析し、トレンドや季節変動などを自動的に捉えて将来の値を予測することができます。

Prophetは、以下の特徴によって広く使われています:

  1. シンプルさ: ユーザーフレンドリーなインターフェースとシンプルなパラメータ設定により、初心者から専門家まで幅広いユーザーにとって扱いやすいツールです。
  2. 柔軟性: 多くのタイプの時系列データに適応できるよう設計されており、欠損値の処理や外れ値の取り扱いにも柔軟に対応します。
  3. 季節変動の考慮: 季節的なパターンや週末効果など、時系列データの周期的な変動をモデルに組み込むことができます。
  4. カスタマイズ性: ユーザーがトレンドや季節変動の要素をカスタマイズして組み込むことができるため、予測の精度向上が可能です。
  5. スケーラビリティ: 大規模な時系列データセットにも対応しており、高速な予測を行うことができます。

Prophetはビジネスや金融、気象予測、eコマースなど、さまざまな分野で需要があります。データの特性を理解し、簡単なコードで高品質な予測を得ることができるため、幅広いアプリケーションに利用されています。

2.重要な5パラメータ

他にもパラメータは存在するが一旦は下のパラメータを調整すると良いと考えている。

  • ‘changepoint_prior_scale’ :トレンドの柔軟性
  • ‘seasonality_prior_scale’ :季節成分の柔軟性
  • ‘seasonality_mode’ :モデルを加法か乗法か決める
  • ‘holidays_prior_scale’ :休日成分の柔軟性
  • ‘changepoint_range’ :変化点を検出する期間を制限

3.祝日について

以下のようなカラムになるようにデータを成形する。

  • ds:年月日
  • holiday:祝日名
  • lower_window:祝日の影響範囲(x日前から影響)
  • upper_window:祝日の影響範囲(x日後まで影響)

4.重回帰分析(.regresser)を含めた学習

備忘録的に使ったライブラリを残しておく。

Python==3.8(prophet利用のため)

pip install pandas
pip install matplotlib

pip install statsmodels
pip install scikit-learn

pip install prophet#!pip install fbprophet
pip install pystan
pip install cython
pip install japanize-matplotlib

まずはoptunaをインポートする。

from prophet import Prophet
import optuna

optunaとの連携について

#===ここを変える===
train_data_1=学習データ
test_data_1=訓練データ(複数用意し、学習期間もこの行数で決まる)

週次性、年次性があるかをデータで確認後に、あればTrueに変更する。

#===ここを変える===
yearly_flag=False
weekly_flag=True

optunaの準備をする

def objective(trial):
    params = {'changepoint_prior_scale' : 
                 trial.suggest_float('changepoint_prior_scale',
                                       0.001,0.5
                                      ),
              'seasonality_prior_scale' : 
                 trial.suggest_float('seasonality_prior_scale',
                                       0.01,10
                                      ),
              'seasonality_mode' : 
                 trial.suggest_categorical('seasonality_mode',
                                           ['additive', 'multiplicative']
                                          ),
              
              'holidays_prior_scale' : 
                 trial.suggest_float('holidays_prior_scale',
                                       0.01,10
                                      ),
              
              'changepoint_range' : 
                  trial.suggest_float('changepoint_range', 
                                                 0.8, 0.95, 
                                                 step=0.001),
 
             }
    
    model = Prophet(#holidays=prop_holiday,#祝日のDFを使う場合
                    yearly_seasonality=yearly_flag,
                    weekly_seasonality=weekly_flag,
                    **params)

  #===ここを変える===
    model.add_regressor("重回帰用の変数")



    #===ここを変える=== DATE=>データの時系列変数(年月日)
    model.fit(train_data_1.rename(columns={'DATE':'ds','目的変数':'y'})) 

  #===ここを変える===
    forecast =  model.predict(test_data_1.drop("目的変数",axis=1).rename(columns={'DATE':'ds'}))
    
  #===ここを変える===
    val_mae = mean_absolute_error(test_data_1.目的変数, forecast.yhat)
    return val_mae

optunaの実行をする。

study = optuna.create_study(direction="minimize",
                            sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=100)

optunaの結果を出力する。

print("=======ベストパラメータ========")
print(study.best_params)

optuna反映

model_o = Prophet(
    #holidays=prop_holiday,#祝日のDFを使う場合
    yearly_seasonality=yearly_flag,
    weekly_seasonality=weekly_flag,
    **study.best_params)

model_o.add_regressor("重回帰用の変数")
#===ここを変える=== DATE=>データの時系列変数(年月日)
model_o.fit(train_data_1.rename(columns={'DATE':'ds','目的変数':'y'}))

#===ここを変える=== DATE=>データの時系列変数(年月日)
test_pred_o = model_o.predict(df=test_data_1.drop("目的変数",axis=1).rename(columns={'DATE':'ds'}))
# test_pred_o.head()

コンポーネントについて

# Plot the components of the model
fig = model_o.plot_components(test_pred_o)

plt.show()

5.まとめ

完全に備忘録的にまとめてあるため、中身については省略してしまいましたが、使いながら大枠から本質までをつかめると思いますので、良ければコピペして使ってみてください。

コメント

タイトルとURLをコピーしました