A robot ant. A maze. A frozen dataset of someone else's wandering. Train a policy that's better than the data — without ever interacting with the world. Every equation, every line of code, every gotcha.
Two new ideas land at once in HW3: a new environment, and a new way to learn.
An 8-DOF "ant" robot must navigate a maze to reach a goal. The state is the ant's joints and body pose plus its position in the maze (~29 numbers). The action is continuous joint torques in [-1, 1]8.
Two variants:
The reward is sparse — like HW2, but inverted in sign:
So the return per episode is essentially "minus the number of steps to the goal." A successful 50-step run gets −50; failure gets −episode_length (often −700).
In HW2 you collected new rollouts after every gradient update (PPO) or constantly added to a replay buffer (off-policy SAC). Here you don't run the policy in the environment at all until evaluation. Training reads from a fixed dataset of someone else's transitions, period.
Learn a policy from a fixed dataset D = {(s, a, r, s', d)} of transitions, with no further interaction with the environment during training. The data was collected by some other policy (mixture of experts, scripted controllers, prior random walks — anything).
This is the most realistic RL setting you'll see in this course. In healthcare, robotics, dialogue systems — you usually cannot let an untrained policy explore. People die, robots break, customers leave. You have a log of past behavior and must learn from it.
Advantage-Weighted Actor-Critic (Nair et al., 2020). The headline idea:
The result: a policy that resembles the dataset (so it doesn't query Q at out-of-distribution actions and get fooled) but biased toward the dataset's good actions.
Pure BC: imitate everything. Pure Q-maximization: drift to wherever Q is highest, even if Q is hallucinated nonsense at out-of-distribution actions. AWAC: imitate what's in the dataset, but lean toward the parts of the dataset where the value is good. Best of both.
To appreciate why AWAC is built the way it is, we need to be precise about what offline RL is and is not.
| Regime | Data source | Examples |
|---|---|---|
| On-policy | Fresh rollouts under the current policy. Discard after a few epochs. | PPO, REINFORCE (HW2 P2) |
| Off-policy | Replay buffer that grows as the agent acts. Mix of past + recent. | SAC, DQN (HW2 P3) |
| Offline | Fixed dataset, never grows. Agent never acts in the environment. | AWAC, IQL (HW3) |
Off-policy and offline sound similar. The difference is critical: off-policy can add new data when its current Q-estimates are wrong about some action, by trying that action and seeing what happens. Offline cannot.
Has:
Doesn't have:
The first absence is the killer. The whole story of modern offline RL is: how do we prevent the Q-function from becoming unreliable at actions the dataset doesn't cover?
Run vanilla SAC on a fixed dataset (no env interaction). What happens?
y = r + γ min(Q1_target, Q2_target)(s', a') where a' ~ π(·|s'). The actor's actions a' are not in the dataset — the actor was randomly initialized.(s', a'). Its prediction is whatever the random initialization plus a few gradient steps say. Could be anything.This is sometimes called extrapolation error. Online RL fixes it by trying the bad action and getting a reality check from the env. Offline RL has to fix it through algorithm design.
This is the most important conceptual chapter of the homework. Internalize this and AWAC's design becomes obvious.
Imagine the action space is a 2D plane. The dataset's actions form a cloud somewhere in that plane — maybe a blob in the upper-left. Outside the blob, no data exists.
The Q-network is a neural net. It interpolates and extrapolates. Inside the blob, the network's predictions are anchored by training data. Outside, predictions are pure extrapolation. There's nothing constraining what Q outputs at OOD actions — not the data, not the loss.
If you query the Q-net at an action it has seen, you get a meaningful answer. If you query it at an action it has never seen, you get whatever the network's inductive biases produce — often a wildly inaccurate value, frequently too high.
The actor's job is to find argmaxa Q(s, a). If Q is overestimated at some OOD action a*, the actor will gladly select a*. Now the critic gets trained against backups that involve a*. But there's no real-world signal saying a* is bad — we're offline. The error compounds.
This is sometimes called the "actor-critic feedback loop": the actor exploits errors in the critic, which feed back into the critic via TD targets, which makes the errors grow.
Modern offline RL has two main families, and AWAC sits between them.
| Family | Idea | Examples |
|---|---|---|
| Constrain the policy | Force the policy to stay close to the data distribution. Don't let it pick OOD actions. | BCQ, BEAR, AWAC |
| Constrain the Q-function | Modify Q-learning to deliberately underestimate at OOD actions, so the actor never wants them. | CQL, IQL (Problem 2!) |
AWAC's approach: keep the policy close to the dataset by training it via weighted behavior cloning. The weights bias toward high-advantage actions, but the cloning structure ensures the policy stays in-distribution. Q-learning continues normally; OOD is avoided by the actor never going there.
(IQL takes the other approach — it modifies Q-learning so it never queries OOD actions to begin with. We'll cover IQL in Problem 2.)
One equation summarizes the entire actor update:
Read it piece by piece:
πψ(a | s): the policy you're training, parameterized by ψ.log πψ(a | s): log-probability the policy assigns to action a at state s.(s, a) ~ D: state-action pairs from the offline dataset. Crucial: the actions come from the data, not from sampling the current policy. This is what keeps the actor in-distribution.A(s, a): advantage = Q(s, a) − V(s). Positive: this action was better than the policy's average. Negative: worse.exp(A / λ): exponential weight. Always positive. Larger when A is larger.λ: temperature. Small λ = aggressive selectivity (only the best actions get weight). Large λ = mild selectivity (all actions get similar weight).Pure behavior cloning would minimize:
AWAC is exactly this with one extra factor: exp(A / λ). So AWAC reduces to BC if you set the weights to 1 (i.e., if all advantages are equal). The weight is what "tilts" the imitation toward better-than-average actions.
1. Weighted maximum likelihood: imitate the dataset, but with each sample weighted by exp(A/λ).
2. Soft policy improvement: starting from the data distribution, take a small KL-bounded step toward higher-Q actions.
3. Reward-weighted regression: a generalization of REINFORCE where instead of weighting log-probs by raw return, we weight by exp-of-advantage.
All three are correct. The exp weighting is the unifying feature.
This is the critical insight. The actor's gradient is computed only on (s, a) pairs from the dataset. The actor is never asked "what's a good action at state s?" — it's only ever asked "given that someone took action a at state s, how much should we increase log π(a|s)?" The advantage weighting decides how much, but the action a is always from the dataset.
So the actor cannot drift to OOD actions during training. Its log-prob distribution stays anchored on dataset actions. The critic's OOD predictions never come up because the actor never proposes OOD actions for the critic to evaluate.
(Wait — doesn't the critic itself sample a' ~ π(·|s') in its TD target? Yes. But because the actor stays in-distribution, those samples are also approximately in-distribution. Cycle resolved.)
The exp(A/λ) weight isn't pulled out of a hat. It falls out of solving a constrained optimization problem. The derivation is worth seeing once.
We want a new policy π that improves over the data policy πD, but doesn't drift too far:
"Maximize expected advantage under the new policy, while keeping the new policy within KL-distance ε of the data policy." Read ε as how far we're willing to deviate.
Setting up the Lagrangian and solving for π gives an analytical optimum:
where λ is the Lagrange multiplier corresponding to the KL constraint.
Read this: the optimal new policy is the data policy reweighted by exp of advantage. High-advantage actions get amplified, low-advantage actions get attenuated. The temperature λ controls how aggressive the reweighting is.
You don't need to actually solve a constrained optimization problem at every gradient step. You just need to fit a parametric policy πψ to match the optimal π*. Since π* is a reweighted version of the data, fitting via weighted maximum likelihood does the job.
To fit πψ to π*, minimize KL divergence:
Substituting the form of π* and dropping terms that don't depend on ψ:
Now use the fact that π*(a|s) ∝ πD(a|s) exp(A/λ), and the (s, a) samples in the dataset come from πD:
Negate to get a loss to minimize:
This is exactly the equation from the homework PDF, derived from first principles. The exp weighting is the unique form that solves the KL-constrained policy improvement problem.
| λ | exp(A/λ) behavior | Resulting policy |
|---|---|---|
| Very small (→ 0) | Sharp peak around the highest-advantage action | Effectively argmax over dataset actions — aggressive but risks overfitting |
| Moderate (~1) | Moderate reweighting | Tilted BC — the AWAC sweet spot |
| Very large | All weights ~ 1 | Reduces to plain BC — safe but doesn't improve over data |
The homework defaults to a small λ (0.3 in the AWAC paper, configurable here). You won't tune it for AWAC, but Problem 2's IQL has a similar temperature you'll sweep.
Look at the starter code:
exp_weights = exp_weights.clamp(max=50.0)
Why? When advantages are large and λ is small, exp(A/λ) can overflow. exp(20) is already 4.8e8; exp(40) is 2.4e17. Clamping to 50 keeps things sane — even one outlier that gets weight 50 in a batch of 256 is fine, but exp(100) would dominate everything.
The clamping breaks the strict mathematical equivalence to the constrained optimization — but in practice, large weights would be dominated by sampling noise anyway. Clamping is a numerical regularization that almost every implementation uses.
Now the other half: the critic. AWAC's critic is essentially the same as Problem 3 of HW2 (SAC) — double Q-networks with target networks and clipped target.
Three components, each familiar from HW2:
a' come from?This is the one step where AWAC's critic touches the policy. To compute min(Q̄1, Q̄2)(s', a'), we need an action a' at the next state s'. We sample one from the current actor:
This is what makes the algorithm an "actor-critic" rather than offline Q-learning. The critic's target depends on the actor's choices.
Why this is OK in offline RL: because the actor is trained via AWAC's weighted-BC objective, it stays close to the data distribution. So its sampled a' are approximately in-distribution, and the critic's target is approximately reliable.
If the actor drifts to OOD actions, the critic's targets become unreliable, which feeds back into actor advantages, which feeds back to actor updates. AWAC works because the BC anchor is strong enough to keep the actor in-distribution. With a smaller offline dataset or a weaker BC anchor, this coupling can still break.
Same as HW2: if d=1 (terminal), zero out the bootstrap. The full target:
You'll write this expression in awac_critic.py.
Notice every step. The actor never sees its own samples for training — only dataset (s, a) pairs reweighted by advantage. The critic only sees a' sampled from the actor in one place: computing the TD target. Both ingredients combined keep the algorithm in-distribution.
The AWAC implementation spans three files. Each maps cleanly to one part of the algorithm.
| File | Class | Responsibility |
|---|---|---|
policies/MLP_policy.py | MLPPolicyAWAC | Actor — the AWR loss |
critics/awac_critic.py | AWACCritic | Double-Q critic — TD loss |
agents/awac_agent.py | AWACAgent | Orchestrator — advantage estimation + train loop |
Look at MLPPolicy.forward in MLP_policy.py:109-125. It builds a MultivariateNormal distribution from a learned mean (passed through tanh) and a diagonal covariance from learned logstd:
batch_mean = torch.tanh(self.mean_net(observation)) # shape [B, ac_dim] clipped_logstd = torch.clamp(self.logstd, min=−5.0, max=2.0) std = torch.exp(clipped_logstd) * float(temperature) scale_tril = torch.diag(std) batch_scale_tril = scale_tril.repeat(batch_dim, 1, 1) action_distribution = distributions.MultivariateNormal( batch_mean, scale_tril=batch_scale_tril, )
Notable:
scale_tril is the Cholesky factor; for diagonal covariance it's just diag(std).This is your first edit target. Look at MLP_policy.py:164-202:
def update(self, observations, actions, adv_n=None): # ... convert to tensors ... dist = self(observations) log_prob_n = dist.log_prob(actions) # TODO: Use adv_n and self.lambda_awac to compute exponential weights. exp_weights = None # YOUR CODE exp_weights = exp_weights.clamp(max=50.0) actor_loss = -(log_prob_n * exp_weights).mean() self.optimizer.zero_grad() actor_loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) self.optimizer.step()
Most of the wiring is done. You just need to fill in the exp weight calculation. We'll get to it in Chapter 09.
The second edit target. Look at awac_critic.py:134-188. The wiring:
def update(self, ob_no, ac_na, next_ob_no, reward_n, terminal_n, next_actions): # ... convert to tensors ... # YOUR CODE: compute loss for q_net and q_net2. loss = None loss2 = None self.optimizer.zero_grad() (loss + loss2).backward() utils.clip_grad_norm_(...) self.optimizer.step()
You'll fill in: compute the TD target with get_target_q, then MSE losses for each Q-network. The helper _get_q_value handles the (discrete vs continuous) cases for you.
The third edit target. Two blocks in awac_agent.py:
estimate_advantage() at lines 70-98: compute A(s, a) = Q(s, a) − V(s).train() at lines 100-151: sample next_actions for the critic, then call critic.update + actor.update.Together, these wire everything up.
Every blank you'll fill in, with line-by-line annotation. This is the centerpiece chapter.
Where: MLP_policy.py:188-191.
The math:
Already in scope: adv_n — tensor of shape [B], advantage for each sample. self.lambda_awac — scalar λ.
The code:
exp_weights = torch.exp(adv_n / self.lambda_awac)
One line, three operations.
1. adv_n / self.lambda_awac — elementwise division. adv_n is shape [B], self.lambda_awac is a Python float, so the result is shape [B] with each element adv_n[i] / lambda.
2. torch.exp(...) — elementwise exponential. Output shape [B], all positive (exp is always positive).
3. The result is the per-sample weight. Multiplying by log_prob_n on the next line gives the AWAC loss summands.
Why this works: actions with high advantage (better than average) get large weights; actions with low advantage (worse than average) get small weights. The MLE update on the next line moves the policy toward high-weight actions and away from low-weight actions.
Why torch.exp instead of Python's math.exp or NumPy's np.exp? Because adv_n is a torch tensor, possibly on the GPU, and the result needs to support backprop downstream. math.exp works on Python scalars only. np.exp would force a CPU detour. torch.exp is the only one that's GPU-native and gradient-safe.
(In this specific case adv_n is detached — the advantage is treated as a fixed number for the policy update — so backprop through exp doesn't matter. But using torch.exp is the right reflex regardless.)
Where: awac_critic.py:171-177.
The math:
Already in scope:
ob_no, ac_na, next_ob_no, reward_n, terminal_n — all tensors.next_actions — a', sampled from the actor by the agent (passed in).self.gamma — discount.self._get_q_value(q_net, obs, actions) — helper that returns Q(s, a) for one Q-network.self.get_target_q(obs, actions) — helper that returns min(Q̄1, Q̄2)(s, a).self.mse_loss — nn.MSELoss() instance.The code:
with torch.no_grad(): target_q = self.get_target_q(next_ob_no, next_actions) target = reward_n + self.gamma * (1.0 − terminal_n) * target_q q1 = self._get_q_value(self.q_net, ob_no, ac_na) q2 = self._get_q_value(self.q_net2, ob_no, ac_na) loss = self.mse_loss(q1, target) loss2 = self.mse_loss(q2, target)
Everything inside this block is computed without tracking gradients. The TD target is a regression label — the loss is (prediction − target)2, and only the prediction should produce gradients. If you forget no_grad, gradients flow through the target into the target networks, the target moves while we're chasing it, and training diverges.
Calls the helper at awac_critic.py:119-132. It runs both target Q-networks on (s', a'), takes the elementwise min, returns shape [B]. This is min(Q̄1(s', a'), Q̄2(s', a')) — the clipped double-Q estimate of "value at next state."
Critical: it uses self.q_net_target and self.q_net2_target internally, NOT the online self.q_net and self.q_net2. Mixing these up is the most common bug in this homework.
The Bellman target. Three pieces, all shape [B]:
• reward_n — immediate reward.
• self.gamma * target_q — discounted future value.
• (1.0 - terminal_n) — the done mask. Where terminal_n[i] = 1, this is 0, killing the bootstrap. Where terminal_n[i] = 0, this is 1, keeping the bootstrap.
The whole thing: "reward I just got, plus γ times the target's estimate of next-state value, but don't bootstrap if the episode ended."
OUT of the no_grad block now. We do want gradients on these — this is the prediction we're training.
The helper handles the (discrete vs continuous) action representation. For continuous actions (AntMaze), it does self.q_net(obs, actions).squeeze(-1) — concatenates obs and action, runs through the Q-network, returns shape [B]. For discrete, it gathers the relevant logit. You don't need to think about it; just call the helper.
Same thing for the second online critic. We train BOTH critics — even though only the target networks are used for the target. Training both keeps them independent enough that min(target_q1, target_q2) remains a meaningful pessimistic estimate.
Mean squared error: mean( (q1 - target)2 ), returns a scalar. nn.MSELoss by default reduces with mean. PyTorch optimizers minimize, so this loss tries to make q1 match target.
Same for the second critic. The downstream code does (loss + loss2).backward(), which sums their gradients. Each Q-network only has parameters in its own loss, so the sum's gradient w.r.t. each network's parameters is just that network's MSE gradient.
1. Forgetting no_grad: target networks get trained too, training diverges. The target is supposed to be a slow-moving teacher, not a student.
2. Using q_net instead of q_net_target in target: now the target moves at every gradient step. Same divergence.
3. Forgetting (1 - terminal_n): bootstraps on terminal states, where Q is undefined. The first time the dataset has terminal transitions, training corrupts.
4. Shape mismatch: reward_n is [B]; if your target_q somehow becomes [B, 1], broadcasting silently produces [B, B]. Print shapes liberally if loss numbers look weird.
Where: awac_agent.py:89-97.
The math:
Already in scope:
ob_no, ac_na — tensors, the dataset state-action pair.self.actor — the policy (callable as self.actor(obs) returns a distribution).self.critic.get_q(obs, actions) — returns min(Q1, Q2)(s, a), shape [B].The code:
# inside `with torch.no_grad():` q_sa = self.critic.get_q(ob_no, ac_na) # Q(s, a) for dataset action action_dist = self.actor(ob_no) # pi(.|s) ac_pi = action_dist.sample() # sample one action per state v_s = self.critic.get_q(ob_no, ac_pi) # Q(s, a_pi) ~ V(s) adv = q_sa − v_s
Q-value at the dataset's actual action. get_q returns min(Q1, Q2)(s, a) using the online Q-networks (not target — we want our most recent estimate of value, even if noisier). Shape [B].
Note: this is computed inside with torch.no_grad(): so the actor's gradient won't flow back through the critic. The actor only learns from the log π term and treats advantage as a fixed weight.
Forward through the actor to get a MultivariateNormal distribution per state. Shape information: the distribution's mean has shape [B, ac_dim], and the covariance is [B, ac_dim, ac_dim].
Draw one action per state from the policy distribution. Shape [B, ac_dim].
This is a single-sample Monte Carlo estimate of Ea~π[Q(s, a)]. We're approximating an expectation with a single sample. Higher-quality implementations sometimes use multiple samples and average; AWAC uses one for speed and finds it works well enough.
Plug the policy's sampled action into the critic. get_q returns min(Q1, Q2)(s, a_pi). This is our estimate of V(s) = Ea~π[Q(s, a)].
Why this works as a value baseline: by definition, V(s) is the expected Q under the current policy. We approximate the expectation by sampling once. Crude but unbiased.
Advantage. Shape [B]. Positive: the dataset action was better than what the policy would currently choose. Negative: the dataset action was worse.
Used downstream as the weight on log π(a|s) in the actor update. Positive advantage → large weight → push log-prob of dataset action up. Negative advantage → small weight → barely push.
The actor update is -(log_prob * exp(A/λ)).mean(). If we used raw Q instead of A, the weights would be exp(Q/λ). But Q can be uniformly large (like all values being -50 in this maze) without saying anything about which action is better. Subtracting V acts as a state-dependent baseline, isolating the per-action signal. Same role as the value baseline in PPO's policy gradient.
Where: awac_agent.py:121-126 (next_actions) and awac_agent.py:135-138 (actor_loss).
The code:
with torch.no_grad(): next_obs_t = ptu.from_numpy(next_ob_no) next_dist = self.actor(next_obs_t) next_actions = next_dist.sample()
next_ob_no arrives as a NumPy array; the actor needs a torch tensor on the right device. ptu.from_numpy is the homework's helper for that conversion.
Forward through the actor at each next-state. Returns a MultivariateNormal distribution batched over states.
One action per next-state, shape [B, ac_dim]. This is the a' ~ π(·|s') in the critic's TD target.
Why no_grad: we don't want gradients flowing from critic loss back into the actor through this sample. The actor has its own update path. Mixing them would be a bug.
The code:
adv_n = self.estimate_advantage(ob_no, ac_na) actor_loss = self.actor.update(ob_no, ac_na, adv_n=adv_n)
Calls the function you wrote in Change 3. Returns shape [B] tensor of advantages, one per dataset transition.
Calls the function you wrote in Change 1, passing the advantages as the third argument. The actor's update internally computes the exp-weighted MLE loss, calls backward, and steps its optimizer. Returns a scalar loss for logging.
The keyword adv_n=adv_n is required because the function signature is update(self, observations, actions, adv_n=None). Without the keyword, you'd be passing the advantage as the third positional, which is fine in this case but less explicit.
Now you can read the whole train method and understand exactly what each line does:
1. Sample minibatch (already done by RL_Trainer).
2. Sample next-actions from the policy at the next states (your block 1).
3. Update the critic with those next-actions in the TD target (calls the function you wrote in Change 2).
4. Estimate advantages for the actor (calls the function you wrote in Change 3).
5. Update the actor with weighted MLE (calls the function you wrote in Change 1).
6. Soft-update target networks.
One iteration of AWAC. Loop a million times. The agent never touches the environment.
Same setup as HW2 P3:
# Install Modal in the conda env conda create -n cs224r-hw3 python=3.10.19 conda activate cs224r-hw3 pip install modal # Authenticate (browser device flow) modal setup # Redeem course credits at modal.com/credits # Wandb secret on Modal modal secret create wandb-secret WANDB_API_KEY=<your_key> --force
While iterating on bugs:
cd /Users/ozyphus/Documents/GitHub/cs224r/homework/hw3/hw3
modal run --detach modal_train.py --algo awac \
--env-name antmaze-umaze-v0 \
--exp-name awac_antmaze_umaze \
--use-wandb --seed 1
One container, one seed. Faster turnaround if you hit a bug.
The submission requires 3 seeds per experiment for the table. The provided script launches all 3 in parallel:
# Part 1: U-maze (~1 hour) modal run --detach modal_train_para.py --algo awac \ --env-name antmaze-umaze-v0 \ --exp-name awac_antmaze_umaze \ --use-wandb # Part 2: medium maze (~1.5 hours) modal run --detach modal_train_para.py --algo awac \ --env-name antmaze-medium-diverse-v0 \ --exp-name awac_antmaze_medium_diverse \ --use-wandb
The --detach flag means you can close your terminal — jobs run on Modal regardless. Watch progress via the wandb URL printed at startup.
| Metric | Healthy behavior | Bug signal |
|---|---|---|
Critic Loss | Decreases over training, eventually stable in some range | Stays at 0 or explodes to 1e6+ |
Critic Loss2 | Roughly tracks Critic Loss | Wildly different from Loss1 |
Actor Loss | Negative, slowly drifts as advantages shift | Positive (means BC is broken) |
Eval_AverageReturn | Starts near −700, climbs toward 0 over training | Stays at −700 (not learning) or drops below the data's max return (overfitting) |
mkdir -p data modal volume get cs224r-hw3-results / ./data/
Or just one experiment:
modal volume get cs224r-hw3-results /awac_antmaze_umaze/ ./data/
For each of the 3 seeds, in wandb:
Eval_AverageReturn plot.awac_umaze_seed1.csv, awac_umaze_seed2.csv, etc.Then organize:
P1/ ├── 1/ │ ├── awac_umaze_seed1.csv │ ├── awac_umaze_seed2.csv │ └── awac_umaze_seed3.csv └── 2/ ├── awac_medium_maze_seed1.csv ├── awac_medium_maze_seed2.csv └── awac_medium_maze_seed3.csv
For each environment, take the final-checkpoint Eval_AverageReturn for each of the 3 seeds, then compute mean and standard deviation across the 3 numbers. Fill into Tables 1 (umaze) and 2 (medium-diverse).
| Call | What it returns | File |
|---|---|---|
self.actor(obs) | MultivariateNormal distribution | MLP_policy.py |
action_dist.sample() | Sampled actions, shape [B, ac_dim] | torch.distributions |
action_dist.log_prob(actions) | Log-probs, shape [B] | torch.distributions |
self.critic.get_q(obs, actions) | min(Q1, Q2)(s, a) using ONLINE nets, shape [B] | awac_critic.py:104 |
self.critic.get_target_q(obs, actions) | min(Q̄1, Q̄2)(s, a) using TARGET nets, shape [B] | awac_critic.py:119 |
self._get_q_value(q_net, obs, actions) | Q-value from a single network, shape [B] | awac_critic.py:97 |
ptu.from_numpy(arr) | Tensor on the correct device | infrastructure.pytorch_util |
min(Q1, Q2) instead of just Q in the TD target?get_target_q use the target networks but get_q use the online ones?estimate_advantage wrapped in with torch.no_grad():?V(s) ≈ Q(s, aπ) with aπ ~ π a single-sample estimate? Could we do better?(1 - terminal_n) factor do in the TD target?1. Off-policy buffers grow as the agent acts, mixing past + recent data. Offline uses a fixed dataset, never grows, no env interaction. Off-policy can self-correct via new env experience; offline cannot.
2. The Q-network's predictions at OOD actions are unreliable. The actor finds OOD actions where Q is overestimated. The critic gets trained on backups using those bad actions, but never has reality to correct them. Errors compound.
3. It biases the maximum-likelihood objective toward dataset actions with high advantage. Plain BC (weight=1 everywhere) imitates indiscriminately; AWAC tilts toward better-than-average actions.
4. λ → ∞: all weights become 1, AWAC reduces to plain BC. λ → 0: weights become a sharp peak at the highest-advantage action, effectively argmax over dataset actions per state — risky, can overfit.
5. Because (s, a) from the dataset are guaranteed in-distribution. If the actor were trained on its own samples (s, π(s)), it could drift to OOD actions and the entire OOD-avoidance argument collapses.
6. Maximization bias. max over noisy estimates is biased upward; the min counters this with downward bias. Combining two independent critics' minima produces a conservative target.
7. Target networks should move slowly (Polyak averaging) to provide a stable regression target. Online networks are what you're optimizing — they move fast. Computing Q for the actor's value baseline uses online (latest estimate); computing TD targets uses target (stable).
8. Advantages are weights for the actor's loss, treated as fixed numbers. We don't want gradients flowing through Q into the critic when computing the policy update — that would create an unwanted coupling. The actor's log_prob is the only thing that should produce policy gradient.
9. True V(s) = expectation over π. We use one Monte Carlo sample. Could do better by averaging K samples (lower variance) or by analytically computing V via a separate network (which is what IQL does in Problem 2). AWAC's single sample is a tradeoff between accuracy and compute.
10. Zeros out the bootstrap when the episode ended. Terminal states have no future, so the target should just be the reward, not r + γ Q(s', a'). The (1-d) mask handles this elegantly without an if/else.
11. Numerical stability. With small λ and large advantages, exp(A/λ) can overflow to inf. Clamping at 50 caps the worst single-sample weight, preventing one outlier from dominating a batch.
12. Both. The dataset provides (s, a) pairs and log-probs to differentiate. The critic provides advantages that weight those log-probs. Without either, the actor wouldn't have a useful update.
Once all four are correct, launch the parallel run on umaze. ~1 hour. If Eval_AverageReturn rises from −700 toward something near −100 by the end, AWAC is working.
Three big ideas, in order of importance:
If a friend asks: "Why does naive offline Q-learning fail?" — you say: "Because Q-networks extrapolate badly. They produce wildly inaccurate values at actions the dataset doesn't cover. The actor finds those errors and exploits them, but offline you can't correct via real experience. Errors compound. AWAC's fix is to keep the policy close to the dataset via weighted BC; IQL's fix is to never query Q at out-of-distribution actions in the first place. Both work, both are widely used."
You can teach this. Onward to IQL.