sklearn kMeans 分類實戰,對滬深300的每日漲跌進行分類
阿新 • • 發佈:2018-12-12
# ohlc_clustering.py
import copy
import datetime
import pymysql
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.finance import candlestick_ohlc
import matplotlib.dates as mdates
from matplotlib.dates import (
DateFormatter, WeekdayLocator, DayLocator, MONDAY
)
import mpl_finance as mpf
import numpy as np
import pandas as pd
import pandas_datareader.data as web
from sklearn.cluster import KMeans
def get_open_normalised_prices():
"""
Obtains a pandas DataFrame containing open normalised prices
for high, low and close for a particular equities symbol
from Yahoo Finance. That is, it creates High/Open, Low/Open
and Close/Open columns.
"""
# df = web.DataReader(symbol, "yahoo", start, end)
connect = pymysql.connect(
host='127.0.0.1',
db='blog',
user='root',
passwd='123456',
charset='utf8',
use_unicode=True
)
select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01' order by date asc"
df = pd.read_sql(select_sql_300, con=connect)
df["H/O"] = df["High"]/df["Open"]
df["L/O"] = df["Low"]/df["Open"]
df["C/O"] = df["Close"]/df["Open"]
df.drop(
[
"Open", "High", "Low",
"Close", "Date"
],
axis=1, inplace=True
)
return df
def plot_candlesticks(data):
"""
Plot a candlestick chart of the prices,
appropriately formatted for dates
"""
# Copy and reset the index of the dataframe
# to only use a subset of the data for plotting
df = copy.deepcopy(data)
# df = df[df.index >= since]
df.reset_index(inplace=True)
df['date_fmt'] = df['Date'].apply(
lambda date: mdates.date2num(date.to_pydatetime())
)
# Set the axis formatting correctly for dates
# with Mondays highlighted as a "major" tick
mondays = WeekdayLocator(MONDAY)
alldays = DayLocator()
weekFormatter = DateFormatter('%b %d')
fig, ax = plt.subplots(figsize=(16,4))
fig.subplots_adjust(bottom=0.2)
# ax.xaxis.set_major_locator(mondays)
# ax.xaxis.set_minor_locator(alldays)
# ax.xaxis.set_major_formatter(weekFormatter)
# Plot the candlestick OHLC chart using black for
# up days and red for down days
csticks = mpf.candlestick_ohlc(
ax, df[
['date_fmt', 'Open', 'High', 'Low', 'Close']
].values, width=0.6,
colorup='r', colordown='green'
)
# ax.set_axis_bgcolor((1,1,0.9))
ax.xaxis_date()
# plt.setp(
# plt.gca().get_xticklabels(),
# rotation=45, horizontalalignment='right'
# )
plt.show()
def plot_cluster(data):
df = copy.deepcopy(data)
# df = df[df.index >= since]
df.reset_index(inplace=True)
df['date_fmt'] = df['Date'].apply(
lambda date: mdates.date2num(date.to_pydatetime())
)
# Set the axis formatting correctly for dates
# with Mondays highlighted as a "major" tick
mondays = WeekdayLocator(MONDAY)
alldays = DayLocator()
weekFormatter = DateFormatter('%b %d')
fig, ax = plt.subplots(figsize=(16, 4))
fig.subplots_adjust(bottom=0.2)
# ax.xaxis.set_major_locator(mondays)
# ax.xaxis.set_minor_locator(alldays)
# ax.xaxis.set_major_formatter(weekFormatter)
df0 = df.loc[df["Cluster"] == 0]
df1 = df.loc[df["Cluster"] == 1]
df2 = df.loc[df["Cluster"] == 2]
df3 = df.loc[df["Cluster"] == 3]
size = 1.2
ax.scatter(df0['date_fmt'], df0['Close'], s=size, c='y',marker='o',label="Small Rise")
ax.scatter(df1['date_fmt'], df1['Close'], s=size, c='g', marker='o', label="Big Down")
ax.scatter(df2['date_fmt'], df2['Close'], s=size, c='r', marker='o', label="Big Rise")
ax.scatter(df3['date_fmt'], df3['Close'], s=size, c='b', marker='o', label="Small Down")
ax.xaxis_date()
plt.xlabel('Date')
plt.ylabel('Close')
plt.legend(loc='upper right')
# plt.setp(
# plt.gca().get_xticklabels(),
# rotation=45, horizontalalignment='right'
# )
plt.show()
def plot_3d_normalised_candles(data):
"""
Plot a 3D scatterchart of the open-normalised bars
highlighting the separate clusters by colour
"""
fig = plt.figure(figsize=(12, 9))
ax = Axes3D(fig, elev=21, azim=-136)
ax.scatter(
data["H/O"], data["L/O"], data["C/O"],
c=labels.astype(np.float)
)
ax.set_xlabel('High/Open')
ax.set_ylabel('Low/Open')
ax.set_zlabel('Close/Open')
plt.show()
def plot_cluster_ordered_candles(data):
"""
Plot a candlestick chart ordered by cluster membership
with the dotted blue line representing each cluster
boundary.
"""
# Set the format for the axis to account for dates
# correctly, particularly Monday as a major tick
mondays = WeekdayLocator(MONDAY)
alldays = DayLocator()
weekFormatter = DateFormatter("")
fig, ax = plt.subplots(figsize=(16,4))
ax.xaxis.set_major_locator(mondays)
ax.xaxis.set_minor_locator(alldays)
ax.xaxis.set_major_formatter(weekFormatter)
# Sort the data by the cluster values and obtain
# a separate DataFrame listing the index values at
# which the cluster boundaries change
df = copy.deepcopy(data)
df.sort_values(by="Cluster", inplace=True)
df.reset_index(inplace=True)
df["clust_index"] = df.index
df["clust_change"] = df["Cluster"].diff()
change_indices = df[df["clust_change"] != 0]
# Plot the OHLC chart with cluster-ordered "candles"
csticks = mpf.candlestick_ohlc(
ax, df[
["clust_index", 'Open', 'High', 'Low', 'Close']
].values, width=0.6,
colorup='#000000', colordown='#ff0000'
)
# ax.set_axis_bgcolor((1,1,0.9))
# Add each of the cluster boundaries as a blue dotted line
for row in change_indices.iterrows():
plt.axvline(
row[1]["clust_index"],
linestyle="dashed", c="blue"
)
plt.xlim(0, len(df))
plt.setp(
plt.gca().get_xticklabels(),
rotation=45, horizontalalignment='right'
)
plt.show()
def create_follow_cluster_matrix(data):
"""
Creates a k x k matrix, where k is the number of clusters
that shows when cluster j follows cluster i.
"""
data["ClusterTomorrow"] = data["Cluster"].shift(-1)
data.dropna(inplace=True)
data["ClusterTomorrow"] = data["ClusterTomorrow"].apply(int)
hs300["ClusterMatrix"] = list(zip(data["Cluster"], data["ClusterTomorrow"]))
cmvc = data["ClusterMatrix"].value_counts()
clust_mat = np.zeros( (k, k) )
for row in cmvc.iteritems():
clust_mat[row[0]] = row[1]*100.0/len(data)
print("Cluster Follow-on Matrix:")
print(clust_mat)
if __name__ == "__main__":
# Obtain S&P500 pricing data from Yahoo Finance
connect = pymysql.connect(
host='127.0.0.1',
db='blog',
user='root',
passwd='123456',
charset='utf8',
use_unicode=True
)
select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01' order by date asc"
hs300 = pd.read_sql(select_sql_300, con=connect)
# # Plot last year of price "candles"
plot_candlesticks(hs300)
# Carry out K-Means clustering with four clusters on the
# three-dimensional data H/O, L/O and C/O
hs300_norm = get_open_normalised_prices()
k = 4
km = KMeans(n_clusters=k, random_state=42)
km.fit(hs300_norm)
labels = km.labels_
hs300_norm["Cluster"] = labels
hs300["Cluster"] = labels
#
# # Plot the 3D normalised candles using H/O, L/O, C/O
plot_3d_normalised_candles(hs300_norm)
# Create and output the cluster follow-on matrix
create_follow_cluster_matrix(hs300)
plot_cluster(hs300)