You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

16 lines
656 B
Python

import torch
import pandas as pd
from 电压等级_输入10_输出3 import LSTM_Regression
from 电压等级_输入10_输出3 import create_dataset
model = LSTM_Regression(10, 32, output_size=5, num_layers=2)
model.load_state_dict(torch.load('dy5.pth'))
df_eval = pd.read_excel(r'C:\Users\user\Desktop\浙江各地市分电压日电量数据\杭州.xlsx',index_col=' stat_date ')
df_eval.columns = df_eval.columns.map(lambda x:x.strip())
df_eval.index = pd.to_datetime(df_eval.index)
x,y = create_dataset(df_eval.loc['2023-10']['10kv以下'],10)
x = x.reshape(-1,1,10)
print(x.shape,y.shape)
x = torch.from_numpy(x).type(torch.float32)
print(model(x),y)