I am trying to train a TD3 algorithm to place points in 3d space.
However, I am currently not able to even get the model to overfit on a small number of data points.
As far as I can tell, part of the issue is that the episodes mostly have progressively more negative and negative rewards (measured by change in MSE from previous position) leading to a critic that simply always predicts negative q values because the positive rewards as so sparse.
Does anyone have any advice?
Below, I have provided my code for my custom environment and training.
class PDB_Env(gym.Env):
"""A custom Gymnasium environment for protein-water refinement.
This environment now takes pre-computed density maps and XMap objects,
simplifying its role to managing the state and reward calculation during
the simulation.
"""
def __init__(self, water_data_pos, protein_data_pos, water_data_ele, protein_data_ele, ground_truth_maps,
n_repeats=3, noise=1.5, termination_mse=1, termination_reward=100,
query_radius=.5, proximity_threshold=1,
mse_weight = 100.0, density_weight = 1):
super().__init__()
self.termination_mse = termination_mse
self._original_water_data_pos = water_data_pos
self._original_protein_atoms_pos = protein_data_pos
self._original_ground_truth_maps = ground_truth_maps
self._original_water_ele = water_data_ele
self._original_protein_ele = protein_data_ele
self.n_proteins = len(self._original_water_data_pos)
self.n_repeats = n_repeats
self.noise = noise
self.termination_reward = termination_reward
self.termination_mse = termination_mse
self.query_radius = query_radius
self.mse_weight = mse_weight
self.density_weight = density_weight
self.clash_penalty = 5
self.proximity_threshold = proximity_threshold
self._create_protein_sequence()
self.protein_sequence_idx = -1
initial_n_atoms = self._original_water_data_pos[0].shape[0] if self.n_proteins > 0 else 0
self.observation_space = gym.spaces.Dict({
"agent": gym.spaces.Box(shape=(initial_n_atoms, 3), low=-np.inf, high=np.inf, dtype=np.float32),
"target": gym.spaces.Box(shape=(initial_n_atoms, 3), low=-np.inf, high=np.inf, dtype=np.float32),
})
def _get_obs(self):
"""Helper function to get the current observation dictionary."""
return {
"agent_info_pos": self._agent_info_pos.clone(),
"target_info_pos": self._target_info_pos.clone(),
"protein_info": self._curr_protein_atoms_pos.clone(),
"target_density": self._curr_target_density.clone(),
"water_ele": copy.deepcopy(self._curr_water_ele),
"protein_ele": copy.deepcopy(self._curr_protein_ele)
}
def _get_info(self):
"""Helper function to get the current information dictionary."""
mse = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
return {"distance": mse}
def _create_protein_sequence(self):
"""Creates the sequence of protein indices for training repeats."""
protein_indices = list(range(self.n_proteins))
self.protein_sequence = protein_indices * self.n_repeats
self.total_sequence_length = len(self.protein_sequence)
def _load_protein(self, protein_idx):
"""Loads a specific protein and its pre-computed density map."""
self.current_protein_idx = protein_idx
self._target_info_pos = self._original_water_data_pos[protein_idx].clone()
self._agent_info_pos = self._original_water_data_pos[protein_idx].clone()
self._curr_protein_atoms_pos = self._original_protein_atoms_pos[protein_idx].clone()
self._curr_target_density = self._original_ground_truth_maps[protein_idx]
self._curr_water_ele = self._original_water_ele[protein_idx]
self._curr_protein_ele = self._original_protein_ele[protein_idx]
self.n_atoms = self._original_water_data_pos[protein_idx].shape[0]
obs_shape = (self.n_atoms, 3)
self.observation_space = gym.spaces.Dict({
"agent": gym.spaces.Box(shape=obs_shape, low=-np.inf, high=np.inf, dtype=np.float32),
"target": gym.spaces.Box(shape=obs_shape, low=-np.inf, high=np.inf, dtype=np.float32),
})
# Build KD-tree from target density coordinates (first 3 columns) - convert to CPU numpy for scipy
density_coords_cpu = self._curr_target_density[:, :3].cpu().numpy()
self.kd_tree = spatial.KDTree(density_coords_cpu)
def reset(self, seed=42, options=None):
"""Resets the environment to a new protein and adds noise to water positions."""
super().reset(seed=seed)
self.protein_sequence_idx = (self.protein_sequence_idx + 1) % len(self.protein_sequence)
protein_idx = self.protein_sequence[self.protein_sequence_idx]
self._load_protein(protein_idx)
torch.manual_seed(seed)
# Add noise to create starting positions
noise = torch.randn_like(self._target_info_pos, device=device) * self.noise
self._agent_info_pos = self._target_info_pos + noise
# ... rest of the function remains the same ...
density_vals = self._curr_target_density[:, 3]
agent_pos_cpu = self._agent_info_pos.cpu().numpy()
self._density_lookup = torch.zeros(self.n_atoms, device=device)
for atom_idx in range(self.n_atoms):
indices = self.kd_tree.query_ball_point(agent_pos_cpu[atom_idx], self.query_radius)
self._density_lookup[atom_idx] = torch.sum(density_vals[indices]) if len(indices) > 0 else 0.0
obs = self._get_obs()
info = self._get_info()
return obs, info
def select_worst_atom(self):
min_density = torch.min(self._density_lookup)
worst_atoms = torch.where(self._density_lookup == min_density)[0]
return worst_atoms[0].item()
def step(self, atom, movement):
target_density = self._curr_target_density
mse_before = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
# Get density from the atom's last known position
density_before = self._density_lookup[atom].item()
# Apply movement
new_coords = self._agent_info_pos.clone()
new_coords[atom] += movement
self._agent_info_pos = new_coords
mse_after = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
# Query density at the new position - convert to CPU numpy for KDTree
agent_pos_cpu = self._agent_info_pos[atom].detach().cpu().numpy()
indices = self.kd_tree.query_ball_point(agent_pos_cpu, self.query_radius)
density_after = torch.sum(target_density[indices, 3]).item() if len(indices) > 0 else 0.0
# OPTIMIZED: Vectorized proximity check
moved_atom_pos = self._agent_info_pos[atom]
# Check distances to other water atoms
mask = torch.ones(self.n_atoms, dtype=torch.bool, device=device)
mask[atom] = False
# Calculate distances to all other water atoms
other_water_positions = self._agent_info_pos[mask]
water_distances = torch.norm(other_water_positions - moved_atom_pos, dim=1)
water_violations = (water_distances < self.proximity_threshold).sum().item()
# Check distances to protein atoms (vectorized)
protein_distances = torch.norm(self._curr_protein_atoms_pos - moved_atom_pos.unsqueeze(0), dim=1)
protein_violations = (protein_distances < self.proximity_threshold).sum().item()
distance_penalty = mse_after * 0.1
#reward = (mse_before - mse_after) * self.mse_weight + (density_after - density_before) * self.density_weight
#reward = -mse_after
#reward = np.clip(reward, -10.0, 10.0)
if mse_before - mse_after > 0:
reward = (mse_before - mse_after) * 2.5
else:
reward = (mse_before - mse_after)
# Update lookup table with new density for next move
self._density_lookup[atom] = density_after
#print(f"Atom {atom}: dist_before={mse_before:.2f}, dist_after={mse_after:.2f}, reward={reward:.2f}")
terminated = mse_after < self.termination_mse
if terminated:
reward += self.termination_reward
return self._get_obs(), reward, terminated, False, self._get_info()
def step_test(self, atom, movement):
"""Applies an action during testing (no reward calculation)."""
new_coords = self._agent_info_pos.clone()
new_coords[atom] += movement
self._agent_info_pos = new_coords
mse_after = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
terminated = mse_after < self.termination_mse
obs = self._get_obs()
info = self._get_info()
return obs, 0.0, terminated, False, info
for iteration in tqdm(range(n_repeats), desc="Training Iterations"):
print(f"\n{'='*80}")
print(f"=== Training Iteration {iteration + 1}/{n_repeats} ===")
print(f"{'='*80}")
obs_dict, info = my_env.reset()
actor_losses, critic_losses, episode_rewards = [], [], []
# === TRAINING PHASE ===
print(f"\n\tTRAINING PHASE")
for step in tqdm(range(max_steps_per_iteration), desc="\tTraining steps", unit="step"):
# Build full system for the network
full_protein_elements = np.concatenate((obs_dict['protein_ele'], obs_dict['water_ele']))
full_protein_coords = torch.vstack((obs_dict["protein_info"], obs_dict["agent_info_pos"]))
agent_edge_index, agent_edge_weight = build_knn_graph(full_protein_coords)
# 1. Get the LOCAL water index from the environment (e.g., 3)
atom_to_move_local = my_env.select_worst_atom()
# 2. IMMEDIATELY convert it to the GLOBAL index for the network
num_protein_atoms = len(obs_dict["protein_info"])
atom_to_move_global = atom_to_move_local + num_protein_atoms
# 3. Use the GLOBAL index to get movement from the actor network
with torch.no_grad():
movement = actor_network(
full_protein_coords,
full_protein_elements,
agent_edge_index,
atom_to_move_global # <-- CORRECTED
).squeeze(0)
# Add exploration noise during training
exploration_noise = torch.randn_like(movement) * 0.01
movement = movement + exploration_noise
# 4. The environment step still uses the LOCAL index
next_obs_dict, reward, terminated, truncated, info = my_env.step(atom_to_move_local, movement)
episode_rewards.append(reward)
# Also get and correct the index for the NEXT state
next_atom_to_move_local = my_env.select_worst_atom()
next_num_protein_atoms = len(next_obs_dict["protein_info"])
next_atom_to_move_global = next_atom_to_move_local + next_num_protein_atoms
# Build next state's full system
full_protein_elements_next = np.concatenate((next_obs_dict['protein_ele'], next_obs_dict['water_ele']))
full_protein_coords_next = torch.vstack((next_obs_dict["protein_info"], next_obs_dict["agent_info_pos"]))
next_agent_edge_index, next_agent_edge_weight = build_knn_graph(full_protein_coords_next)
if reward > 0: # Positive reward = moving closer
for _ in range(100):
my_buffer.add(
obs_dict,
movement.detach(),
reward,
terminated,
next_obs_dict,
agent_edge_index,
agent_edge_weight,
next_agent_edge_index,
next_agent_edge_weight,
atom_to_move_global,
next_atom_to_move_global
)
else:
my_buffer.add(
obs_dict,
movement.detach(),
reward,
terminated,
next_obs_dict,
agent_edge_index,
agent_edge_weight,
next_agent_edge_index,
next_agent_edge_weight,
atom_to_move_global,
next_atom_to_move_global
)
# Train if we have enough samples in the buffer
if len(my_buffer) >= 10:
batch = my_buffer.sample(batch_size=64)
# Batch the current state graphs together
batched_pos, batched_elem, batched_edge_idx, batched_edge_wt, batched_atom_idx, batch_idx = batch_graphs(
batch.full_coords,
batch.full_elements,
batch.edge_index,
batch.edge_weight,
batch.atom_moved
)
# Batch the next state graphs together
next_batched_pos, next_batched_elem, next_batched_edge_idx, next_batched_edge_wt, next_batched_atom_idx, next_batch_idx = batch_graphs(
batch.next_full_coords,
batch.next_full_elements,
batch.next_edge_index,
batch.next_edge_weight,
batch.next_atom_moved
)
# === CRITIC UPDATE ===
with torch.no_grad():
next_actions = actor_target_network(
next_batched_pos,
next_batched_elem,
next_batched_edge_idx,
next_batched_atom_idx,
next_batch_idx
)
# Target policy smoothing
noise = torch.randn_like(next_actions) * policy_noise_std
noise = torch.clamp(noise, -0.5, 0.5)
next_actions = next_actions + noise
next_q1, next_q2 = critic_target_network(
next_batched_pos,
next_batched_elem,
next_batched_edge_idx,
next_actions,
next_batched_atom_idx,
next_batch_idx
)
next_q_value = torch.min(next_q1, next_q2)
target_q_values = batch.rewards + (1 - batch.terminated.float()) * gamma * next_q_value
current_q1, current_q2 = critic_network(
batched_pos,
batched_elem,
batched_edge_idx,
batch.actions,
batched_atom_idx,
batch_idx
)
critic_loss_1 = F.mse_loss(current_q1, target_q_values)
critic_loss_2 = F.mse_loss(current_q2, target_q_values)
critic_loss = critic_loss_1 + critic_loss_2
critic_losses.append(critic_loss.item())
critic_optimizer.zero_grad()
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(critic_network.parameters(), 1.0)
critic_optimizer.step()
if step % policy_delay == 0:
predict_actions = actor_network(
batched_pos,
batched_elem,
batched_edge_idx,
batched_atom_idx,
batch_idx
)
q1_actor, _ = critic_network(
batched_pos,
batched_elem,
batched_edge_idx,
predict_actions,
batched_atom_idx,
batch_idx
)
actor_loss = -q1_actor.mean()
actor_losses.append(actor_loss.item())
actor_optimizer.zero_grad()
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(actor_network.parameters(), 1.0)
actor_optimizer.step()
polyak_update(actor_network.parameters(), actor_target_network.parameters(), tau)
polyak_update(critic_network.parameters(), critic_target_network.parameters(), tau)
obs_dict = next_obs_dict
if terminated or truncated:
obs_dict, info = my_env.reset()
# Print training summary
print(f"\n\tTraining Iteration {iteration + 1} Complete!")
if actor_losses and critic_losses:
print(f"\tAverage actor loss: {np.mean(actor_losses):.4f}")
print(f"\tAverage critic loss: {np.mean(critic_losses):.4f}")
print(f"\tAverage reward: {np.mean(episode_rewards):.4f}")
# === TRAINING METRICS VISUALIZATION ===
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
fig.suptitle(f'Training Metrics - Iteration {iteration + 1}', fontsize=14)
# Plot 1: Actor and Critic Losses
if actor_losses and critic_losses:
ax1_twin = axes[0].twinx()
axes[0].plot(critic_losses, label='Critic Loss', color='blue', alpha=0.7)
ax1_twin.plot([i * policy_delay for i in range(len(actor_losses))],
actor_losses, label='Actor Loss', color='red', alpha=0.7)
axes[0].set_xlabel('Training Step')
axes[0].set_ylabel('Critic Loss', color='blue')
ax1_twin.set_ylabel('Actor Loss', color='red')
axes[0].set_title('Training Losses Over Steps')
axes[0].grid(True, alpha=0.3)
axes[0].legend(loc='upper left')
ax1_twin.legend(loc='upper right')
# Plot 2: Rewards over time
window_size = min(10, len(episode_rewards) // 10) if len(episode_rewards) > 0 else 1
if window_size > 1:
moving_avg = np.convolve(episode_rewards, np.ones(window_size)/window_size, mode='valid')
moving_avg_steps = range(window_size//2, window_size//2 + len(moving_avg))
axes[1].plot(moving_avg_steps, moving_avg, color='green', linewidth=2)
else:
axes[1].plot(episode_rewards, color='green', alpha=0.6)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Reward')
axes[1].set_title('Episode Rewards Over Steps')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# === EVALUATION PHASE ===
print(f"\n\tEVALUATION PHASE - Testing on Train Set")
actor_network.eval() # Set to evaluation mode
# Evaluate on training protein(s)
with torch.no_grad():
eval_env = PDB_Env(
train_data['water_indices'],
train_data['protein_indices'],
train_data['water_elements'],
train_data['protein_elements'],
train_data['ground_truth_maps'],
n_repeats=1,
termination_reward=0, # No bonus during eval
termination_mse=1
)
obs_dict_eval, info_eval = eval_env.reset()
initial_agent_coords = obs_dict_eval['agent_info_pos'].clone()
num_waters = initial_agent_coords.shape[0]
move_counts = torch.zeros(num_waters, dtype=torch.int64, device=device)
eval_step = 0
# Calculate initial metrics
coverage_init, precision_init = cov_prec_at_threshold(
initial_agent_coords, obs_dict_eval['target_info_pos']
)
print(f"\t Initial: Distance={info_eval['distance']:.4f}, "
f"Coverage={coverage_init:.4f}, Precision={precision_init:.4f}")
intial_distance = info_eval['distance']
# Run deterministic evaluation episode
while eval_step < max_eval_steps:
full_protein_elements_eval = np.concatenate((
obs_dict_eval['protein_ele'], obs_dict_eval['water_ele']
))
full_protein_coords_eval = torch.vstack((
obs_dict_eval["protein_info"], obs_dict_eval["agent_info_pos"]
))
agent_edge_index_eval, _ = build_knn_graph(full_protein_coords_eval)
# Get worst atom (deterministic selection)
atom_to_move_eval_local = eval_env.select_worst_atom()
num_protein_atoms_eval = len(obs_dict_eval["protein_info"]) # <-- Use the correct dictionary
atom_to_move_eval_global = atom_to_move_eval_local + num_protein_atoms_eval
# Get deterministic movement (no noise)
movement_eval = actor_network(
full_protein_coords_eval,
full_protein_elements_eval,
agent_edge_index_eval,
atom_to_move_eval_global
).squeeze(0)
obs_dict_eval, _, terminated_eval, truncated_eval, info_eval = eval_env.step_test(
atom_to_move_eval_local, movement_eval
)
eval_step += 1
move_counts[atom_to_move_eval_local] += 1
if terminated_eval or truncated_eval:
break
# Calculate final metrics
final_info_eval = eval_env._get_info()
agent_coords_final = eval_env._agent_info_pos
target_coords_eval = eval_env._target_info_pos
protein_coords = obs_dict_eval['protein_info']
final_distance = final_info_eval['distance']
coverage_final, precision_final = cov_prec_at_threshold(
agent_coords_final, target_coords_eval
)
print(f"\t Final: Distance={final_distance:.4f}, "
f"Coverage={coverage_final:.4f}, Precision={precision_final:.4f}")
print(f"\t Improvement: Distance={info_eval['distance'] - final_distance:.4f}")
# === EVALUATION VISUALIZATION ===
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle(f'Evaluation Results - Iteration {iteration + 1}', fontsize=14)
# Plot 1: Atom Movement Frequency
ax1.bar(range(len(move_counts)), move_counts.cpu().numpy(),
color='steelblue', alpha=0.7)
ax1.set_xlabel('Water Atom Index', fontsize=12)
ax1.set_ylabel('Times Moved', fontsize=12)
ax1.set_title('Atom Movement Frequency During Evaluation', fontsize=13)
ax1.grid(axis='y', linestyle='--', alpha=0.3)
# Plot 2: Spatial Positions Comparison
initial_np = initial_agent_coords.cpu().numpy()
agent_np = agent_coords_final.cpu().numpy()
target_np = target_coords_eval.cpu().numpy()
protein_np = protein_coords.cpu().numpy()
# Protein atoms (background)
ax2.scatter(protein_np[:, 0], protein_np[:, 1],
c='grey', s=20, alpha=0.3, label='Protein', zorder=1)
# Starting positions (green circles)
ax2.scatter(initial_np[:, 0], initial_np[:, 1],
c='green', s=80, alpha=0.6, label='Starting Position',
marker='o', edgecolors='darkgreen', linewidth=1, zorder=2)
# Ground truth targets (blue X marks)
ax2.scatter(target_np[:, 0], target_np[:, 1],
c='blue', s=100, alpha=0.8, label='Ground Truth (Target)',
marker='x', linewidth=2, zorder=3)
# Agent final positions (red triangles)
ax2.scatter(agent_np[:, 0], agent_np[:, 1],
c='red', s=60, alpha=0.7, label='Agent Final',
marker='^', edgecolors='darkred', linewidth=1, zorder=4)
ax2.set_xlabel('X Coordinate', fontsize=12)
ax2.set_ylabel('Y Coordinate', fontsize=12)
ax2.set_title('Final Positions Comparison (XY Plane)', fontsize=13)
ax2.legend(loc='upper right', fontsize=10)
ax2.grid(True, alpha=0.2)
ax2.axis('equal')
# Add metrics text box
textstr = (f'Initial Distance: {intial_distance:.3f}\n'
f'Final Distance: {final_distance:.3f}\n'
f'Coverage: {coverage_final:.3f}\n'
f'Precision: {precision_final:.3f}\n'
f'Steps: {eval_step}')
props = dict(boxstyle='round', facecolor='lightgreen', alpha=0.7)
ax2.text(0.02, 0.98, textstr, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', bbox=props, weight='semibold')
plt.tight_layout()
plt.show()
actor_network.train() # Set back to training mode
print(f"\n{'='*80}\n")