Actor-Critic算法简介
:Actor基于当前策略选择动作并执行,Critic则根据环境反馈评估该动作的好坏,生成优势函数来指导Actor的策略更新。用Python实现一个简单的Actor-Critic模型来训练智能体解决CartPole平衡问题。:广泛应用于机器人控制、游戏AI、能源管理等领域,特别适合动作空间复杂、需要精细控制的场景。Actor-Critic算法是一种结合了策略梯度和价值函数优点的强化学习方法。
·
Actor-Critic算法是一种结合了策略梯度和价值函数优点的强化学习方法。
核心思想:算法包含两个部分协同工作:
- Actor(演员):负责执行策略,根据当前状态选择动作
- Critic(评论家):负责评估价值,对Actor选择的动作进行评分
工作流程:Actor基于当前策略选择动作并执行,Critic则根据环境反馈评估该动作的好坏,生成优势函数来指导Actor的策略更新。
主要优势:
- 高效学习:相比纯策略梯度方法,能够实现单步更新而非回合更新
- 低方差:使用Critic的价值估计减少了策略梯度的方差
- 处理连续动作:适用于连续动作空间的问题
应用场景:广泛应用于机器人控制、游戏AI、能源管理等领域,特别适合动作空间复杂、需要精细控制的场景。
以下是用Python实现一个简单的Actor-Critic模型来训练智能体解决CartPole平衡问题。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, state):
return self.network(state)
class CriticNetwork(nn.Module):
def __init__(self, state_dim):
super(CriticNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, state):
return self.network(state)
class ActorCritic:
def __init__(self, state_dim, action_dim, learning_rate=0.001, gamma=0.99):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.gamma = gamma
# 初始化Actor和Critic网络
self.actor = ActorNetwork(state_dim, action_dim).to(self.device)
self.critic = CriticNetwork(state_dim).to(self.device)
# 初始化优化器
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)
# 经验回放缓冲区
self.memory = deque(maxlen=10000)
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_probs = self.actor(state)#通过Actor网络计算各动作的概率分布
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()#基于概率分布进行随机采样确定最终动作
return action.item(), action_dist.log_prob(action)#返回选择的动作及其对数概率用于后续梯度计算
def store_experience(self, state, action, reward, next_state, done, log_prob):
self.memory.append((state, action, reward, next_state, done, log_prob))
def update(self):
if len(self.memory) < 32: # 小批量更新
return
# 随机采样经验
batch = random.sample(self.memory, 32)
states, actions, rewards, next_states, dones, old_log_probs = zip(*batch)
states = torch.FloatTensor(np.array(states)).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
dones = torch.BoolTensor(dones).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
old_log_probs = torch.stack(old_log_probs).to(self.device)
# 计算目标值和优势函数
with torch.no_grad():
next_values = self.critic(next_states).squeeze()
target_values = rewards + (1 - dones.float()) * self.gamma * next_values
current_values = self.critic(states).squeeze()
advantages = target_values - current_values
# 更新Actor网络,首先通过Actor网络获取当前状态下各动作的概率分布,
#然后创建分类分布对象来表征这个策略。
#接着计算在给定状态下采取实际执行动作的对数概率,
#这个对数概率值将用于后续的策略梯度计算。
action_probs = self.actor(states)
dist = torch.distributions.Categorical(action_probs)
log_probs = dist.log_prob(actions)
# 策略梯度损失
"""
实现了Actor-Critic算法中Actor网络的策略梯度损失计算:
功能说明:
通过负对数概率与优势函数的乘积计算策略梯度
使用均值操作获得批量样本的平均损失
通过detach()确保优势值不参与Actor网络的梯度计算
核心作用:
当优势值为正时,减小负对数概率,增加该动作的选择概率
当优势值为负时,增大负对数概率,减少该动作的选择概率
实现策略改进,使智能体更倾向于选择高回报动作
优化策略网络的参数,提升决策质量
该损失函数是强化学习中策略优化方法的核心,通过优势函数指导策略网络的梯度更新方向。
"""
actor_loss = -(log_probs * advantages.detach()).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新Critic网络
critic_loss = nn.MSELoss()(self.critic(states).squeeze(), target_values)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
def train(self, env_name="CartPole-v1", episodes=1000, max_steps=500):
env = gym.make(env_name)
rewards_history = []
for episode in range(episodes):
state, _ = env.reset()
episode_reward = 0
episode_log_probs = []
for step in range(max_steps):
action, log_prob = self.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 存储经验
self.store_experience(state, action, reward, next_state, done, log_prob)
state = next_state
episode_reward += reward
episode_log_probs.append(log_prob)
if done:
break
# 每回合结束后更新网络
self.update()
rewards_history.append(episode_reward)
# 打印训练进度
if (episode + 1) % 50 == 0:
avg_reward = np.mean(rewards_history[-50:])
print(f"Episode {episode + 1}, Average Reward: {avg_reward:.2f}")
# 如果连续100轮平均奖励达到195,认为问题已解决
if len(rewards_history) >= 100 and np.mean(rewards_history[-100:]) >= 195:
print(f"Solved at episode {episode + 1}!")
break
env.close()
return rewards_history
def save_model(self, actor_path="actor_model.pth", critic_path="critic_model.pth"):
torch.save(self.actor.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
print("Models saved successfully!")
def load_model(self, actor_path="actor_model.pth", critic_path="critic_model.pth"):
self.actor.load_state_dict(torch.load(actor_path, map_location=self.device))
self.critic.load_state_dict(torch.load(critic_path, map_location=self.device))
print("Models loaded successfully!")
def main():
print("开始训练Actor-Critic模型...")
print("环境: CartPole-v1")
print("设备:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# 创建并训练模型
agent = ActorCritic(state_dim=4, action_dim=2)
rewards = agent.train(episodes=1000)
# 保存训练好的模型
agent.save_model()
print("训练完成!")
print(f"最终100轮平均奖励: {np.mean(rewards[-100:]):.2f}")
if __name__ == "__main__":
main()
"""
torch==2.1.0
numpy==1.24.0
该Actor-Critic模型实现包含以下核心功能:
- 构建了Actor策略网络和Critic价值网络的双神经网络架构
- 实现了基于策略梯度的Actor网络更新机制,使用优势函数指导策略改进
- 采用TD误差方法更新Critic网络,优化状态价值估计
- 包含经验回放缓冲区,提高样本利用效率
- 完整的训练循环,支持模型保存和加载功能
- 专门针对CartPole平衡问题进行优化,具备问题解决检测机制
模型特点包括模块化设计、小批量训练、完整的性能监控和跨平台兼容性,适用于强化学习的入门学习和实验验证
"""。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐

所有评论(0)