1. 程式人生 > >sklearn kMeans 分類實戰,對滬深300的每日漲跌進行分類

sklearn kMeans 分類實戰,對滬深300的每日漲跌進行分類

# ohlc_clustering.py

import copy
import datetime
import pymysql

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.finance import candlestick_ohlc
import matplotlib.dates as mdates
from matplotlib.dates import (
    DateFormatter, WeekdayLocator, DayLocator, MONDAY
)
import mpl_finance as mpf import numpy as np import pandas as pd import pandas_datareader.data as web from sklearn.cluster import KMeans def get_open_normalised_prices(): """ Obtains a pandas DataFrame containing open normalised prices for high, low and close for a particular equities symbol from Yahoo Finance. That is, it creates High/Open, Low/Open and Close/Open columns. """
# df = web.DataReader(symbol, "yahoo", start, end) connect = pymysql.connect( host='127.0.0.1', db='blog', user='root', passwd='123456', charset='utf8', use_unicode=True ) select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01' order by date asc"
df = pd.read_sql(select_sql_300, con=connect) df["H/O"] = df["High"]/df["Open"] df["L/O"] = df["Low"]/df["Open"] df["C/O"] = df["Close"]/df["Open"] df.drop( [ "Open", "High", "Low", "Close", "Date" ], axis=1, inplace=True ) return df def plot_candlesticks(data): """ Plot a candlestick chart of the prices, appropriately formatted for dates """ # Copy and reset the index of the dataframe # to only use a subset of the data for plotting df = copy.deepcopy(data) # df = df[df.index >= since] df.reset_index(inplace=True) df['date_fmt'] = df['Date'].apply( lambda date: mdates.date2num(date.to_pydatetime()) ) # Set the axis formatting correctly for dates # with Mondays highlighted as a "major" tick mondays = WeekdayLocator(MONDAY) alldays = DayLocator() weekFormatter = DateFormatter('%b %d') fig, ax = plt.subplots(figsize=(16,4)) fig.subplots_adjust(bottom=0.2) # ax.xaxis.set_major_locator(mondays) # ax.xaxis.set_minor_locator(alldays) # ax.xaxis.set_major_formatter(weekFormatter) # Plot the candlestick OHLC chart using black for # up days and red for down days csticks = mpf.candlestick_ohlc( ax, df[ ['date_fmt', 'Open', 'High', 'Low', 'Close'] ].values, width=0.6, colorup='r', colordown='green' ) # ax.set_axis_bgcolor((1,1,0.9)) ax.xaxis_date() # plt.setp( # plt.gca().get_xticklabels(), # rotation=45, horizontalalignment='right' # ) plt.show() def plot_cluster(data): df = copy.deepcopy(data) # df = df[df.index >= since] df.reset_index(inplace=True) df['date_fmt'] = df['Date'].apply( lambda date: mdates.date2num(date.to_pydatetime()) ) # Set the axis formatting correctly for dates # with Mondays highlighted as a "major" tick mondays = WeekdayLocator(MONDAY) alldays = DayLocator() weekFormatter = DateFormatter('%b %d') fig, ax = plt.subplots(figsize=(16, 4)) fig.subplots_adjust(bottom=0.2) # ax.xaxis.set_major_locator(mondays) # ax.xaxis.set_minor_locator(alldays) # ax.xaxis.set_major_formatter(weekFormatter) df0 = df.loc[df["Cluster"] == 0] df1 = df.loc[df["Cluster"] == 1] df2 = df.loc[df["Cluster"] == 2] df3 = df.loc[df["Cluster"] == 3] size = 1.2 ax.scatter(df0['date_fmt'], df0['Close'], s=size, c='y',marker='o',label="Small Rise") ax.scatter(df1['date_fmt'], df1['Close'], s=size, c='g', marker='o', label="Big Down") ax.scatter(df2['date_fmt'], df2['Close'], s=size, c='r', marker='o', label="Big Rise") ax.scatter(df3['date_fmt'], df3['Close'], s=size, c='b', marker='o', label="Small Down") ax.xaxis_date() plt.xlabel('Date') plt.ylabel('Close') plt.legend(loc='upper right') # plt.setp( # plt.gca().get_xticklabels(), # rotation=45, horizontalalignment='right' # ) plt.show() def plot_3d_normalised_candles(data): """ Plot a 3D scatterchart of the open-normalised bars highlighting the separate clusters by colour """ fig = plt.figure(figsize=(12, 9)) ax = Axes3D(fig, elev=21, azim=-136) ax.scatter( data["H/O"], data["L/O"], data["C/O"], c=labels.astype(np.float) ) ax.set_xlabel('High/Open') ax.set_ylabel('Low/Open') ax.set_zlabel('Close/Open') plt.show() def plot_cluster_ordered_candles(data): """ Plot a candlestick chart ordered by cluster membership with the dotted blue line representing each cluster boundary. """ # Set the format for the axis to account for dates # correctly, particularly Monday as a major tick mondays = WeekdayLocator(MONDAY) alldays = DayLocator() weekFormatter = DateFormatter("") fig, ax = plt.subplots(figsize=(16,4)) ax.xaxis.set_major_locator(mondays) ax.xaxis.set_minor_locator(alldays) ax.xaxis.set_major_formatter(weekFormatter) # Sort the data by the cluster values and obtain # a separate DataFrame listing the index values at # which the cluster boundaries change df = copy.deepcopy(data) df.sort_values(by="Cluster", inplace=True) df.reset_index(inplace=True) df["clust_index"] = df.index df["clust_change"] = df["Cluster"].diff() change_indices = df[df["clust_change"] != 0] # Plot the OHLC chart with cluster-ordered "candles" csticks = mpf.candlestick_ohlc( ax, df[ ["clust_index", 'Open', 'High', 'Low', 'Close'] ].values, width=0.6, colorup='#000000', colordown='#ff0000' ) # ax.set_axis_bgcolor((1,1,0.9)) # Add each of the cluster boundaries as a blue dotted line for row in change_indices.iterrows(): plt.axvline( row[1]["clust_index"], linestyle="dashed", c="blue" ) plt.xlim(0, len(df)) plt.setp( plt.gca().get_xticklabels(), rotation=45, horizontalalignment='right' ) plt.show() def create_follow_cluster_matrix(data): """ Creates a k x k matrix, where k is the number of clusters that shows when cluster j follows cluster i. """ data["ClusterTomorrow"] = data["Cluster"].shift(-1) data.dropna(inplace=True) data["ClusterTomorrow"] = data["ClusterTomorrow"].apply(int) hs300["ClusterMatrix"] = list(zip(data["Cluster"], data["ClusterTomorrow"])) cmvc = data["ClusterMatrix"].value_counts() clust_mat = np.zeros( (k, k) ) for row in cmvc.iteritems(): clust_mat[row[0]] = row[1]*100.0/len(data) print("Cluster Follow-on Matrix:") print(clust_mat) if __name__ == "__main__": # Obtain S&P500 pricing data from Yahoo Finance connect = pymysql.connect( host='127.0.0.1', db='blog', user='root', passwd='123456', charset='utf8', use_unicode=True ) select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01' order by date asc" hs300 = pd.read_sql(select_sql_300, con=connect) # # Plot last year of price "candles" plot_candlesticks(hs300) # Carry out K-Means clustering with four clusters on the # three-dimensional data H/O, L/O and C/O hs300_norm = get_open_normalised_prices() k = 4 km = KMeans(n_clusters=k, random_state=42) km.fit(hs300_norm) labels = km.labels_ hs300_norm["Cluster"] = labels hs300["Cluster"] = labels # # # Plot the 3D normalised candles using H/O, L/O, C/O plot_3d_normalised_candles(hs300_norm) # Create and output the cluster follow-on matrix create_follow_cluster_matrix(hs300) plot_cluster(hs300)

在這裡插入圖片描述

在這裡插入圖片描述