1 Star 3 Fork 0

guox66 / wind_prediction

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 2.94 KB
一键复制 编辑 原始数据 按行查看 历史
Xu Guo 提交于 2024-02-28 20:59 . Add files via upload
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from tqdm import tqdm
from torch import nn
from torch.optim.lr_scheduler import StepLR
from LSTM import LSTM, CNN_LSTM, Seq2Seq
from config import *
from utils import process
if __name__ == '__main__':
print(f'{r_name} training...')
os.makedirs(model_path, exist_ok=True)
os.makedirs(f'{result_path}/train', exist_ok=True)
Dtr = process(train_data, batch_size, True, interval, pred_size, output_n)
DVa = process(val_data, batch_size, True, interval, pred_size, output_n)
if model_name == 'LSTM':
model = LSTM(input_size, hidden_size, num_layers, pred_size, batch_size, device)
elif model_name == 'CNN_LSTM':
model = CNN_LSTM(input_size, hidden_size, num_layers, pred_size)
elif model_name == 'Seq2Seq':
model = Seq2Seq(input_size, hidden_size, num_layers, pred_size, batch_size, device)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
loss_fn = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
min_val_loss = np.Inf
loss_list = []
for epoch in tqdm(range(max_epochs)):
train_loss = []
model.train()
for (seq, label) in Dtr:
seq = seq.to(device)
label = label.to(device)
y_pred = model(seq)
# print(label.shape, y_pred.shape)
loss = loss_fn(y_pred, label)
train_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
total_val_loss = 0
with torch.no_grad(): # 验证数据集时禁止反向传播优化权重
for seq, label in DVa:
seq = seq.to(device)
label = label.to(device)
outputs = model(seq)
loss = loss_fn(outputs, label)
total_val_loss = total_val_loss + loss.item()
loss_list.append(total_val_loss)
if total_val_loss < min_val_loss:
min_val_loss = total_val_loss
m_epoch = epoch
torch.save(model, f"{model_path}/model-{r_name}.pth") # 保存最好的模型
print()
print(f'本次训练损失最小的epoch为{m_epoch},最小损失为{min_val_loss}')
figure(figsize=(12.8, 9.6))
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(loss_list, color='red', label='损失曲线')
plt.scatter(m_epoch, min_val_loss, color='blue', s=50)
plt.text(m_epoch, min_val_loss - min_val_loss * 0.5, '%.6f' % min_val_loss, ha='center', va='bottom', size=20)
plt.title(f'LOSS-{r_name}', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=20)
plt.ylim((0, max(loss_list)))
plt.savefig(f'{result_path}/train/LOSS-{r_name}.png')
Python
1
https://gitee.com/guox66/wind_prediction.git
git@gitee.com:guox66/wind_prediction.git
guox66
wind_prediction
wind_prediction
main

搜索帮助