You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
6.3 KiB
Python

10 months ago
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
10 months ago
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
train_step = 10
10 months ago
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
self.fc1 = nn.Linear(hidden_size, 128)
self.fc2 = nn.Linear(128, output_size)
self.ReLu = nn.ReLU()
self.dropout = nn.Dropout(0.8)
10 months ago
def forward(self, x):
x, _ = self.lstm(x)
s, b, h = x.shape
x = x.reshape(-1, h)
output = self.ReLu(self.dropout(self.fc1(x)))
10 months ago
output = self.fc2(output)
return output
10 months ago
def normal(data):
high = data.describe()['75%'] + 1.5 * (data.describe()['75%'] - data.describe()['25%'])
low = data.describe()['25%'] - 1.5 * (data.describe()['75%'] - data.describe()['25%'])
return (data >= low) & (data <= high)
# file_dir = './浙江各地市行业电量数据'
#
# # 合并11个市
# df = pd.DataFrame({})
# for city in os.listdir(file_dir):
#
# df_city = pd.read_excel(os.path.join(file_dir, city))
#
# # 对每个市的每一个行业异常值 向后填充
# for industry in df_city.columns[2:]:
# outliers_index = normal(df_city[industry]).index
# df_city[industry] = df_city[industry].where(normal(df_city[industry]), other=np.nan).bfill()
# df_city[industry].fillna(method='ffill',inplace=True)
# df = pd.concat([df,df_city])
# print(df.shape)
#
# df.to_csv('11市行业数据(已处理异常).csv',index=False,encoding='GBK')
df = pd.read_csv('11市行业数据(已处理异常).csv', encoding='gbk')
10 months ago
# 对df每一行业进行归一化
column_params = {}
for column in df.columns[2:]:
scaler = MinMaxScaler()
df[column] = scaler.fit_transform(df[[column]])
column_params[column] = {'min': scaler.data_min_[0], 'max': scaler.data_max_[0]}
print(column_params)
print(df.head())
def create_dataset(data, train_step=train_step) -> (np.array, np.array):
10 months ago
dataset_x, dataset_y = [], []
for i in range(len(data) - train_step - 3):
dataset_x.append(data[i:(i + train_step)])
dataset_y.append(data[i + train_step:i + train_step + 3])
10 months ago
return (np.array(dataset_x), np.array(dataset_y))
# 切分x,y数据集步长为10.最小单位为单个城市的单个行业。
# 先从第一个行业切分,合并所有城市。
# industry = df.columns[2:][0]
# city = df['地市'].drop_duplicates()[0]
# df_city_industry = df[df['地市'] == city][industry]
# dataset_x, dataset_y = create_dataset(df_city_industry)
#
# for city in df['地市'].drop_duplicates()[1:]:
# df_city_industry = df[df['地市'] == city][industry]
# x, y = create_dataset(df_city_industry)
# dataset_x, dataset_y = np.concatenate([dataset_x, x]), np.concatenate([dataset_y, y])
#
# for industry in df.columns[2:][1:]:
# for city in df['地市'].drop_duplicates():
# df_city_industry = df[df['地市'] == city][industry]
# x, y = create_dataset(df_city_industry)
# dataset_x, dataset_y = np.concatenate([dataset_x, x]), np.concatenate([dataset_y, y])
#
# print(dataset_x.shape, dataset_y.shape)
# df_x = pd.DataFrame(dataset_x)
# df_y = pd.DataFrame(dataset_y)
# df_x.to_csv('df_x_100.csv',index=False)
# df_y.to_csv('df_y_100.csv',index=False)
dataset_x = pd.read_csv('df_x.csv').values
dataset_y = pd.read_csv('df_y.csv').values
10 months ago
print(dataset_x.shape, dataset_y.shape)
10 months ago
train_size = int(0.7 * len(dataset_x))
x_train, y_train = dataset_x[:train_size].reshape(-1,1,train_step), dataset_y[:train_size].reshape(-1, 1, 3)
x_eval, y_eval = dataset_x[train_size:].reshape(-1,1,train_step), dataset_y[train_size:].reshape(-1, 1, 3)
10 months ago
x_train, y_train = torch.from_numpy(x_train).type(torch.float32), torch.from_numpy(y_train).type(torch.float32)
x_eval, y_eval = torch.from_numpy(x_eval).type(torch.float32), torch.from_numpy(y_eval).type(torch.float32)
ds = TensorDataset(x_train, y_train)
dl = DataLoader(ds, batch_size=32, drop_last=True)
10 months ago
eval_ds = TensorDataset(x_eval, y_eval)
eval_dl = DataLoader(eval_ds, batch_size=64, drop_last=True)
10 months ago
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTM(train_step,64, 3, num_layers=2).to(device)
10 months ago
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
10 months ago
min_loss = 1
for i in range(500):
train_x,train_y = train_x.to(device),train_y.to(device)
out = model(train_x)
loss = loss_fn(out, train_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if loss <= min_loss:
min_loss = loss
best_para = model.state_dict()
if i % 100 == 0:
print(f'epoch {i+1}: loss:{loss}')
# for epoch in range(3):
# model.train()
# for step, (x, y) in enumerate(dl):
# x, y = x.to(device), y.to(device)
# pred = model(x)
# loss = loss_fn(pred,y)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
#
# if step % 1000 == 0:
# print(f'epoch{epoch+1}: train_step:{step}/{len(dl)} train_loss:{loss}\n')
#
# model.eval()
# batch_loss = 0
# with torch.no_grad():
# for x,y in eval_dl:
# x, y = x.to(device), y.to(device)
# pred = model(x)
# loss = loss_fn(pred, y)
# batch_loss += loss
# print(f'epoch{epoch+1}: eval_loss:{batch_loss/len(eval_dl)}\n')
#
# if batch_loss/len(eval_dl) < min_loss:
# min_loss = batch_loss/len(eval_dl)
# best_parameters = model.state_dict()
10 months ago
torch.save(best_parameters,'best_3.pth')
model = LSTM(train_step,64, 3, num_layers=2).to(device)
10 months ago
model.load_state_dict(torch.load('best_3.pth'))
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("LSTM参数总量:", params)
dataset_x = dataset_x.reshape(-1,1,train_step)
10 months ago
dataset_x = torch.from_numpy(dataset_x).type(torch.float32).to(device)
pred = model(dataset_x).reshape(-1)
pred = np.concatenate((np.zeros(train_step), pred.cpu().detach().numpy()))
10 months ago
plt.plot(pred, 'r', label='prediction')
plt.plot(dataset_y.reshape(-1), 'b', label='real')
plt.plot((train_size*3, train_size*3), (0, 1), 'g--') # 分割线 左边是训练数据 右边是测试数据的输出
plt.legend(loc='best')
plt.show()