微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

有没有办法自动匹配多个相同的参数?

如何解决有没有办法自动匹配多个相同的参数?

我的模型中有多个深度神经网络,并希望它们具有相同的输入大小 (网络属于不同的类别)。比如我的模型是:

class Model:
 def __init__(self,cfg: DictConfig):
   self.net1 = Net1(**cfg.net1_hparams)
   self.net2 = Net2(**cfg.net2_hparams)

这里,Net1 和 Net2 有不同的超参数集,但其中的 input_size 参数在 Net1 和 Net2 之间共享,并且必须匹配,即 cfg.net1_hparams.input_size == cfg.net2_hparams.input_size

我可以在父级定义 input_size:cfg.input_size 并手动将它们传递给 Net1 和 Net2。但是,我希望每个 Net 的 hparams-configs 完整,以便以后我只能使用 cfg.net1_hparams 构建 Net1。

在 hydra 中是否有实现此目的的好方法

解决方法

这可以使用 OmegaConf 的 variable interpolation 功能来实现。

以下是使用 Hydra 进行变量插值以实现所需结果的最小示例:

# config.yaml
shared_hparams:
  input_size: [128,128]
net1_hparams:
  name: net one
  input_size: ${shared_hparams.input_size}
net2_hparams:
  name: net two
  input_size: ${shared_hparams.input_size}
"""my_app.py"""
import hydra
from omegaconf import DictConfig

class Model:
    def __init__(self,cfg: DictConfig):
        print("Net1",dict(**cfg.net1_hparams))
        print("Net2",dict(**cfg.net2_hparams))

@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    Model(cfg)

if __name__ == "__main__":
    my_app()

在命令行运行 my_app.py 会产生以下结果:

$ python my_app.py
Net1 {'name': 'net one','input_size': [128,128]}
Net2 {'name': 'net two',128]}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。