You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

28 lines
833 B
Python

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import os
torch.manual_seed(42)
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # 解决OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
class LSTM(nn.Module):
def __init__(self,input_size,hidden_size,output_size,num_layers=2):
super().__init__()
self.lstm = nn.LSTM(input_size,hidden_size,num_layers=2)
self.fc1 = nn.Linear(hidden_size,64)
self.ReLu = nn.ReLU()
self.fc2 = nn.Linear(64,output_size)
def forward(self,x):
output,_ = self.lstm(x)
s,b,h = output.shape
output = output.reshape(-1,h)
output = self.ReLu(self.fc1(output))
output = self.fc2(output)
return output
# 创建数据集