get 1 year ago
parent 359effce4c
commit 8a2d8f2a9a

@ -0,0 +1,81 @@
,real,pred
0,366314897.6,325991300.0
1,380820733.0,290698080.0
2,380238801.9,257544720.0
3,244227473.0,281246980.0
4,364073058.1,275352350.0
5,380820733.0,347751680.0
6,380238801.9,321158200.0
7,244227473.0,292571360.0
8,364073058.1,298532380.0
9,316723431.2,319373630.0
10,380238801.9,343127650.0
11,244227473.0,333881150.0
12,364073058.1,315865570.0
13,316723431.2,321350340.0
14,334702116.4,339486850.0
15,244227473.0,355201730.0
16,364073058.1,351205660.0
17,316723431.2,340886050.0
18,334702116.4,344197900.0
19,340252127.5,356064770.0
20,364073058.1,273813060.0
21,316723431.2,292272900.0
22,334702116.4,293830340.0
23,340252127.5,293426180.0
24,338685923.8,300031800.0
25,316723431.2,349388480.0
26,334702116.4,348153700.0
27,340252127.5,350227100.0
28,338685923.8,351487040.0
29,323789488.1,347344860.0
30,334702116.4,305196900.0
31,340252127.5,312199360.0
32,338685923.8,318501820.0
33,323789488.1,309094460.0
34,327864337.1,311822370.0
35,340252127.5,332496220.0
36,338685923.8,331613600.0
37,323789488.1,330546100.0
38,327864337.1,327573700.0
39,308162792.6,328044900.0
40,338685923.8,335969060.0
41,323789488.1,331136770.0
42,327864337.1,323793180.0
43,308162792.6,321710080.0
44,291282417.7,328650000.0
45,323789488.1,319168770.0
46,327864337.1,319473630.0
47,308162792.6,315658180.0
48,291282417.7,316581950.0
49,326988587.6,322722780.0
50,327864337.1,304423680.0
51,308162792.6,311155600.0
52,291282417.7,315857400.0
53,326988587.6,318633470.0
54,338531879.5,318105200.0
55,308162792.6,332961340.0
56,291282417.7,335313200.0
57,326988587.6,345534460.0
58,338531879.5,343698900.0
59,343441750.8,337725980.0
60,291282417.7,301866880.0
61,326988587.6,310824100.0
62,338531879.5,321738430.0
63,343441750.8,311232700.0
64,328373555.6,309342560.0
65,326988587.6,306569470.0
66,338531879.5,310643680.0
67,343441750.8,314744000.0
68,328373555.6,308885660.0
69,307601302.1,306099400.0
70,338531879.5,321590140.0
71,343441750.8,318885440.0
72,328373555.6,318865300.0
73,307601302.1,314202980.0
74,296591378.5,314701920.0
75,343441750.8,322403870.0
76,328373555.6,318382370.0
77,307601302.1,315469380.0
78,296591378.5,312369950.0
79,291360430.5,316049300.0
1 real pred
2 0 366314897.6 325991300.0
3 1 380820733.0 290698080.0
4 2 380238801.9 257544720.0
5 3 244227473.0 281246980.0
6 4 364073058.1 275352350.0
7 5 380820733.0 347751680.0
8 6 380238801.9 321158200.0
9 7 244227473.0 292571360.0
10 8 364073058.1 298532380.0
11 9 316723431.2 319373630.0
12 10 380238801.9 343127650.0
13 11 244227473.0 333881150.0
14 12 364073058.1 315865570.0
15 13 316723431.2 321350340.0
16 14 334702116.4 339486850.0
17 15 244227473.0 355201730.0
18 16 364073058.1 351205660.0
19 17 316723431.2 340886050.0
20 18 334702116.4 344197900.0
21 19 340252127.5 356064770.0
22 20 364073058.1 273813060.0
23 21 316723431.2 292272900.0
24 22 334702116.4 293830340.0
25 23 340252127.5 293426180.0
26 24 338685923.8 300031800.0
27 25 316723431.2 349388480.0
28 26 334702116.4 348153700.0
29 27 340252127.5 350227100.0
30 28 338685923.8 351487040.0
31 29 323789488.1 347344860.0
32 30 334702116.4 305196900.0
33 31 340252127.5 312199360.0
34 32 338685923.8 318501820.0
35 33 323789488.1 309094460.0
36 34 327864337.1 311822370.0
37 35 340252127.5 332496220.0
38 36 338685923.8 331613600.0
39 37 323789488.1 330546100.0
40 38 327864337.1 327573700.0
41 39 308162792.6 328044900.0
42 40 338685923.8 335969060.0
43 41 323789488.1 331136770.0
44 42 327864337.1 323793180.0
45 43 308162792.6 321710080.0
46 44 291282417.7 328650000.0
47 45 323789488.1 319168770.0
48 46 327864337.1 319473630.0
49 47 308162792.6 315658180.0
50 48 291282417.7 316581950.0
51 49 326988587.6 322722780.0
52 50 327864337.1 304423680.0
53 51 308162792.6 311155600.0
54 52 291282417.7 315857400.0
55 53 326988587.6 318633470.0
56 54 338531879.5 318105200.0
57 55 308162792.6 332961340.0
58 56 291282417.7 335313200.0
59 57 326988587.6 345534460.0
60 58 338531879.5 343698900.0
61 59 343441750.8 337725980.0
62 60 291282417.7 301866880.0
63 61 326988587.6 310824100.0
64 62 338531879.5 321738430.0
65 63 343441750.8 311232700.0
66 64 328373555.6 309342560.0
67 65 326988587.6 306569470.0
68 66 338531879.5 310643680.0
69 67 343441750.8 314744000.0
70 68 328373555.6 308885660.0
71 69 307601302.1 306099400.0
72 70 338531879.5 321590140.0
73 71 343441750.8 318885440.0
74 72 328373555.6 318865300.0
75 73 307601302.1 314202980.0
76 74 296591378.5 314701920.0
77 75 343441750.8 322403870.0
78 76 328373555.6 318382370.0
79 77 307601302.1 315469380.0
80 78 296591378.5 312369950.0
81 79 291360430.5 316049300.0

@ -0,0 +1,81 @@
,real,pred
0,135488382.9,126543270.0
1,139728296.4,132820504.0
2,137821025.8,111796240.0
3,78600422.1,131913600.0
4,126378434.7,121779730.0
5,139728296.4,133275510.0
6,137821025.8,132085740.0
7,78600422.1,123508616.0
8,126378434.7,122062880.0
9,112584670.1,126883176.0
10,137821025.8,127550950.0
11,78600422.1,127659760.0
12,126378434.7,120534296.0
13,112584670.1,120153140.0
14,126043317.8,126331544.0
15,78600422.1,124620616.0
16,126378434.7,125789710.0
17,112584670.1,120741016.0
18,126043317.8,122845240.0
19,127571038.2,126469840.0
20,126378434.7,89965520.0
21,112584670.1,97125790.0
22,126043317.8,97640040.0
23,127571038.2,97961870.0
24,126803223.9,100093896.0
25,112584670.1,118327300.0
26,126043317.8,119570870.0
27,127571038.2,120226640.0
28,126803223.9,122849880.0
29,121261267.3,119103896.0
30,126043317.8,112915770.0
31,127571038.2,113519840.0
32,126803223.9,115918376.0
33,121261267.3,112486350.0
34,124085800.1,111817100.0
35,127571038.2,122143170.0
36,126803223.9,119428750.0
37,121261267.3,117216230.0
38,124085800.1,115505320.0
39,111506694.9,116549080.0
40,126803223.9,123125660.0
41,121261267.3,118972860.0
42,124085800.1,113853950.0
43,111506694.9,114263896.0
44,106343825.4,117712260.0
45,121261267.3,118652850.0
46,124085800.1,116123000.0
47,111506694.9,112103380.0
48,106343825.4,114997496.0
49,123690748.8,118147590.0
50,124085800.1,113974550.0
51,111506694.9,114976400.0
52,106343825.4,114994730.0
53,123690748.8,119030696.0
54,126477907.1,118913830.0
55,111506694.9,128059680.0
56,106343825.4,128777160.0
57,123690748.8,132264690.0
58,126477907.1,134432380.0
59,127784089.0,130678440.0
60,106343825.4,117309140.0
61,123690748.8,120235380.0
62,126477907.1,124647930.0
63,127784089.0,121395256.0
64,121978961.0,120084936.0
65,123690748.8,115211520.0
66,126477907.1,116868504.0
67,127784089.0,118030080.0
68,121978961.0,115720190.0
69,115593904.4,115036584.0
70,126477907.1,122232480.0
71,127784089.0,120755330.0
72,121978961.0,119494640.0
73,115593904.4,118518104.0
74,108270758.4,118762740.0
75,127784089.0,121758630.0
76,121978961.0,119449144.0
77,115593904.4,117008670.0
78,108270758.4,116879120.0
79,106730768.7,118586610.0
1 real pred
2 0 135488382.9 126543270.0
3 1 139728296.4 132820504.0
4 2 137821025.8 111796240.0
5 3 78600422.1 131913600.0
6 4 126378434.7 121779730.0
7 5 139728296.4 133275510.0
8 6 137821025.8 132085740.0
9 7 78600422.1 123508616.0
10 8 126378434.7 122062880.0
11 9 112584670.1 126883176.0
12 10 137821025.8 127550950.0
13 11 78600422.1 127659760.0
14 12 126378434.7 120534296.0
15 13 112584670.1 120153140.0
16 14 126043317.8 126331544.0
17 15 78600422.1 124620616.0
18 16 126378434.7 125789710.0
19 17 112584670.1 120741016.0
20 18 126043317.8 122845240.0
21 19 127571038.2 126469840.0
22 20 126378434.7 89965520.0
23 21 112584670.1 97125790.0
24 22 126043317.8 97640040.0
25 23 127571038.2 97961870.0
26 24 126803223.9 100093896.0
27 25 112584670.1 118327300.0
28 26 126043317.8 119570870.0
29 27 127571038.2 120226640.0
30 28 126803223.9 122849880.0
31 29 121261267.3 119103896.0
32 30 126043317.8 112915770.0
33 31 127571038.2 113519840.0
34 32 126803223.9 115918376.0
35 33 121261267.3 112486350.0
36 34 124085800.1 111817100.0
37 35 127571038.2 122143170.0
38 36 126803223.9 119428750.0
39 37 121261267.3 117216230.0
40 38 124085800.1 115505320.0
41 39 111506694.9 116549080.0
42 40 126803223.9 123125660.0
43 41 121261267.3 118972860.0
44 42 124085800.1 113853950.0
45 43 111506694.9 114263896.0
46 44 106343825.4 117712260.0
47 45 121261267.3 118652850.0
48 46 124085800.1 116123000.0
49 47 111506694.9 112103380.0
50 48 106343825.4 114997496.0
51 49 123690748.8 118147590.0
52 50 124085800.1 113974550.0
53 51 111506694.9 114976400.0
54 52 106343825.4 114994730.0
55 53 123690748.8 119030696.0
56 54 126477907.1 118913830.0
57 55 111506694.9 128059680.0
58 56 106343825.4 128777160.0
59 57 123690748.8 132264690.0
60 58 126477907.1 134432380.0
61 59 127784089.0 130678440.0
62 60 106343825.4 117309140.0
63 61 123690748.8 120235380.0
64 62 126477907.1 124647930.0
65 63 127784089.0 121395256.0
66 64 121978961.0 120084936.0
67 65 123690748.8 115211520.0
68 66 126477907.1 116868504.0
69 67 127784089.0 118030080.0
70 68 121978961.0 115720190.0
71 69 115593904.4 115036584.0
72 70 126477907.1 122232480.0
73 71 127784089.0 120755330.0
74 72 121978961.0 119494640.0
75 73 115593904.4 118518104.0
76 74 108270758.4 118762740.0
77 75 127784089.0 121758630.0
78 76 121978961.0 119449144.0
79 77 115593904.4 117008670.0
80 78 108270758.4 116879120.0
81 79 106730768.7 118586610.0

@ -0,0 +1,81 @@
,real,pred
0,50.73,-3852886.8
1,47.88,19293490.0
2,48.31,-787532.5
3,57.75,17889912.0
4,49.95,14566443.0
5,47.88,-2105174.2
6,48.31,9359497.0
7,57.75,-433206.5
8,49.95,7017683.0
9,39.61,7703169.0
10,48.31,-808642.3
11,57.75,4597720.0
12,49.95,-360913.0
13,39.61,3072979.0
14,56.33,4244487.0
15,57.75,-51642.7
16,49.95,2367718.2
17,39.61,-302210.2
18,56.33,1611450.0
19,52.92,2502897.2
20,49.95,276224.3
21,39.61,1291416.2
22,56.33,-257380.97
23,52.92,980514.25
24,21.66,1559440.0
25,39.61,357924.16
26,56.33,741489.3
27,52.92,-223402.89
28,21.66,639626.94
29,17.11,1005113.4
30,56.33,332503.4
31,52.92,443570.0
32,21.66,-191401.42
33,17.11,425019.1
34,21.02,662274.2
35,52.92,272184.7
36,21.66,272402.28
37,17.11,-157965.39
38,21.02,281494.4
39,21.65,444933.7
40,21.66,216378.39
41,17.11,174190.5
42,21.02,-121474.26
43,21.65,189702.75
44,17.95,310332.28
45,17.11,170874.25
46,21.02,115950.27
47,21.65,-86960.06
48,17.95,130457.57
49,32.64,225947.27
50,21.02,136433.27
51,21.65,80172.625
52,17.95,-57612.453
53,32.64,93036.65
54,36.2,172486.55
55,21.65,111274.96
56,17.95,57095.465
57,32.64,-35822.484
58,36.2,68347.86
59,36.29,137399.97
60,17.95,91150.695
61,32.64,40596.707
62,36.2,-20513.984
63,36.29,51582.69
64,34.19,112940.664
65,32.64,78077.64
66,36.2,31135.309
67,36.29,-9841.227
68,34.19,40612.53
69,14.42,97921.65
70,36.2,70152.695
71,36.29,25340.3
72,34.19,-927.8079
73,14.42,34880.83
74,17.77,88737.21
75,36.29,62387.992
76,34.19,19532.104
77,14.42,3516.0425
78,17.77,29056.807
79,126.08,80504.98
1 real pred
2 0 50.73 -3852886.8
3 1 47.88 19293490.0
4 2 48.31 -787532.5
5 3 57.75 17889912.0
6 4 49.95 14566443.0
7 5 47.88 -2105174.2
8 6 48.31 9359497.0
9 7 57.75 -433206.5
10 8 49.95 7017683.0
11 9 39.61 7703169.0
12 10 48.31 -808642.3
13 11 57.75 4597720.0
14 12 49.95 -360913.0
15 13 39.61 3072979.0
16 14 56.33 4244487.0
17 15 57.75 -51642.7
18 16 49.95 2367718.2
19 17 39.61 -302210.2
20 18 56.33 1611450.0
21 19 52.92 2502897.2
22 20 49.95 276224.3
23 21 39.61 1291416.2
24 22 56.33 -257380.97
25 23 52.92 980514.25
26 24 21.66 1559440.0
27 25 39.61 357924.16
28 26 56.33 741489.3
29 27 52.92 -223402.89
30 28 21.66 639626.94
31 29 17.11 1005113.4
32 30 56.33 332503.4
33 31 52.92 443570.0
34 32 21.66 -191401.42
35 33 17.11 425019.1
36 34 21.02 662274.2
37 35 52.92 272184.7
38 36 21.66 272402.28
39 37 17.11 -157965.39
40 38 21.02 281494.4
41 39 21.65 444933.7
42 40 21.66 216378.39
43 41 17.11 174190.5
44 42 21.02 -121474.26
45 43 21.65 189702.75
46 44 17.95 310332.28
47 45 17.11 170874.25
48 46 21.02 115950.27
49 47 21.65 -86960.06
50 48 17.95 130457.57
51 49 32.64 225947.27
52 50 21.02 136433.27
53 51 21.65 80172.625
54 52 17.95 -57612.453
55 53 32.64 93036.65
56 54 36.2 172486.55
57 55 21.65 111274.96
58 56 17.95 57095.465
59 57 32.64 -35822.484
60 58 36.2 68347.86
61 59 36.29 137399.97
62 60 17.95 91150.695
63 61 32.64 40596.707
64 62 36.2 -20513.984
65 63 36.29 51582.69
66 64 34.19 112940.664
67 65 32.64 78077.64
68 66 36.2 31135.309
69 67 36.29 -9841.227
70 68 34.19 40612.53
71 69 14.42 97921.65
72 70 36.2 70152.695
73 71 36.29 25340.3
74 72 34.19 -927.8079
75 73 14.42 34880.83
76 74 17.77 88737.21
77 75 36.29 62387.992
78 76 34.19 19532.104
79 77 14.42 3516.0425
80 78 17.77 29056.807
81 79 126.08 80504.98

@ -0,0 +1,81 @@
,real,pred
0,2704531.598,-2316342.0
1,2847401.395,20676706.0
2,2853452.254,621314.75
3,1445328.766,19298740.0
4,2859827.803,15913489.0
5,2847401.395,160415.81
6,2853452.254,11473635.0
7,1445328.766,1780781.9
8,2859827.803,9122751.0
9,2596527.222,9748031.0
10,2853452.254,1740487.1
11,1445328.766,7060606.5
12,2859827.803,2215774.0
13,2596527.222,5502353.0
14,2493576.394,6607667.0
15,1445328.766,2590090.2
16,2859827.803,4988868.5
17,2596527.222,2414713.2
18,2493576.394,4189292.0
19,2590630.491,5007022.5
20,2859827.803,2663833.5
21,2596527.222,3738563.2
22,2493576.394,2267029.0
23,2590630.491,3384996.8
24,2542066.764,3895295.0
25,2596527.222,2791455.5
26,2493576.394,3239150.8
27,2590630.491,2324025.0
28,2542066.764,3109880.2
29,2348001.351,3393537.5
30,2493576.394,2799275.0
31,2590630.491,2971549.0
32,2542066.764,2372881.5
33,2348001.351,2928644.5
34,2358799.425,3085286.2
35,2590630.491,2734414.0
36,2542066.764,2794325.2
37,2348001.351,2387036.8
38,2358799.425,2778122.0
39,2280758.623,2866578.2
40,2542066.764,2720565.0
41,2348001.351,2730321.8
42,2358799.425,2451633.2
43,2280758.623,2726225.2
44,2178461.623,2774300.0
45,2348001.351,2671150.2
46,2358799.425,2668631.5
47,2280758.623,2480967.5
48,2178461.623,2668544.2
49,2333089.921,2699806.2
50,2358799.425,2573983.0
51,2280758.623,2581621.8
52,2178461.623,2458319.0
53,2333089.921,2591283.8
54,2464330.824,2604670.0
55,2280758.623,2550985.0
56,2178461.623,2560103.5
57,2333089.921,2483460.2
58,2464330.824,2577773.5
59,2501636.637,2571584.0
60,2178461.623,2486985.5
61,2333089.921,2497864.5
62,2464330.824,2443193.2
63,2501636.637,2499882.5
64,2391662.7,2499099.0
65,2333089.921,2403869.2
66,2464330.824,2416259.8
67,2501636.637,2374731.2
68,2391662.7,2414558.5
69,2132738.014,2411228.2
70,2464330.824,2412634.2
71,2501636.637,2417201.2
72,2391662.7,2387954.0
73,2132738.014,2420058.0
74,2089421.05,2411550.0
75,2501636.637,2422625.8
76,2391662.7,2422028.2
77,2132738.014,2402778.5
78,2089421.05,2425684.2
79,2169828.292,2417022.0
1 real pred
2 0 2704531.598 -2316342.0
3 1 2847401.395 20676706.0
4 2 2853452.254 621314.75
5 3 1445328.766 19298740.0
6 4 2859827.803 15913489.0
7 5 2847401.395 160415.81
8 6 2853452.254 11473635.0
9 7 1445328.766 1780781.9
10 8 2859827.803 9122751.0
11 9 2596527.222 9748031.0
12 10 2853452.254 1740487.1
13 11 1445328.766 7060606.5
14 12 2859827.803 2215774.0
15 13 2596527.222 5502353.0
16 14 2493576.394 6607667.0
17 15 1445328.766 2590090.2
18 16 2859827.803 4988868.5
19 17 2596527.222 2414713.2
20 18 2493576.394 4189292.0
21 19 2590630.491 5007022.5
22 20 2859827.803 2663833.5
23 21 2596527.222 3738563.2
24 22 2493576.394 2267029.0
25 23 2590630.491 3384996.8
26 24 2542066.764 3895295.0
27 25 2596527.222 2791455.5
28 26 2493576.394 3239150.8
29 27 2590630.491 2324025.0
30 28 2542066.764 3109880.2
31 29 2348001.351 3393537.5
32 30 2493576.394 2799275.0
33 31 2590630.491 2971549.0
34 32 2542066.764 2372881.5
35 33 2348001.351 2928644.5
36 34 2358799.425 3085286.2
37 35 2590630.491 2734414.0
38 36 2542066.764 2794325.2
39 37 2348001.351 2387036.8
40 38 2358799.425 2778122.0
41 39 2280758.623 2866578.2
42 40 2542066.764 2720565.0
43 41 2348001.351 2730321.8
44 42 2358799.425 2451633.2
45 43 2280758.623 2726225.2
46 44 2178461.623 2774300.0
47 45 2348001.351 2671150.2
48 46 2358799.425 2668631.5
49 47 2280758.623 2480967.5
50 48 2178461.623 2668544.2
51 49 2333089.921 2699806.2
52 50 2358799.425 2573983.0
53 51 2280758.623 2581621.8
54 52 2178461.623 2458319.0
55 53 2333089.921 2591283.8
56 54 2464330.824 2604670.0
57 55 2280758.623 2550985.0
58 56 2178461.623 2560103.5
59 57 2333089.921 2483460.2
60 58 2464330.824 2577773.5
61 59 2501636.637 2571584.0
62 60 2178461.623 2486985.5
63 61 2333089.921 2497864.5
64 62 2464330.824 2443193.2
65 63 2501636.637 2499882.5
66 64 2391662.7 2499099.0
67 65 2333089.921 2403869.2
68 66 2464330.824 2416259.8
69 67 2501636.637 2374731.2
70 68 2391662.7 2414558.5
71 69 2132738.014 2411228.2
72 70 2464330.824 2412634.2
73 71 2501636.637 2417201.2
74 72 2391662.7 2387954.0
75 73 2132738.014 2420058.0
76 74 2089421.05 2411550.0
77 75 2501636.637 2422625.8
78 76 2391662.7 2422028.2
79 77 2132738.014 2402778.5
80 78 2089421.05 2425684.2
81 79 2169828.292 2417022.0

@ -86,7 +86,7 @@ for excel in os.listdir(file_dir)[1:]:
dataset_y = np.concatenate((dataset_y,y))
print(dataset_x.shape,dataset_y.shape)
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -112,22 +112,23 @@ train_y = torch.from_numpy(train_y).to(device).type(torch.float32)
model = LSTM_Regression(DAYS_FOR_TRAIN, 32, output_size=5, num_layers=2).to(device) # 导入模型并设置模型的参数输入输出层、隐藏层等
train_loss = []
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
for i in range(1500):
out = model(train_x)
loss = loss_function(out, train_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss.append(loss.item())
if i % 100 == 0:
print(f'epoch {i+1}: loss:{loss}')
# train_loss = []
# loss_function = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# for i in range(1500):
# out = model(train_x)
# loss = loss_function(out, train_y)
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# train_loss.append(loss.item())
# if i % 100 == 0:
# print(f'epoch {i+1}: loss:{loss}')
# 保存模型
torch.save(model.state_dict(),'hy5.pth')
# 保存/读取模型
# torch.save(model.state_dict(),'hy5.pth')
model.load_state_dict(torch.load('hy5.pth'))
# for test
model = model.eval() # 转换成测试模式
# model.load_state_dict(torch.load(os.path.join(model_save_dir,model_file))) # 读取参数
@ -139,25 +140,40 @@ pred_test = model(dataset_x) # 全量训练集
pred_test = pred_test.view(-1)
pred_test = np.concatenate((np.zeros(DAYS_FOR_TRAIN), pred_test.cpu().detach().numpy()))
plt.plot(pred_test.reshape(-1), 'r', label='prediction')
plt.plot(dataset_y.reshape(-1), 'b', label='real')
plt.plot((train_size*5, train_size*5), (0, 1), 'g--') # 分割线 左边是训练数据 右边是测试数据的输出
plt.legend(loc='best')
plt.show()
# plt.plot(pred_test.reshape(-1), 'r', label='prediction')
# plt.plot(dataset_y.reshape(-1), 'b', label='real')
# plt.plot((train_size*5, train_size*5), (0, 1), 'g--') # 分割线 左边是训练数据 右边是测试数据的输出
# plt.legend(loc='best')
# plt.show()
# 创建测试集
# result_list = []
# 以x为基础实际数据滚动预测未来3天
# x = torch.from_numpy(df[-14:-4]).to(device)
# pred = model(x.reshape(-1,1,DAYS_FOR_TRAIN)).view(-1).detach().numpy()
df_eval = pd.read_excel(r'C:\Users\user\Desktop\浙江各地市行业电量数据\ 杭州 .xlsx',index_col='stat_date')
df_eval.columns = df_eval.columns.map(lambda x:x.strip())
df_eval.index = pd.to_datetime(df_eval.index)
x,y = create_dataset(df_eval.loc['2023-7']['第三产业'],10)
x = (x - min_value) / (max_value - min_value)
x = x.reshape(-1,1,10)
x = torch.from_numpy(x).type(torch.float32).to(device)
pred = model(x)
# 反归一化
pred = pred * (max_value - min_value) + min_value
# df = df * (max_value - min_value) + min_value
df = pd.DataFrame({'real':y.reshape(-1),'pred':pred.view(-1).cpu().detach().numpy()})
print(df)
df.to_csv('7月第三产业.csv',encoding='gbk')
# 反归一化
# pred = pred * (max_value - min_value) + min_value
# df = df * (max_value - min_value) + min_value
# print(pred)
# # 打印指标
# print(abs(pred - df[-3:]).mean() / df[-3:].mean())
# result_eight = pd.DataFrame({'pred': np.round(pred,1),'real': df[-3:]})

@ -0,0 +1,195 @@
import numpy as np
import pandas as pd
import torch
from torch import nn
from multiprocessing import Pool
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
DAYS_FOR_TRAIN = 10
torch.manual_seed(42)
class LSTM_Regression(nn.Module):
def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, _x):
x, _ = self.lstm(_x) # _x is input, size (seq_len, batch, input_size)
s, b, h = x.shape # x is output, size (seq_len, batch, hidden_size)
x = x.view(s * b, h)
x = self.fc(x)
x = x.view(s, b, -1) # 把形状改回来
return x
def create_dataset(data, days_for_train=5) -> (np.array, np.array):
dataset_x, dataset_y = [], []
for i in range(len(data) - days_for_train-5):
dataset_x.append(data[i:(i + days_for_train)])
dataset_y.append(data[i + days_for_train:i + days_for_train+5])
# print(dataset_x,dataset_y)
return (np.array(dataset_x), np.array(dataset_y))
def normal(nd):
high = nd.describe()['75%'] + 1.5*(nd.describe()['75%']-nd.describe()['25%'])
low = nd.describe()['25%'] - 1.5*(nd.describe()['75%']-nd.describe()['25%'])
return nd[(nd<high)&(nd>low)]
def data_preprocessing(data):
data.columns = data.columns.map(lambda x: x.strip())
data.index = pd.to_datetime(data.index)
data.sort_index(inplace=True)
data = data.loc['2021-01':'2023-08']
data.drop(columns=[i for i in data.columns if (data[i] == 0).sum() / len(data) >= 0.5], inplace=True) # 去除0值列
data = data[data.values != 0]
data = data.astype(float)
for col in data.columns:
data[col] = normal(data[col])
return data
if __name__ == '__main__':
# 拼接数据集
file_dir = r'C:\Users\user\Desktop\浙江各地市分电压日电量数据'
excel = os.listdir(file_dir)[0]
data = pd.read_excel(os.path.join(file_dir, excel), sheet_name=0, index_col=' stat_date ')
data = data_preprocessing(data)
df = data[data.columns[0]]
df.dropna(inplace = True)
dataset_x, dataset_y = create_dataset(df, DAYS_FOR_TRAIN)
for level in data.columns[1:]:
df = data[level]
df.dropna(inplace=True)
x, y = create_dataset(df, DAYS_FOR_TRAIN)
dataset_x = np.concatenate((dataset_x, x))
dataset_y = np.concatenate((dataset_y, y))
for excel in os.listdir(file_dir)[1:]:
data = pd.read_excel(os.path.join(file_dir,excel), sheet_name=0,index_col=' stat_date ')
data = data_preprocessing(data)
for level in data.columns:
df = data[level]
df.dropna(inplace=True)
x,y = create_dataset(df,DAYS_FOR_TRAIN)
dataset_x = np.concatenate((dataset_x,x))
dataset_y = np.concatenate((dataset_y,y))
print(dataset_x,dataset_y,dataset_x.shape,dataset_y.shape)
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 标准化到0~1
max_value = np.max(dataset_x)
min_value = np.min(dataset_x)
dataset_x = (dataset_x - min_value) / (max_value - min_value)
dataset_y = (dataset_y - min_value) / (max_value - min_value)
# 划分训练集和测试集
train_size = int(len(dataset_x)*0.7)
train_x = dataset_x[:train_size]
train_y = dataset_y[:train_size]
# 将数据改变形状RNN 读入的数据维度是 (seq_size, batch_size, feature_size)
train_x = train_x.reshape(-1, 1, DAYS_FOR_TRAIN)
train_y = train_y.reshape(-1, 1, 5)
# 转为pytorch的tensor对象
train_x = torch.from_numpy(train_x).to(device).type(torch.float32)
train_y = torch.from_numpy(train_y).to(device).type(torch.float32)
model = LSTM_Regression(DAYS_FOR_TRAIN, 32, output_size=5, num_layers=2).to(device) # 导入模型并设置模型的参数输入输出层、隐藏层等
train_loss = []
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# for i in range(1500):
# out = model(train_x)
# loss = loss_function(out, train_y)
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# train_loss.append(loss.item())
# # print(loss)
# # 保存模型
# torch.save(model.state_dict(),'dy5.pth')
model.load_state_dict(torch.load('dy5.pth'))
# for test
model = model.eval() # 转换成测试模式
# model.load_state_dict(torch.load(os.path.join(model_save_dir,model_file))) # 读取参数
dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN) # (seq_size, batch_size, feature_size)
dataset_x = torch.from_numpy(dataset_x).to(device).type(torch.float32)
pred_test = model(dataset_x) # 全量训练集
# 模型输出 (seq_size, batch_size, output_size)
pred_test = pred_test.view(-1)
pred_test = np.concatenate((np.zeros(DAYS_FOR_TRAIN), pred_test.cpu().detach().numpy()))
# plt.plot(pred_test.reshape(-1), 'r', label='prediction')
# plt.plot(dataset_y.reshape(-1), 'b', label='real')
# plt.plot((train_size*5, train_size*5), (0, 1), 'g--') # 分割线 左边是训练数据 右边是测试数据的输出
# plt.legend(loc='best')
# plt.show()
# 创建测试集
# result_list = []
# 以x为基础实际数据滚动预测未来3天
df_eval = pd.read_excel(r'C:\Users\user\Desktop\浙江各地市分电压日电量数据\杭州.xlsx',index_col=' stat_date ')
df_eval.columns = df_eval.columns.map(lambda x:x.strip())
df_eval.index = pd.to_datetime(df_eval.index)
x,y = create_dataset(df_eval.loc['2023-7']['10kv以下'],10)
x = (x - min_value) / (max_value - min_value)
x = x.reshape(-1,1,10)
x = torch.from_numpy(x).type(torch.float32).to(device)
pred = model(x)
# 反归一化
pred = pred * (max_value - min_value) + min_value
# df = df * (max_value - min_value) + min_value
print(pred,y)
df = pd.DataFrame({'real':y.reshape(-1),'pred':pred.view(-1).cpu().detach().numpy()})
df.to_csv('7月全行业.csv',encoding='gbk')
# 打印指标
# print(abs(pred - df[-3:]).mean() / df[-3:].mean())
# result_eight = pd.DataFrame({'pred': np.round(pred,1),'real': df[-3:]})
# target = (result_eight['pred'].sum() - result_eight['real'].sum()) / df[-31:].sum()
# result_eight['loss_rate'] = round(target, 5)
# result_eight['level'] = level
# list_app.append(result_eight)
# print(target)
# print(result_eight)
# final_df = pd.concat(list_app,ignore_index=True)
# final_df.to_csv('市行业电量.csv',encoding='gbk')
# print(final_df)
Loading…
Cancel
Save