17427 1 year ago
parent 7a4d041dc2
commit 2959a98bd1

@ -43,6 +43,7 @@ def to_data(file_dir, excel):
data.drop(columns=[i for i in data.columns if (data[i] == 0).sum() / len(data) >= 0.5], inplace=True) # 去除0值列
print('len(data):', len(data))
for industry in data.columns[1:]:
c = time.time()
df = data[['stat_date', industry]]
df = df[df[industry] != 0] # 去除0值行
@ -58,7 +59,7 @@ def to_data(file_dir, excel):
df = (df - min_value) / (max_value - min_value)
dataset_x, dataset_y = create_dataset(df, DAYS_FOR_TRAIN)
print()
print("========")
print('len(dataset_x:)', len(dataset_x))
# 划分训练集和测试集
@ -82,7 +83,9 @@ def to_data(file_dir, excel):
train_loss = []
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
a = time.time()
print("行业加载时间", a-c)
for i in range(1200):
out = model(train_x)
loss = loss_function(out, train_y)
@ -126,6 +129,7 @@ def to_data(file_dir, excel):
with open(fr'.\cws_to_data\{excel[:2]}.txt', 'a', encoding='utf-8') as f:
tmp_data = {'city': excel[:2], 'industry': industry, "month_deviation_rate": round(target, 5)}
f.write(str(tmp_data) + "\n")
print("========")
if __name__ == '__main__':

Loading…
Cancel
Save