0
$\begingroup$

I am trying to train a TD3 algorithm to place points in 3d space.

However, I am currently not able to even get the model to overfit on a small number of data points.

As far as I can tell, part of the issue is that the episodes mostly have progressively more negative and negative rewards (measured by change in MSE from previous position) leading to a critic that simply always predicts negative q values because the positive rewards as so sparse.

Does anyone have any advice?

Below, I have provided my code for my custom environment and training.

class PDB_Env(gym.Env):
"""A custom Gymnasium environment for protein-water refinement.

This environment now takes pre-computed density maps and XMap objects,
simplifying its role to managing the state and reward calculation during
the simulation.
"""
def __init__(self, water_data_pos, protein_data_pos, water_data_ele, protein_data_ele, ground_truth_maps,
             n_repeats=3, noise=1.5, termination_mse=1, termination_reward=100,
             query_radius=.5, proximity_threshold=1,
             mse_weight = 100.0, density_weight = 1):
    super().__init__()
    self.termination_mse = termination_mse
    self._original_water_data_pos = water_data_pos
    self._original_protein_atoms_pos = protein_data_pos
    self._original_ground_truth_maps = ground_truth_maps
    self._original_water_ele = water_data_ele
    self._original_protein_ele = protein_data_ele
    self.n_proteins = len(self._original_water_data_pos)
    self.n_repeats = n_repeats
    self.noise = noise
    self.termination_reward = termination_reward
    self.termination_mse = termination_mse
    self.query_radius = query_radius
    self.mse_weight = mse_weight
    self.density_weight = density_weight
    self.clash_penalty = 5
    self.proximity_threshold = proximity_threshold
    self._create_protein_sequence()
    self.protein_sequence_idx = -1

    initial_n_atoms = self._original_water_data_pos[0].shape[0] if self.n_proteins > 0 else 0
    self.observation_space = gym.spaces.Dict({
        "agent": gym.spaces.Box(shape=(initial_n_atoms, 3), low=-np.inf, high=np.inf, dtype=np.float32),
        "target": gym.spaces.Box(shape=(initial_n_atoms, 3), low=-np.inf, high=np.inf, dtype=np.float32),
    })

def _get_obs(self):
    """Helper function to get the current observation dictionary."""
    return {
        "agent_info_pos": self._agent_info_pos.clone(),
        "target_info_pos": self._target_info_pos.clone(),
        "protein_info": self._curr_protein_atoms_pos.clone(),
        "target_density": self._curr_target_density.clone(),
        "water_ele": copy.deepcopy(self._curr_water_ele),
        "protein_ele": copy.deepcopy(self._curr_protein_ele)
    }

def _get_info(self):
    """Helper function to get the current information dictionary."""
    mse = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
    return {"distance": mse}

def _create_protein_sequence(self):
    """Creates the sequence of protein indices for training repeats."""
    protein_indices = list(range(self.n_proteins))
    self.protein_sequence = protein_indices * self.n_repeats
    self.total_sequence_length = len(self.protein_sequence)

def _load_protein(self, protein_idx):
    """Loads a specific protein and its pre-computed density map."""
    self.current_protein_idx = protein_idx
    self._target_info_pos = self._original_water_data_pos[protein_idx].clone()
    self._agent_info_pos = self._original_water_data_pos[protein_idx].clone()
    self._curr_protein_atoms_pos = self._original_protein_atoms_pos[protein_idx].clone()
    self._curr_target_density = self._original_ground_truth_maps[protein_idx]
    self._curr_water_ele = self._original_water_ele[protein_idx]
    self._curr_protein_ele = self._original_protein_ele[protein_idx]

    self.n_atoms = self._original_water_data_pos[protein_idx].shape[0]
    obs_shape = (self.n_atoms, 3)
    self.observation_space = gym.spaces.Dict({
        "agent": gym.spaces.Box(shape=obs_shape, low=-np.inf, high=np.inf, dtype=np.float32),
        "target": gym.spaces.Box(shape=obs_shape, low=-np.inf, high=np.inf, dtype=np.float32),
    })

    # Build KD-tree from target density coordinates (first 3 columns) - convert to CPU numpy for scipy
    density_coords_cpu = self._curr_target_density[:, :3].cpu().numpy()
    self.kd_tree = spatial.KDTree(density_coords_cpu)

def reset(self, seed=42, options=None):
    """Resets the environment to a new protein and adds noise to water positions."""
    super().reset(seed=seed)

    self.protein_sequence_idx = (self.protein_sequence_idx + 1) % len(self.protein_sequence)
    protein_idx = self.protein_sequence[self.protein_sequence_idx]
    self._load_protein(protein_idx)

    torch.manual_seed(seed)
    # Add noise to create starting positions
    noise = torch.randn_like(self._target_info_pos, device=device) * self.noise
    self._agent_info_pos = self._target_info_pos + noise

    # ... rest of the function remains the same ...
    density_vals = self._curr_target_density[:, 3]
    agent_pos_cpu = self._agent_info_pos.cpu().numpy()
    self._density_lookup = torch.zeros(self.n_atoms, device=device)

    for atom_idx in range(self.n_atoms):
        indices = self.kd_tree.query_ball_point(agent_pos_cpu[atom_idx], self.query_radius)
        self._density_lookup[atom_idx] = torch.sum(density_vals[indices]) if len(indices) > 0 else 0.0

    obs = self._get_obs()
    info = self._get_info()
    return obs, info

def select_worst_atom(self):    
    min_density = torch.min(self._density_lookup)
    worst_atoms = torch.where(self._density_lookup == min_density)[0]
    return worst_atoms[0].item()  

def step(self, atom, movement):
    target_density = self._curr_target_density

    mse_before = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()

    # Get density from the atom's last known position
    density_before = self._density_lookup[atom].item()

    # Apply movement
    new_coords = self._agent_info_pos.clone()
    new_coords[atom] += movement
    self._agent_info_pos = new_coords

    mse_after = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()

    # Query density at the new position - convert to CPU numpy for KDTree
    agent_pos_cpu = self._agent_info_pos[atom].detach().cpu().numpy()
    indices = self.kd_tree.query_ball_point(agent_pos_cpu, self.query_radius)
    density_after = torch.sum(target_density[indices, 3]).item() if len(indices) > 0 else 0.0

    # OPTIMIZED: Vectorized proximity check
    moved_atom_pos = self._agent_info_pos[atom]

    # Check distances to other water atoms
    mask = torch.ones(self.n_atoms, dtype=torch.bool, device=device)
    mask[atom] = False

    # Calculate distances to all other water atoms
    other_water_positions = self._agent_info_pos[mask]
    water_distances = torch.norm(other_water_positions - moved_atom_pos, dim=1)
    water_violations = (water_distances < self.proximity_threshold).sum().item()

    # Check distances to protein atoms (vectorized)
    protein_distances = torch.norm(self._curr_protein_atoms_pos - moved_atom_pos.unsqueeze(0), dim=1)
    protein_violations = (protein_distances < self.proximity_threshold).sum().item()

    distance_penalty = mse_after * 0.1

    #reward = (mse_before - mse_after) * self.mse_weight + (density_after - density_before) * self.density_weight
    #reward = -mse_after
    #reward = np.clip(reward, -10.0, 10.0)
    if mse_before - mse_after > 0:
        reward = (mse_before - mse_after) * 2.5
    else:
        reward = (mse_before - mse_after)
    # Update lookup table with new density for next move
    self._density_lookup[atom] = density_after

    #print(f"Atom {atom}: dist_before={mse_before:.2f}, dist_after={mse_after:.2f}, reward={reward:.2f}")

    terminated = mse_after < self.termination_mse
    if terminated:
        reward += self.termination_reward

    return self._get_obs(), reward, terminated, False, self._get_info()

def step_test(self, atom, movement):
    """Applies an action during testing (no reward calculation)."""
    new_coords = self._agent_info_pos.clone()
    new_coords[atom] += movement
    self._agent_info_pos = new_coords

    mse_after = torch.mean((self._agent_info_pos - self._target_info_pos) ** 2).item()
    terminated = mse_after < self.termination_mse
    obs = self._get_obs()
    info = self._get_info()
    return obs, 0.0, terminated, False, info
for iteration in tqdm(range(n_repeats), desc="Training Iterations"):
print(f"\n{'='*80}")
print(f"=== Training Iteration {iteration + 1}/{n_repeats} ===")
print(f"{'='*80}")

obs_dict, info = my_env.reset()
actor_losses, critic_losses, episode_rewards = [], [], []

# === TRAINING PHASE ===
print(f"\n\tTRAINING PHASE")
for step in tqdm(range(max_steps_per_iteration), desc="\tTraining steps", unit="step"):
    # Build full system for the network
    full_protein_elements = np.concatenate((obs_dict['protein_ele'], obs_dict['water_ele']))
    full_protein_coords = torch.vstack((obs_dict["protein_info"], obs_dict["agent_info_pos"]))
    agent_edge_index, agent_edge_weight = build_knn_graph(full_protein_coords)

    # 1. Get the LOCAL water index from the environment (e.g., 3)
    atom_to_move_local = my_env.select_worst_atom()

    # 2. IMMEDIATELY convert it to the GLOBAL index for the network
    num_protein_atoms = len(obs_dict["protein_info"])
    atom_to_move_global = atom_to_move_local + num_protein_atoms

    # 3. Use the GLOBAL index to get movement from the actor network
    with torch.no_grad():
        movement = actor_network(
            full_protein_coords,
            full_protein_elements,
            agent_edge_index,
            atom_to_move_global  # <-- CORRECTED
        ).squeeze(0)
        
        # Add exploration noise during training
        exploration_noise = torch.randn_like(movement) * 0.01
        movement = movement + exploration_noise

    # 4. The environment step still uses the LOCAL index
    next_obs_dict, reward, terminated, truncated, info = my_env.step(atom_to_move_local, movement)
    episode_rewards.append(reward)

    # Also get and correct the index for the NEXT state
    next_atom_to_move_local = my_env.select_worst_atom()
    next_num_protein_atoms = len(next_obs_dict["protein_info"])
    next_atom_to_move_global = next_atom_to_move_local + next_num_protein_atoms

    # Build next state's full system
    full_protein_elements_next = np.concatenate((next_obs_dict['protein_ele'], next_obs_dict['water_ele']))
    full_protein_coords_next = torch.vstack((next_obs_dict["protein_info"], next_obs_dict["agent_info_pos"]))
    next_agent_edge_index, next_agent_edge_weight = build_knn_graph(full_protein_coords_next)
    if reward > 0:  # Positive reward = moving closer
        for _ in range(100):
            my_buffer.add(
                obs_dict,
                movement.detach(),
                reward,
                terminated,
                next_obs_dict,
                agent_edge_index,
                agent_edge_weight,
                next_agent_edge_index,
                next_agent_edge_weight,
                atom_to_move_global,       
                next_atom_to_move_global   
            )
    else: 
        my_buffer.add(
            obs_dict,
            movement.detach(),
            reward,
            terminated,
            next_obs_dict,
            agent_edge_index,
            agent_edge_weight,
            next_agent_edge_index,
            next_agent_edge_weight,
            atom_to_move_global,       
            next_atom_to_move_global   
        )

    # Train if we have enough samples in the buffer
    if len(my_buffer) >= 10:
        batch = my_buffer.sample(batch_size=64)

        # Batch the current state graphs together
        batched_pos, batched_elem, batched_edge_idx, batched_edge_wt, batched_atom_idx, batch_idx = batch_graphs(
            batch.full_coords,
            batch.full_elements,
            batch.edge_index,
            batch.edge_weight,
            batch.atom_moved
        )

        # Batch the next state graphs together
        next_batched_pos, next_batched_elem, next_batched_edge_idx, next_batched_edge_wt, next_batched_atom_idx, next_batch_idx = batch_graphs(
            batch.next_full_coords,
            batch.next_full_elements,
            batch.next_edge_index,
            batch.next_edge_weight,
            batch.next_atom_moved
        )

        # === CRITIC UPDATE ===
        with torch.no_grad():
            next_actions = actor_target_network(
                next_batched_pos,
                next_batched_elem,
                next_batched_edge_idx,
                next_batched_atom_idx,
                next_batch_idx
            )

            # Target policy smoothing
            noise = torch.randn_like(next_actions) * policy_noise_std
            noise = torch.clamp(noise, -0.5, 0.5)
            next_actions = next_actions + noise
            next_q1, next_q2 = critic_target_network(
                next_batched_pos,
                next_batched_elem,
                next_batched_edge_idx,
                next_actions,
                next_batched_atom_idx,
                next_batch_idx
            )
            next_q_value = torch.min(next_q1, next_q2)
            target_q_values = batch.rewards + (1 - batch.terminated.float()) * gamma * next_q_value

        current_q1, current_q2 = critic_network(
            batched_pos,
            batched_elem,
            batched_edge_idx,
            batch.actions,
            batched_atom_idx,
            batch_idx
        )

        critic_loss_1 = F.mse_loss(current_q1, target_q_values)
        critic_loss_2 = F.mse_loss(current_q2, target_q_values)
        critic_loss = critic_loss_1 + critic_loss_2
        critic_losses.append(critic_loss.item())

        critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(critic_network.parameters(), 1.0)
        critic_optimizer.step()


        if step % policy_delay == 0:
            predict_actions = actor_network(
                batched_pos,
                batched_elem,
                batched_edge_idx,
                batched_atom_idx,
                batch_idx
            )

            q1_actor, _ = critic_network(
                batched_pos,
                batched_elem,
                batched_edge_idx,
                predict_actions,
                batched_atom_idx,
                batch_idx
            )

            actor_loss = -q1_actor.mean()
            actor_losses.append(actor_loss.item())

            actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor_network.parameters(), 1.0)
            actor_optimizer.step()

           
            polyak_update(actor_network.parameters(), actor_target_network.parameters(), tau)
            polyak_update(critic_network.parameters(), critic_target_network.parameters(), tau)

    obs_dict = next_obs_dict

    if terminated or truncated:
        obs_dict, info = my_env.reset()

# Print training summary
print(f"\n\tTraining Iteration {iteration + 1} Complete!")
if actor_losses and critic_losses:
    print(f"\tAverage actor loss: {np.mean(actor_losses):.4f}")
    print(f"\tAverage critic loss: {np.mean(critic_losses):.4f}")
    print(f"\tAverage reward: {np.mean(episode_rewards):.4f}")

# === TRAINING METRICS VISUALIZATION ===
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
fig.suptitle(f'Training Metrics - Iteration {iteration + 1}', fontsize=14)

# Plot 1: Actor and Critic Losses
if actor_losses and critic_losses:
    ax1_twin = axes[0].twinx()
    axes[0].plot(critic_losses, label='Critic Loss', color='blue', alpha=0.7)
    ax1_twin.plot([i * policy_delay for i in range(len(actor_losses))],
                  actor_losses, label='Actor Loss', color='red', alpha=0.7)
    axes[0].set_xlabel('Training Step')
    axes[0].set_ylabel('Critic Loss', color='blue')
    ax1_twin.set_ylabel('Actor Loss', color='red')
    axes[0].set_title('Training Losses Over Steps')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend(loc='upper left')
    ax1_twin.legend(loc='upper right')

# Plot 2: Rewards over time
window_size = min(10, len(episode_rewards) // 10) if len(episode_rewards) > 0 else 1
if window_size > 1:
    moving_avg = np.convolve(episode_rewards, np.ones(window_size)/window_size, mode='valid')
    moving_avg_steps = range(window_size//2, window_size//2 + len(moving_avg))
    axes[1].plot(moving_avg_steps, moving_avg, color='green', linewidth=2)
else:
    axes[1].plot(episode_rewards, color='green', alpha=0.6)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Reward')
axes[1].set_title('Episode Rewards Over Steps')
axes[1].grid(True, alpha=0.3)


plt.tight_layout()
plt.show()

# === EVALUATION PHASE ===
print(f"\n\tEVALUATION PHASE - Testing on Train Set")

actor_network.eval()  # Set to evaluation mode

# Evaluate on training protein(s)
with torch.no_grad():
    eval_env = PDB_Env(
        train_data['water_indices'],
        train_data['protein_indices'],
        train_data['water_elements'],
        train_data['protein_elements'],
        train_data['ground_truth_maps'],
        n_repeats=1,
        termination_reward=0,  # No bonus during eval
        termination_mse=1
    )

    obs_dict_eval, info_eval = eval_env.reset()
    initial_agent_coords = obs_dict_eval['agent_info_pos'].clone()
    num_waters = initial_agent_coords.shape[0]
    move_counts = torch.zeros(num_waters, dtype=torch.int64, device=device)
    eval_step = 0

    # Calculate initial metrics
    coverage_init, precision_init = cov_prec_at_threshold(
        initial_agent_coords, obs_dict_eval['target_info_pos']
    )
    print(f"\t  Initial: Distance={info_eval['distance']:.4f}, "
          f"Coverage={coverage_init:.4f}, Precision={precision_init:.4f}")
    intial_distance = info_eval['distance']
    # Run deterministic evaluation episode
    while eval_step < max_eval_steps:
        full_protein_elements_eval = np.concatenate((
            obs_dict_eval['protein_ele'], obs_dict_eval['water_ele']
        ))
        full_protein_coords_eval = torch.vstack((
            obs_dict_eval["protein_info"], obs_dict_eval["agent_info_pos"]
        ))
        agent_edge_index_eval, _ = build_knn_graph(full_protein_coords_eval)

        # Get worst atom (deterministic selection)
        atom_to_move_eval_local = eval_env.select_worst_atom()
        num_protein_atoms_eval = len(obs_dict_eval["protein_info"]) # <-- Use the correct dictionary
        atom_to_move_eval_global = atom_to_move_eval_local + num_protein_atoms_eval

        # Get deterministic movement (no noise)
        movement_eval = actor_network(
            full_protein_coords_eval,
            full_protein_elements_eval,
            agent_edge_index_eval,
            atom_to_move_eval_global
        ).squeeze(0)

        obs_dict_eval, _, terminated_eval, truncated_eval, info_eval = eval_env.step_test(
            atom_to_move_eval_local, movement_eval
        )
        eval_step += 1
        move_counts[atom_to_move_eval_local] += 1

        if terminated_eval or truncated_eval:
            break

    # Calculate final metrics
    final_info_eval = eval_env._get_info()
    agent_coords_final = eval_env._agent_info_pos
    target_coords_eval = eval_env._target_info_pos
    protein_coords = obs_dict_eval['protein_info']

    final_distance = final_info_eval['distance']
    coverage_final, precision_final = cov_prec_at_threshold(
        agent_coords_final, target_coords_eval
    )

    print(f"\t  Final: Distance={final_distance:.4f}, "
          f"Coverage={coverage_final:.4f}, Precision={precision_final:.4f}")
    print(f"\t  Improvement: Distance={info_eval['distance'] - final_distance:.4f}")

    # === EVALUATION VISUALIZATION ===
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    fig.suptitle(f'Evaluation Results - Iteration {iteration + 1}', fontsize=14)

    # Plot 1: Atom Movement Frequency
    ax1.bar(range(len(move_counts)), move_counts.cpu().numpy(),
            color='steelblue', alpha=0.7)
    ax1.set_xlabel('Water Atom Index', fontsize=12)
    ax1.set_ylabel('Times Moved', fontsize=12)
    ax1.set_title('Atom Movement Frequency During Evaluation', fontsize=13)
    ax1.grid(axis='y', linestyle='--', alpha=0.3)

    # Plot 2: Spatial Positions Comparison
    initial_np = initial_agent_coords.cpu().numpy()
    agent_np = agent_coords_final.cpu().numpy()
    target_np = target_coords_eval.cpu().numpy()
    protein_np = protein_coords.cpu().numpy()

    # Protein atoms (background)
    ax2.scatter(protein_np[:, 0], protein_np[:, 1],
               c='grey', s=20, alpha=0.3, label='Protein', zorder=1)

    # Starting positions (green circles)
    ax2.scatter(initial_np[:, 0], initial_np[:, 1],
               c='green', s=80, alpha=0.6, label='Starting Position',
               marker='o', edgecolors='darkgreen', linewidth=1, zorder=2)

    # Ground truth targets (blue X marks)
    ax2.scatter(target_np[:, 0], target_np[:, 1],
               c='blue', s=100, alpha=0.8, label='Ground Truth (Target)',
               marker='x', linewidth=2, zorder=3)

    # Agent final positions (red triangles)
    ax2.scatter(agent_np[:, 0], agent_np[:, 1],
               c='red', s=60, alpha=0.7, label='Agent Final',
               marker='^', edgecolors='darkred', linewidth=1, zorder=4)

    ax2.set_xlabel('X Coordinate', fontsize=12)
    ax2.set_ylabel('Y Coordinate', fontsize=12)
    ax2.set_title('Final Positions Comparison (XY Plane)', fontsize=13)
    ax2.legend(loc='upper right', fontsize=10)
    ax2.grid(True, alpha=0.2)
    ax2.axis('equal')

    # Add metrics text box
    textstr = (f'Initial Distance: {intial_distance:.3f}\n'
              f'Final Distance: {final_distance:.3f}\n'
              f'Coverage: {coverage_final:.3f}\n'
              f'Precision: {precision_final:.3f}\n'
              f'Steps: {eval_step}')
    props = dict(boxstyle='round', facecolor='lightgreen', alpha=0.7)
    ax2.text(0.02, 0.98, textstr, transform=ax2.transAxes, fontsize=10,
            verticalalignment='top', bbox=props, weight='semibold')

    plt.tight_layout()
    plt.show()

actor_network.train()  # Set back to training mode

print(f"\n{'='*80}\n")
$\endgroup$

0

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.