|
|
@ -87,14 +87,14 @@ def run(file_dir,excel):
|
|
|
|
train_loss = []
|
|
|
|
train_loss = []
|
|
|
|
loss_function = nn.MSELoss()
|
|
|
|
loss_function = nn.MSELoss()
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
|
|
for i in range(1500):
|
|
|
|
for i in range(2500):
|
|
|
|
out = model(train_x)
|
|
|
|
out = model(train_x)
|
|
|
|
loss = loss_function(out, train_y)
|
|
|
|
loss = loss_function(out, train_y)
|
|
|
|
loss.backward()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
train_loss.append(loss.item())
|
|
|
|
train_loss.append(loss.item())
|
|
|
|
# print(loss)
|
|
|
|
print(loss)
|
|
|
|
# 保存模型
|
|
|
|
# 保存模型
|
|
|
|
# torch.save(model.state_dict(),save_filename)
|
|
|
|
# torch.save(model.state_dict(),save_filename)
|
|
|
|
# torch.save(model.state_dict(),os.path.join(model_save_dir,model_file))
|
|
|
|
# torch.save(model.state_dict(),os.path.join(model_save_dir,model_file))
|
|
|
@ -121,7 +121,7 @@ def run(file_dir,excel):
|
|
|
|
# result_list = []
|
|
|
|
# result_list = []
|
|
|
|
# 以x为基础实际数据,滚动预测未来3天
|
|
|
|
# 以x为基础实际数据,滚动预测未来3天
|
|
|
|
x = torch.from_numpy(df[-14:-4]).to(device)
|
|
|
|
x = torch.from_numpy(df[-14:-4]).to(device)
|
|
|
|
pred = model(x.reshape(-1,1,DAYS_FOR_TRAIN)).view(-1).detach().numpy()
|
|
|
|
pred = model(x.reshape(-1,1,DAYS_FOR_TRAIN)).view(-1).cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# for i in range(3):
|
|
|
|
# for i in range(3):
|
|
|
@ -147,7 +147,7 @@ def run(file_dir,excel):
|
|
|
|
print(target)
|
|
|
|
print(target)
|
|
|
|
print(result_eight)
|
|
|
|
print(result_eight)
|
|
|
|
final_df = pd.concat(list_app,ignore_index=True)
|
|
|
|
final_df = pd.concat(list_app,ignore_index=True)
|
|
|
|
final_df.to_csv('市行业电量.csv',encoding='gbk')
|
|
|
|
# final_df.to_csv('市行业电量.csv',encoding='gbk')
|
|
|
|
print(final_df)
|
|
|
|
print(final_df)
|
|
|
|
|
|
|
|
|
|
|
|
# result_eight.to_csv(f'./月底预测结果/9月{excel[:2]}.txt', sep='\t', mode='a')
|
|
|
|
# result_eight.to_csv(f'./月底预测结果/9月{excel[:2]}.txt', sep='\t', mode='a')
|
|
|
@ -155,7 +155,7 @@ def run(file_dir,excel):
|
|
|
|
# f.write(f'{excel[:2]}{industry}:{round(target, 5)}\n')
|
|
|
|
# f.write(f'{excel[:2]}{industry}:{round(target, 5)}\n')
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
file_dir = r'C:\python-project\pytorch3\浙江行业电量\浙江所有地市133行业数据'
|
|
|
|
file_dir = r'C:\Users\user\PycharmProjects\pytorch2\浙江行业电量\浙江所有地市133行业数据'
|
|
|
|
|
|
|
|
|
|
|
|
run(file_dir,'丽水133行业数据(全).xlsx')
|
|
|
|
run(file_dir,'丽水133行业数据(全).xlsx')
|
|
|
|
# p = Pool(4)
|
|
|
|
# p = Pool(4)
|
|
|
|