main
parent
ddbf3e5d61
commit
544ac6add4
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="C:\anaconda\envs\pytorch" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="pytorch_gpu" project-jdk-type="Python SDK" />
|
||||
</project>
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,16 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
from 电压等级_输出为5 import LSTM_Regression
|
||||
from 电压等级_输出为5 import create_dataset
|
||||
model = LSTM_Regression(10, 32, output_size=5, num_layers=2)
|
||||
model.load_state_dict(torch.load('dy5.pth'))
|
||||
|
||||
df_eval = pd.read_excel(r'C:\Users\user\Desktop\浙江各地市分电压日电量数据\杭州.xlsx',index_col=' stat_date ')
|
||||
df_eval.columns = df_eval.columns.map(lambda x:x.strip())
|
||||
df_eval.index = pd.to_datetime(df_eval.index)
|
||||
|
||||
x,y = create_dataset(df_eval.loc['2023-10']['10kv以下'],10)
|
||||
x = x.reshape(-1,1,10)
|
||||
print(x.shape,y.shape)
|
||||
x = torch.from_numpy(x).type(torch.float32)
|
||||
print(model(x),y)
|
Loading…
Reference in New Issue