import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.cm as cm
import matplotlib.colors as mcl
from scipy.stats import zscore
from matplotlib.colors import TwoSlopeNorm

plt.style.use('ggplot') # 차트 격자 제공

# 엑셀 파일 경로 (이미지가 아닌 실제 .xlsx 파일 경로)
file_path = 'IPO_data_simple.xlsx'

# 엑셀 파일 읽기
df2 = pd.read_excel(file_path, sheet_name='Sheet2')
df3 = pd.read_excel(file_path, sheet_name='Sheet3')

# 중앙값 저장
median_val = df3['promise_ratio'].median()

# Normalize C2 values between 0 and 1 to map to the colormap
norm = TwoSlopeNorm(vmin=df3['promise_ratio'].min(), vcenter=median_val, vmax=df3['promise_ratio'].max())
cmap = plt.get_cmap('bwr') # 컬러바 종류 설정

# t+0부터 t+250까지의 데이터만 추출
df_filtered = df2.iloc[:250].copy()  # t+0부터 t+n까지

fig = plt.figure(figsize=(15, 8))
   

# 각 종목별로 시계열 그래프 그리기
for column in df_filtered.columns[1:]:  # 'date' 열을 제외하고 반복
   
    if column == 'mean':  # 'mean' 열에 대한 특별한 처리
        plt.plot(df_filtered['date'], df_filtered[column], label=column, alpha=0.6, color='green', lw=3)  # 'mean'을 초록선으로 표시
   
    elif column != 'mean':  # 'mean' 열은 다르게 처리하므로 여기서 제외
        # df3에서 현재 종목(column)의 평가점수('promise_ratio') 찾기
        score_row = df3[df3['code'] == column]
       
        if not score_row.empty:
            # 데이터가 존재하는 경우
            score = score_row['promise_ratio'].iloc[0]
            color = cmap(norm(score))  # Using the normalized score to get a color from the colormap
            plt.plot(df_filtered['date'], df_filtered[column], label='_nolegend_', color=color, alpha=0.6, lw=2)
           
        else:
            # 데이터가 존재하지 않는 경우
            plt.plot(df_filtered['date'], df_filtered[column], label='_nolegend_', color='black', alpha=1, lw=2)


# y축 눈금값 설정
y_values = [2, 3, 4, 4.605, 5, 6]
plt.yticks(y_values)

colormapping = cm.ScalarMappable(norm=norm, cmap=cmap) # 컬러바 생성
cbar = fig.colorbar(colormapping, ax=plt.gca(), pad=0.01) # 컬러바 삽입

plt.title('IPO price data')
plt.xlabel('date')
plt.ylabel('price index')
plt.xticks(rotation=45)
plt.gca().set_xticks(df_filtered['date'][::30])  # 2개 간격으로 x축 눈금 설정
plt.gca().set_xticklabels(df_filtered['date'][::30])  # 2개 간격으로 x축 라벨 설정
plt.axhline(y=4.605, color='r', linestyle='--') # y=4.605 인 수평 점선 설정


plt.legend()
plt.tight_layout()
plt.show()