株価予測(二値分類)に関する論文を読んでみた

はじめに

この記事は、翻訳記事ではありません。

私は英語が苦手です。

深層学習に関しても専門家ではありませんのでご了承ください。

論文について

こちらが対象の論文です。

https://arxiv.org/pdf/1903.12258.pdf

2019/02/16に書かれたものだそうです。

Abstract

株式市場の予測はニュースや業績、投資家の感情、ソーシャルメディアの感情など様々な要因が影響するため予測が難しい。

この論文では様々な手法とロウソク足を使用して、予測をしている。

台湾とインドネシアの株式市場のデータで試したところ、accuracyが約92%になったらしい。

1. Introduction

この研究ではいくつかの手法を比較している。

 - Convolutional Neural Network(CNN)
 - Residual Network(ResNet)
 - Visual  Geometry Group Network(VGG)
 - k-nearest neighborhood(k近傍法)
 - random forest(ランダムフォレスト)

またその他に、50×50と20×20のロウソクチャートを使用する。

この研究のゴールは、期間や画像サイズ、その他のパラメータを調整して、相関関係をみることです。

また、翌日を予測します。

2. Related Work

株の予測に関する、関連した研究が書いてあります。

3. Dataset

3-1. Data Collection

データの取得について書かれています。

台湾とインドネシアの株式市場を対象にしていて、

台湾は50社分、インドネシアでは10社分取得しています。

収集の際、YahooファイナンスAPIを利用しているようです。

3-2. Data Preprocessing

データの前処理についてです。

時系列データからMatplotlib(Pythonのライブラリ)を使用して、ロウソク足チャートに変換しています。

この論文では、学習データに以下を用いて比較しています。

 - 5日分(出来高無し)
 - 10日分(出来高無し)
 - 20日分(出来高無し)
 - 5日分(出来高有り)
 - 10日分(出来高有り)
 - 20日分(出来高有り)

4. Methodology

YahooファイナンスAPIからロウソク足チャートに変換し、画像を生成します。

それらの画像をCNNにかけ、株価が上昇or下降の2値に分類されます。

以下の画像は論文の画像を引用したものです。

f:id:a_shiba:20190430135711p:plain
http://140.138.155.216/deepcandle/ より引用

4-1. Candlestick Chart

ロウソク足チャートの説明が書かれています。

始値終値、安値、高値の4つから構成される株取引では一般的なものです。

4-2. Learning Algorithm

この論文では1. Introductionに記述した以下の5つの手法を使用しています。

※それぞれの手法の詳しい説明はしません。

 - Convolutional Neural Network(CNN)
 - Residual Network(ResNet)
 - Visual  Geometry Group Network(VGG)
 - k-nearest neighborhood(k近傍法)
 - random forest(ランダムフォレスト)
4-2-1. Convolutional Neural Network

CNNの説明が書かれています。

モデルは、4層の畳み込み、4層のプーリング、3つのドロップアウトを使用しています。

4-2-2. Residual Network

2015年に開発されたResNetです。

層をスキップするためにショートカットできるようになっています。

これにより、層が深くても、誤差逆伝搬を可能にします。

4-2-3. VGG Network

2014年に開発されたVGGです。

学習が非常に遅いことが欠点らしいです。

4-2-4. Random Forest

多くの決定木からなる分類器。

詳しい説明はこちら

mathwords.net

4-2-5. K-Nearest Neighbors

詳しい説明はこちら qiita.com

4-3. Performance Evaluation

評価の計算方法は以下の通りです。

f:id:a_shiba:20190430151847p:plain
評価方法

Sensitivity:ポジティブ率
Specificity:ネガティブ率
Acccuracy:正確度
MCC:マシューズ相関係数
※MCCは-1から1の値を取る。

5. Experimental Results and Discussion

実験結果です。

5-1. Classification for Taiwan 50 Dataset

まずは台湾の50社の結果です。

最も精度がよかったものは、20日を対象にしたデータで、かつ出来高無しのものでした。

出来高有りのときも20日を対象にしたデータが最も精度が良いです。

(詳しい結果は論文を読んでください)

5-2. Classification for Indonesia 10 Dataset

次にインドネシアの10社の結果です。

こちらも、台湾と同様に、20日を対象で、かつ出来高なしのものが最も精度が良いそうです。

5-3. Independent Testing Result

20日分の学習データで50×50の画像、が最も良い結果が得られました。

5-4. Comparison

関連する論文と精度を比較しても、良い結果を得られたということが書かれています。

6. Conclusions and Future Works

結論と今後の課題です。

CNNは長期の取引日数を用いたモデルが最も良い精度が出ることを証明しました。

以下のサイトは、このモデルを使用して株価市場を予測するものです。

140.138.155.216

ユーザーは目標日を入力するだけで、予測結果を得られます。

コンピューターに詳しくない人でも、簡単に使えるように構築されています。

感想

次の日を予測する場合、

「昨日の終値=今日の始値」でないと予測したとしても、あまり効果は得られないのではないかと考えます。

今回紹介した論文にあるモデルを使って、3日後や1週間後を予測してみるも良さそうな気がしたので、試そうと思います。

pythonとmpl_financeを使って株価チャートを表示させる方法

環境

  • ubuntu 18.04(WSLによる構築)
  • python3.7(anacondaによる構築)

はじめに

機械学習用の学習データを集める際に利用したものです。

データがcsv形式で存在すれば実行できると思います。

以下のように出力されます。

f:id:a_shiba:20190430145258p:plain
出力例

※学習データ用なので、人間が見やすいものではありません。

コード

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import csv
import mpl_finance
from pathlib import Path

CSV_PATH = "histday"
CHART_IMG_PATH = "chart"

def num2chart(code, number, opens, highs, lows, closes, judge):
# ======================================
# code: 銘柄コード
# number: 何番目のブロックを取得するか
# opens: 始値(一次元配列)
# highs: 高値(一次元配列)
# lows: 安値(一次元配列)
# closes: 終値(一次元配列)
# judge: 上がったか下がったか
# ======================================
    fig = plt.figure(figsize=(0.5, 0.5), facecolor="k", edgecolor="k")
    ax = fig.add_subplot(1, 1, 1)
    fig.patch.set_facecolor('black') 
    fig.patch.set_alpha(0)
    ax.patch.set_facecolor('black')
    ax.patch.set_alpha(1)
    mpl_finance.candlestick2_ohlc(ax, opens=opens, highs=highs, lows=lows, closes=closes,  width=1, alpha=1, colorup='r', colordown='b')
    plt.tick_params(labelbottom=False,
                    labelleft=False,
                    labelright=False,
                    labeltop=False)
    plt.tick_params(bottom=False,
                    left=False,
                    right=False,
                    top=False)
    ax = plt.gca() 
    ax.spines["right"].set_color("none")
    ax.spines["left"].set_color("none")  
    ax.spines["top"].set_color("none")   
    ax.spines["bottom"].set_color("none") 
    plt.style.context('classic')
    plt.rcParams['axes.xmargin'] = 0
    plt.rcParams['axes.ymargin'] = 0
    plt.savefig("{}/{}/{}_{}.png".format(CHART_IMG_PATH, judge, code, number))
    plt.cla()
    plt.clf()
    plt.close()


def main(block, slide_span, predict_day):

    p = Path(CSV_PATH)
    csv_list = [file for file in list(p.glob("*"))]

    start = 0
    for csv in csv_list:
        print(csv)
        df = pd.read_csv(csv)
        for i in range(start, len(df)-block-predict_day, slide_span):
            code = df['code'][0]
            opens = df['start']
            highs = df['high']
            lows = df['low']
            closes = df['end']

            # ブロックの最後の要素と最後の要素からpredicr_day日後の終値を比較
            # 1.0より上なら上昇、以下なら下降したという仕分け
            if closes[i+block-1+predict_day] / closes[i+block-1] > 1.0:
                judge = "up" 
            else:
                judge = "down"
            num2chart(code, i, opens[i:i+block], highs[i:i+block], lows[i:i+block], closes[i:i+block], judge)

        start = 0

if __name__ == "__main__":
    # 生成するチャートの期間
    block = 20
    # 何日ずらすか
    slide_span = 20
    # 何日後のデータを参照するか
    predict_day = 3
    main(block, slide_span, predict_day)

感想

機械学習で使う場合、軸や目盛線は消したほうが良いのか悪いのか分かりません。

そのあたりいつか検証してみたいところです。

また、最近ではbokehというライブラリも流行っているそうです。

こちらはグラフをweb上で動かしたりするのが簡単にできるそうです。

matplotlibのGitHub。スター数: 9115 github.com

bokehのGitHub。スター数: 9321 github.com

Pythonによるゴールデンクロスの検出

ゴールデンクロスとは

短期移動平均線が、長期移動平均線を下から上に抜けることをいいます。

詳しい説明はこちらを参考にしてください。 www.sevendata.co.jp

コード

私の書いたプログラムの一部を抜き出して少し加工したものです。

以下のコード単体での実行はしていないため、ミスをしている場合があります。

import pandas as pd

# 株価のデータ読み込み
df = pd.read_csv(株価データ)

# 25日移動平均線
mal_25d =df['close'].rolling(window=25).mean()

# 50日移動平均線
mal_50d =df['close'].rolling(window=25).mean()

# 前日時点では50日移動平均線が25日平均線の上に位置し、当日時点で25日移動平均線が上に抜ける
for i in range(0, len(df)):
    if mal_50d[i] < mal_25d[i] and mal_50d[i-1] > mal_25d[i-1]:
        is_golden_cross = True

最後に

ゴールデンクロスは一般的に買いトレンドと言われることがおおいですが、単純にこれだけではなかなか儲かりません(経験談) いろいろな手法を試して組み合わせてみたいと思います。