You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
1 year ago
|
import pandas as pd
|
||
|
import xgboost as xgb
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from sklearn.metrics import r2_score
|
||
|
import numpy as np
|
||
|
|
||
|
df = pd.read_excel(r'./400v入模数据.xlsx')
|
||
|
df['stat_date'] = pd.to_datetime(df['stat_date'])
|
||
|
|
||
|
print(df.corr()['0.4kv及以下'])
|
||
|
|
||
|
X = df[(df['stat_date']>='2021-01-01')&(df['stat_date']<='2023-09-28')].drop(columns=['0.4kv及以下']).set_index('stat_date')
|
||
|
|
||
|
y = df[(df['stat_date']>='2021-01-01')&(df['stat_date']<='2023-09-28')]['0.4kv及以下']
|
||
|
x_eval = df[(df['stat_date']<='2023-09-30')&(df['stat_date']>='2023-09-01')].drop(columns=['0.4kv及以下']).set_index('stat_date')
|
||
|
print(x_eval)
|
||
|
y_eval = df[(df['stat_date']<='2023-09-30')&(df['stat_date']>='2023-09-01')][['0.4kv及以下','city']]
|
||
|
|
||
|
|
||
|
x_train,x_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=42)
|
||
|
model = xgb.XGBRegressor(max_depth=6,learning_rate=0.05,n_estimators=250)
|
||
|
model.fit(x_train,y_train)
|
||
|
y_pred = model.predict(x_test)
|
||
|
print(r2_score(y_test,y_pred))
|
||
|
|
||
|
predict = model.predict(x_eval)
|
||
|
result = pd.DataFrame({'real':y_eval.drop(columns='city').values.reshape(-1),'pred':predict},index=x_eval.index)
|
||
|
print(result.loc['2023-09-28':'2023-09-30'])
|
||
|
|
||
|
|
||
|
dict2 = {'杭州':0,'湖州':1,'嘉兴':2,'金华':3,'丽水':4,'宁波':5,'衢州':6,'绍兴':7,'台州':8,'温州':9,'舟山':10}
|
||
|
dict1 = {}
|
||
|
for city in x_eval['city'].drop_duplicates():
|
||
|
eval_x = x_eval[x_eval['city']==city]
|
||
|
eval_y = y_eval[y_eval['city']==city]['0.4kv及以下']
|
||
|
pred = model.predict(eval_x)
|
||
|
loss_rate = (np.sum(pred[-3:])-np.sum(eval_y[-3:]))/np.sum(eval_y)
|
||
|
dict1[city] = loss_rate
|
||
|
|
||
|
|
||
|
for key in dict2.keys():
|
||
|
dict2[key] = dict1[dict2[key]]
|
||
|
print(dict2)
|