|
|
|
@ -95,8 +95,8 @@ for industry in df.columns[2:][1:]:
|
|
|
|
|
print(dataset_x.shape, dataset_y.shape)
|
|
|
|
|
|
|
|
|
|
train_size = int(0.7 * len(dataset_x))
|
|
|
|
|
x_train, y_train = dataset_x[:train_size], dataset_y[:train_size]
|
|
|
|
|
x_eval, y_eval = dataset_x[train_size:], dataset_y[train_size:]
|
|
|
|
|
x_train, y_train = dataset_x[:train_size].reshape(-1,1,10), dataset_y[:train_size].reshape(-1, 1, 3)
|
|
|
|
|
x_eval, y_eval = dataset_x[train_size:].reshape(-1,1,10), dataset_y[train_size:].reshape(-1, 1, 3)
|
|
|
|
|
x_train, y_train = torch.from_numpy(x_train).type(torch.float32), torch.from_numpy(y_train).type(torch.float32)
|
|
|
|
|
x_eval, y_eval = torch.from_numpy(x_eval).type(torch.float32), torch.from_numpy(y_eval).type(torch.float32)
|
|
|
|
|
|
|
|
|
|