照貓畫虎完成上證50結(jié)構(gòu)可視化
貓?jiān)谶@里 Visualizing the stock market structure
畫虎的大致步驟
- 第一步课锌,使用tushare獲取上證50股票列表
- 第二步,使用tushare獲取每只股票的歷史交易數(shù)據(jù)
- 第三步涌穆,對(duì)數(shù)據(jù)進(jìn)行處理,使用sklearn的相關(guān)模型進(jìn)行嵌套梳码,根據(jù)算法獲得分類化的輸出
- 第四步斜筐,使用sklearn,做一次局部線性嵌入邓夕,數(shù)據(jù)降為二維
- 第五步刘莹,使用matplotlib,二維數(shù)據(jù)可視化
既然說了是照貓畫虎焚刚,原理什么的自然是無(wú)力解釋点弯,相關(guān)概念還是靠搜索引擎吧。
完整代碼
from datetime import datetime
import tushare as ts
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from sklearn import cluster, covariance, manifold
# 第一步
pool = ts.get_sz50s()
names = pool.name
pool.head(10)
# 第二步
d1 = '2014-01-01'
d2 = '2017-09-30'
variation = pd.DataFrame()
for code in pool.code:
k = ts.get_k_data(code, d1, d2).set_index('date')
var = k['close'] - k['open']
var.name = code
variation = pd.concat([variation, var], axis=1)
可能是網(wǎng)速原因矿咕,有時(shí)候拿數(shù)據(jù)很慢抢肛,為了便于數(shù)據(jù)的重復(fù)利用,可以把得到的結(jié)果暫時(shí)保存起來
variation.fillna(method='ffill', inplace=True)
# variation.fillna(0,inplace=True)
variation.to_csv('sz50.csv')
不同的缺省值處理會(huì)造成不同的分類結(jié)果碳柱,但是差異不大捡絮。如果把缺省部分全部處理成0,調(diào)試的時(shí)候會(huì)遇到 ‘the system is too ill-conditioned for this solver’莲镣。難道是數(shù)據(jù)量不夠福稳?
anyway,該死的停牌剥悟。
# 第三步
variation = pd.read_csv('sz50.csv', index_col=0)
# 缺省值處理
variation.fillna(0,inplace=True)
edge_model = covariance.GraphLassoCV()
X = variation.copy()
X /= X.std(axis=0)
edge_model.fit(X)
_, labels = cluster.affinity_propagation(edge_model.covariance_)
n_labels = labels.max()
print labels
for i in range(n_labels + 1):
print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i])))
[ 6 6 0 12 10 6 6 0 1 6 10 2 1 1 4 3 1 10 4 5 10 10 6 0 6
6 7 10 12 8 6 6 6 6 7 6 6 6 7 10 9 10 7 6 0 11 10 12 6 12]
Cluster 1: 中國(guó)石化, 中國(guó)聯(lián)通, 中國(guó)神華, 中國(guó)石油
Cluster 2: 同方股份, 信威集團(tuán), 康美藥業(yè), 綠地控股
Cluster 3: 華夏幸福
Cluster 4: 山東黃金
Cluster 5: 貴州茅臺(tái), 伊利股份
Cluster 6: 江蘇銀行
Cluster 7: 浦發(fā)銀行, 民生銀行, 招商銀行, 保利地產(chǎn), 上汽集團(tuán), 大秦鐵路, 興業(yè)銀行, 北京銀行, 農(nóng)業(yè)銀行, 中國(guó)平安, 交通銀行, 新華保險(xiǎn), 工商銀行, 中國(guó)太保, 中國(guó)人壽, 光大銀行, 中國(guó)銀行
Cluster 8: 中國(guó)鐵建, 中國(guó)中鐵, 中國(guó)建筑, 中國(guó)交建
Cluster 9: 上海銀行
Cluster 10: 中國(guó)中車
Cluster 11: 中信證券, 北方稀土, 海通證券, 東方證券, 招商證券, 東興證券, 華泰證券, 光大證券, 方正證券
Cluster 12: 中國(guó)銀河
Cluster 13: 南方航空, 國(guó)泰君安, 中國(guó)核電, 中國(guó)重工
# 第四步
node_position_model = manifold.LocallyLinearEmbedding(
n_components=2, eigen_solver='dense', n_neighbors=6)
embedding = node_position_model.fit_transform(X.T).T
# 第五步
font = {'family': 'SimHei',
'color': 'black',
'weight': 'normal',
'size': 18,
}
plt.figure(1, facecolor='w', figsize=(10, 8))
plt.clf()
ax = plt.axes([0., 0., 1., 1.])
plt.axis('off')
partial_correlations = edge_model.precision_.copy()
d = 1 / np.sqrt(np.diag(partial_correlations))
partial_correlations *= d
partial_correlations *= d[:, np.newaxis]
non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)
plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,
cmap=plt.cm.spectral)
start_idx, end_idx = np.where(non_zero)
segments = [[embedding[:, start], embedding[:, stop]]
for start, stop in zip(start_idx, end_idx)]
values = np.abs(partial_correlations[non_zero])
lc = LineCollection(segments,
zorder=0, cmap=plt.cm.hot_r,
norm=plt.Normalize(0, .7 * values.max()))
lc.set_array(values)
lc.set_linewidths(15 * values)
ax.add_collection(lc)
for index, (name, label, (x, y)) in enumerate(
zip(names, labels, embedding.T)):
dx = x - embedding[0]
dx[index] = 1
dy = y - embedding[1]
dy[index] = 1
this_dx = dx[np.argmin(np.abs(dy))]
this_dy = dy[np.argmin(np.abs(dx))]
if this_dx > 0:
horizontalalignment = 'left'
x = x + .002
else:
horizontalalignment = 'right'
x = x - .002
if this_dy > 0:
verticalalignment = 'bottom'
y = y + .002
else:
verticalalignment = 'top'
y = y - .002
plt.text(x, y, name, fontdict=font, size=10,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
bbox=dict(facecolor='w',
edgecolor=plt.cm.spectral(label / float(n_labels)),
alpha=.6))
plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
embedding[0].max() + .10 * embedding[0].ptp(),)
plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
embedding[1].max() + .03 * embedding[1].ptp())
plt.show()
回頭再看看被分類的股票
選出兩組灵寺,看看所選時(shí)間段內(nèi)是否真的走勢(shì)相近
中國(guó)鐵建, 中國(guó)中鐵, 中國(guó)建筑, 中國(guó)交建
貴州茅臺(tái), 伊利股份
fit = [ True if name in [u'中國(guó)鐵建', u'中國(guó)中鐵', u'中國(guó)建筑', u'中國(guó)交建', u'貴州茅臺(tái)', u'伊利股份'] else False for name in pool.name]
picked = pool[fit]
print picked
code name
14 600519 貴州茅臺(tái)
18 600887 伊利股份
26 601186 中國(guó)鐵建
34 601390 中國(guó)中鐵
38 601668 中國(guó)建筑
42 601800 中國(guó)交建
close = pd.DataFrame()
for code in picked.code:
p = ts.get_k_data(code, d1, d2).set_index('date')['close']
p = p/p.mean()
p.name = code
close = pd.concat([close, p], axis=1)
close.head()
close.loc[:,['600519','600887']].plot()
plt.show()
close.loc[:,['601186', '601390', '601668', '601800']].plot()
plt.show()