-
Notifications
You must be signed in to change notification settings - Fork 233
[Example] Add STAFNet Model for Air Quality Prediction #1070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 11 commits
f79a3f9
68b23d1
d9d2b54
bfa3e69
2d9dc85
57dc7c2
fa1cdee
ab1ae03
d257a49
b43c7f5
757477a
2b46497
711cd36
a79ad1d
af96434
86a9c0b
e2d8b60
66eeb9b
792796a
db2f093
862cbfb
726a026
5ccc908
0319c71
ff226a7
a5cda23
79e5e6b
d7e6c8a
cc71ffc
56fe67e
296e58c
99b1d1b
b6f55a4
3332954
f5557b9
878cc52
c2e88bd
bc748cd
c7d4426
a7b2eb1
c025ab2
ed22102
8d2433d
3a87833
88c2957
ec92f69
454729f
b1b4ad3
9c658dc
9195278
3e35a0d
803eb28
8599724
59cd4e0
84f91bf
5039b14
d38fbaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,85 @@ | ||||||||||||||||||||
| defaults: | ||||||||||||||||||||
| - ppsci_default | ||||||||||||||||||||
| - TRAIN: train_default | ||||||||||||||||||||
| - TRAIN/ema: ema_default | ||||||||||||||||||||
| - TRAIN/swa: swa_default | ||||||||||||||||||||
| - EVAL: eval_default | ||||||||||||||||||||
| - INFER: infer_default | ||||||||||||||||||||
| - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||||||||||||||||||||
| - _self_ | ||||||||||||||||||||
| hydra: | ||||||||||||||||||||
|
||||||||||||||||||||
| defaults: | |
| - ppsci_default | |
| - TRAIN: train_default | |
| - TRAIN/ema: ema_default | |
| - TRAIN/swa: swa_default | |
| - EVAL: eval_default | |
| - INFER: infer_default | |
| - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
| - _self_ |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里的路径是否能改成相对路径?比如
./dataset/train_data.pkl,其余的路径字段也是,建议改为相对路径,并去掉用户名 - STAFNet_DATA_PATH是否应该放到DATASET字段下?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- data_dir为什么是具体文件路径而不是某个文件夹路径?
- 此处的路径是否跟STAFNet_DATA_PATH重复了?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| input_keys: ["aq_train_data","mete_train_data",] | |
| input_keys: [aq_train_data, mete_train_data] |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_keys: ["label"] | |
| output_keys: [label] |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" | |
| eval_data_path: ./dataset/val_data.pkl |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,153 @@ | ||||||||||||||||||||||||||||
| import ppsci | ||||||||||||||||||||||||||||
| from ppsci.utils import logger | ||||||||||||||||||||||||||||
| from omegaconf import DictConfig | ||||||||||||||||||||||||||||
| import hydra | ||||||||||||||||||||||||||||
| import paddle | ||||||||||||||||||||||||||||
| from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn | ||||||||||||||||||||||||||||
| import multiprocessing | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def train(cfg: DictConfig): | ||||||||||||||||||||||||||||
| # set model | ||||||||||||||||||||||||||||
| model = ppsci.arch.STAFNet(**cfg.MODEL) | ||||||||||||||||||||||||||||
| train_dataloader_cfg = { | ||||||||||||||||||||||||||||
| "dataset": { | ||||||||||||||||||||||||||||
| "name": "STAFNetDataset", | ||||||||||||||||||||||||||||
| "file_path": cfg.DATASET.data_dir, | ||||||||||||||||||||||||||||
| "input_keys": cfg.MODEL.input_keys, | ||||||||||||||||||||||||||||
| "label_keys": cfg.MODEL.output_keys, | ||||||||||||||||||||||||||||
| "seq_len": cfg.MODEL.seq_len, | ||||||||||||||||||||||||||||
| "pred_len": cfg.MODEL.pred_len, | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是EVAL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个"sampler"字段是否可以删掉?eval应该不需要shuffle
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在paddle里,如果你的学习率是lr_scheduler,那么就需要把optimizer的learning_rate设置为lr_scheduler,而不是初始学习率
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_dir = cfg.output_dir |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_dir, | |
| cfg.output_dir, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| seed=cfg.seed, |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除注释
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | |
| Validate after training an epoch | |
| :param epoch: Integer, current training epoch. | |
| :return: A log that contains information about validation | |
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "sampler": { | |
| "name": "BatchSampler", | |
| "drop_last": False, | |
| "shuffle": True, | |
| }, |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback自动创建:
PaddleScience/ppsci/utils/callbacks.py
Lines 90 to 96 in fad6927
| logger.init_logger( | |
| "ppsci", | |
| osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log") | |
| if full_cfg.output_dir and full_cfg.mode not in ["export", "infer"] | |
| else None, | |
| full_cfg.log_level, | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # set random seed for reproducibility | |
| ppsci.utils.misc.set_random_seed(42) | |
| # set output directory | |
| OUTPUT_DIR = "./output_example" | |
| # initialize logger | |
| logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句代码是什么作用?paddle的多卡训练不需要这样吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件建议使用vscode的yaml插件格式化一下,或者提交前用pre-commit格式化:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#1