You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
833 B
Python
28 lines
833 B
Python
1 year ago
|
import torch
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
from torch import nn
|
||
|
import os
|
||
|
torch.manual_seed(42)
|
||
|
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # 解决OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
|
||
|
|
||
|
class LSTM(nn.Module):
|
||
|
def __init__(self,input_size,hidden_size,output_size,num_layers=2):
|
||
|
super().__init__()
|
||
|
self.lstm = nn.LSTM(input_size,hidden_size,num_layers=2)
|
||
|
self.fc1 = nn.Linear(hidden_size,64)
|
||
|
self.ReLu = nn.ReLU()
|
||
|
self.fc2 = nn.Linear(64,output_size)
|
||
|
|
||
|
def forward(self,x):
|
||
|
output,_ = self.lstm(x)
|
||
|
s,b,h = output.shape
|
||
|
output = output.reshape(-1,h)
|
||
|
output = self.ReLu(self.fc1(output))
|
||
|
output = self.fc2(output)
|
||
|
return output
|
||
|
|
||
|
# 创建数据集
|
||
|
|