如何解决pytorch中的DQN算法不收敛
我是深度强化学习的新手,自己实现了算法,但价值没有收敛,任何人都可以看看并告诉我我的算法有什么问题,我可以做些什么来做得更好 这是代码:
import gym
import torch
import numpy as np
import torch
import random
from collections import deque
from itertools import count
class ReplayBuffer:
def __init__(self):
self.buffer=deque(maxlen=50000)
def push(self,state,action,reward,next_state,done):
if(len(self.buffer)<=1000):
self.buffer.append((state,done))
def sample(self,batch_size: int,continuous: bool = True):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
if continuous:
rand = random.randint(0,len(self.buffer) - batch_size)
return [self.buffer[i] for i in range(rand,rand + batch_size)]
else:
indexes = np.random.choice(np.arange(len(self.buffer)),size=batch_size,replace=False)
return [self.buffer[i] for i in indexes]
class NNetwork(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1=torch.nn.Linear(4,128)
self.l2=torch.nn.Linear(128,128)
self.l3=torch.nn.Linear(128,2)
self.optimizer=torch.optim.Adam(params=self.parameters(),lr=0.001)
self.criterion=torch.nn.MSELoss()
def forward(self,x):
al1=torch.nn.ReLU()(self.l1(x))
al2=torch.nn.ReLU()(self.l2(al1))
al3=self.l3(al2)
return al3
class Agent():
def __init__(self):
self.env=gym.make('CartPole-v0')
self.mem=ReplayBuffer()
self.q_local=NNetwork()
self.q_target=NNetwork()
self.q_target.load_state_dict(self.q_local.state_dict())
self.epsilon=1.0
self.e_decay=0.0995
self.e_min=0.1
self.update=4
self.score=0
self.gamma=0.99
def predict(self,state):
if (np.random.randn()<self.epsilon):
return random.randint(0,1)
else:
index=self.q_local.forward(torch.Tensor(state).unsqueeze(0))
return torch.argmax(index,dim=1).item()
def step(self):
state=self.env.reset()
done=False
i=0
while not done:
action=self.predict(state)
n_state,done,_=self.env.step(action)
self.mem.push(state,n_state,done)
self.score+=reward
self.learn()
state=n_state
i+=1
if(i%10==0):
if(self.epsilon>self.e_min):
self.epsilon=self.epsilon-self.e_decay
else:
self.epsilon=self.e_min
self.q_target.load_state_dict(self.q_local.state_dict())
print(self.score)
self.score=0
def learn(self):
if(len(self.mem.buffer)%32==0):
return
batch =self.mem.sample(32)
state,done= zip(*batch)
state=torch.Tensor(state)
action=torch.Tensor(action).unsqueeze(1)
n_state=torch.Tensor(n_state)
reward=torch.Tensor(reward).unsqueeze(1)
done=torch.Tensor(done).unsqueeze(1)
self.q_local.optimizer.zero_grad()
q_N=self.q_local.forward(state).gather(1,action.long())
q_t=self.q_target.forward(n_state)
y=reward+(1-done)*self.gamma*torch.max(q_t,dim=1,keepdim=True)[0]
loss=self.q_local.criterion(q_N,y)
loss.backward()
self.q_local.optimizer.step()
agent=Agent()
for t in count():
print("EP ",t)
agent.step()
好吧,我很乐意打出几分,但并没有收敛
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。