diff --git a/杭州日电量/cws_to_data/台州.txt b/杭州日电量/cws_to_data/台州.txt index 9f0a9e2..bc9092d 100644 --- a/杭州日电量/cws_to_data/台州.txt +++ b/杭州日电量/cws_to_data/台州.txt @@ -58,3 +58,4 @@ {'city': '台州', 'industry': '17.橡胶和塑料制品业', 'month_deviation_rate': 0.00067} {'city': '台州', 'industry': '其中:橡胶制品业', 'month_deviation_rate': 0.00118} {'city': '台州', 'industry': '塑料制品业', 'month_deviation_rate': 0.00314} +{'city': '台州', 'industry': '18.非金属矿物制品业', 'month_deviation_rate': -0.00198} diff --git a/杭州日电量/industry_elec_cws.py b/杭州日电量/industry_elec_cws.py index 77f1d5c..da734cf 100644 --- a/杭州日电量/industry_elec_cws.py +++ b/杭州日电量/industry_elec_cws.py @@ -71,15 +71,22 @@ def to_data(file_dir, excel): train_x = train_x.reshape(-1, 1, DAYS_FOR_TRAIN) train_y = train_y.reshape(-1, 1, 1) + # 使用GPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_x.to(device) + train_y.to(device) + # 转为pytorch的tensor对象 train_x = torch.from_numpy(train_x) train_y = torch.from_numpy(train_y) model = LSTM_Regression(DAYS_FOR_TRAIN, 32, output_size=1, num_layers=2) # 导入模型并设置模型的参数输入输出层、隐藏层等 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # 使用GPU print("cuda" if torch.cuda.is_available() else "cpu") model.to(device) + # train_x.to(device) + # train_y.to(device) train_loss = [] loss_function = nn.MSELoss() @@ -106,8 +113,11 @@ def to_data(file_dir, excel): model = model.eval() # 转换成测试模式 # model.load_state_dict(torch.load(os.path.join(model_save_dir,model_file))) # 读取参数 + + dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN) # (seq_size, batch_size, feature_size) dataset_x = torch.from_numpy(dataset_x) + dataset_x.to(device) pred_test = model(dataset_x) # 全量训练集 # 模型输出 (seq_size, batch_size, output_size) diff --git a/杭州日电量/pip3正确安装PyTorch GPU版本.txt b/杭州日电量/pip3正确安装PyTorch GPU版本.txt new file mode 100644 index 0000000..e10657d --- /dev/null +++ b/杭州日电量/pip3正确安装PyTorch GPU版本.txt @@ -0,0 +1,11 @@ +# 首先卸载 +pip3 uninstall torch +pip3 uninstall torchvision +pip3 uninstall torchaudio + +# 然后安装 +pip3 install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 -f https://download.pytorch.org/whl/cu121/torch_stable.html + + +# 可临时清华源加速 +pip3 install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 -f https://download.pytorch.org/whl/cu121/torch_stable.html -i https://pypi.tuna.tsinghua.edu.cn/simple/ \ No newline at end of file