微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 Dict 观察空间中运行 StableBaselines3

如何解决在 Dict 观察空间中运行 StableBaselines3

我有一个自定义环境,它返回一个 dict 观察空间,如下所示:

OrderedDict([('achieved_goal',array([ 0.4008276,-0.0685866,-0.22774519,0.05827878,0.47759697,0.7327185,2.4765387,-0.8607227,0.89627784,-0.3062557,-0.60894597,-1.4110374 ],dtype=float32)),('desired_goal',array([-1.005679,0.34147817,0.9540531,1.1987132,0.37403303,0.32209057,0.31095287,-2.1119647,0.82215786,-0.6675792,-1.5640837,0.7348459 ],('observation',array([-0.39490733,-0.67843455,-0.43765455,0.1409685,-0.67161006,1.3106273,0.04009145,-1.714885,-1.7085567,-0.44895488,-0.6111999,-1.9730839,0.93647414,0.2714189,-0.67204314,0.8948596,-0.14034131,1.0312599,-1.2369561,-0.2345652,-0.17095046,0.36576194,0.9939435,-1.0381949,-1.2953175,1.4120669,-0.23294891,0.30627772,-1.2250876,-0.35871807,1.3074456,-1.060916,-2.451866,0.18679707,0.609564,-0.16821782,-0.8448521,-1.0025802,0.6878543,-2.1562986,0.6426088,1.386251,1.0454125,-2.2426984 ],dtype=float32))])

但是,像 PPO 这样的算法不能使用 dict 空间。当我尝试过滤掉观察空间时,出现如下错误

我如何过滤:

env.observation_space = env.observation_space['observation']

错误跟踪:

Traceback (most recent call last):
  File "PPO.py",line 69,in <module>
    model.learn(total_timesteps=25000)
  File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/ppo/ppo.py",line 289,in learn
    reset_num_timesteps=reset_num_timesteps,File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/on_policy_algorithm.py",line 220,in learn
    total_timesteps,eval_env,callback,eval_freq,n_eval_episodes,eval_log_path,reset_num_timesteps,tb_log_name
  File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/base_class.py",line 379,in _setup_learn
    self._last_obs = self.env.reset()
  File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py",line 62,in reset
    self._save_obs(env_idx,obs)
  File "/home/yb1025/.conda/envs/allegro_gym/lib/python3.6/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py",line 92,in _save_obs
    self.buf_obs[key][env_idx] = obs
TypeError: float() argument must be a string or a number,not 'dict'

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。