← back
Stanford CS 224R · Homework 3 · Offline RL

Advantage-Weighted Actor-Critic from Absolute Zero

A robot ant. A maze. A frozen dataset of someone else's wandering. Train a policy that's better than the data — without ever interacting with the world. Every equation, every line of code, every gotcha.

Builds on HW2 The offline RL paradigm Four implementation tasks ~1500 lines
Roadmap

What You'll Master

Chapter 01

The AntMaze Setup

Two new ideas land at once in HW3: a new environment, and a new way to learn.

The environment: AntMaze

An 8-DOF "ant" robot must navigate a maze to reach a goal. The state is the ant's joints and body pose plus its position in the maze (~29 numbers). The action is continuous joint torques in [-1, 1]8.

Two variants:

The reward is sparse — like HW2, but inverted in sign:

Reward r(s, a) = 0 when the goal is reached r(s, a) = -1 every other step

So the return per episode is essentially "minus the number of steps to the goal." A successful 50-step run gets −50; failure gets −episode_length (often −700).

The big shift: no environment interaction

In HW2 you collected new rollouts after every gradient update (PPO) or constantly added to a replay buffer (off-policy SAC). Here you don't run the policy in the environment at all until evaluation. Training reads from a fixed dataset of someone else's transitions, period.

Definition
Offline RL

Learn a policy from a fixed dataset D = {(s, a, r, s', d)} of transitions, with no further interaction with the environment during training. The data was collected by some other policy (mixture of experts, scripted controllers, prior random walks — anything).

This is the most realistic RL setting you'll see in this course. In healthcare, robotics, dialogue systems — you usually cannot let an untrained policy explore. People die, robots break, customers leave. You have a log of past behavior and must learn from it.

What you're implementing — AWAC

Advantage-Weighted Actor-Critic (Nair et al., 2020). The headline idea:

The result: a policy that resembles the dataset (so it doesn't query Q at out-of-distribution actions and get fooled) but biased toward the dataset's good actions.

The intuition

Pure BC: imitate everything. Pure Q-maximization: drift to wherever Q is highest, even if Q is hallucinated nonsense at out-of-distribution actions. AWAC: imitate what's in the dataset, but lean toward the parts of the dataset where the value is good. Best of both.

Chapter 02

What "Offline RL" Means

To appreciate why AWAC is built the way it is, we need to be precise about what offline RL is and is not.

The three regimes

RegimeData sourceExamples
On-policyFresh rollouts under the current policy. Discard after a few epochs.PPO, REINFORCE (HW2 P2)
Off-policyReplay buffer that grows as the agent acts. Mix of past + recent.SAC, DQN (HW2 P3)
OfflineFixed dataset, never grows. Agent never acts in the environment.AWAC, IQL (HW3)

Off-policy and offline sound similar. The difference is critical: off-policy can add new data when its current Q-estimates are wrong about some action, by trying that action and seeing what happens. Offline cannot.

What offline RL has and doesn't have

Has:

Doesn't have:

The first absence is the killer. The whole story of modern offline RL is: how do we prevent the Q-function from becoming unreliable at actions the dataset doesn't cover?

Why HW2's algorithms don't work here

Run vanilla SAC on a fixed dataset (no env interaction). What happens?

  1. Critic update: y = r + γ min(Q1_target, Q2_target)(s', a') where a' ~ π(·|s'). The actor's actions a' are not in the dataset — the actor was randomly initialized.
  2. The Q-network has never seen (s', a'). Its prediction is whatever the random initialization plus a few gradient steps say. Could be anything.
  3. If Q(s', a') happens to be erroneously high for some out-of-distribution a', the actor learns "go to that action."
  4. Now the actor concentrates on that action. The critic sees more of it but can never check — it's still out-of-distribution.
  5. The Q-value continues to grow. The actor more aggressively chases it. Diverge.

This is sometimes called extrapolation error. Online RL fixes it by trying the bad action and getting a reality check from the env. Offline RL has to fix it through algorithm design.

Chapter 03

The Distributional Shift Problem

This is the most important conceptual chapter of the homework. Internalize this and AWAC's design becomes obvious.

The "uncovered region" problem

Imagine the action space is a 2D plane. The dataset's actions form a cloud somewhere in that plane — maybe a blob in the upper-left. Outside the blob, no data exists.

The Q-network is a neural net. It interpolates and extrapolates. Inside the blob, the network's predictions are anchored by training data. Outside, predictions are pure extrapolation. There's nothing constraining what Q outputs at OOD actions — not the data, not the loss.

The fundamental asymmetry

If you query the Q-net at an action it has seen, you get a meaningful answer. If you query it at an action it has never seen, you get whatever the network's inductive biases produce — often a wildly inaccurate value, frequently too high.

The exploitation trap

The actor's job is to find argmaxa Q(s, a). If Q is overestimated at some OOD action a*, the actor will gladly select a*. Now the critic gets trained against backups that involve a*. But there's no real-world signal saying a* is bad — we're offline. The error compounds.

This is sometimes called the "actor-critic feedback loop": the actor exploits errors in the critic, which feed back into the critic via TD targets, which makes the errors grow.

Two families of solutions

Modern offline RL has two main families, and AWAC sits between them.

FamilyIdeaExamples
Constrain the policyForce the policy to stay close to the data distribution. Don't let it pick OOD actions.BCQ, BEAR, AWAC
Constrain the Q-functionModify Q-learning to deliberately underestimate at OOD actions, so the actor never wants them.CQL, IQL (Problem 2!)

AWAC's approach: keep the policy close to the dataset by training it via weighted behavior cloning. The weights bias toward high-advantage actions, but the cloning structure ensures the policy stays in-distribution. Q-learning continues normally; OOD is avoided by the actor never going there.

(IQL takes the other approach — it modifies Q-learning so it never queries OOD actions to begin with. We'll cover IQL in Problem 2.)

Chapter 04

The AWAC Idea

One equation summarizes the entire actor update:

AWAC actor loss Lπ(ψ) = − 𝔼(s, a) ~ D[ log πψ(a | s) · exp( A(s, a) / λ ) ]

Read it piece by piece:

Comparison to plain BC

Pure behavior cloning would minimize:

Plain BC LBC(ψ) = − 𝔼(s, a) ~ D[ log πψ(a | s) ]

AWAC is exactly this with one extra factor: exp(A / λ). So AWAC reduces to BC if you set the weights to 1 (i.e., if all advantages are equal). The weight is what "tilts" the imitation toward better-than-average actions.

Three equivalent ways to read AWAC

1. Weighted maximum likelihood: imitate the dataset, but with each sample weighted by exp(A/λ).

2. Soft policy improvement: starting from the data distribution, take a small KL-bounded step toward higher-Q actions.

3. Reward-weighted regression: a generalization of REINFORCE where instead of weighting log-probs by raw return, we weight by exp-of-advantage.

All three are correct. The exp weighting is the unifying feature.

Why this stays in-distribution

This is the critical insight. The actor's gradient is computed only on (s, a) pairs from the dataset. The actor is never asked "what's a good action at state s?" — it's only ever asked "given that someone took action a at state s, how much should we increase log π(a|s)?" The advantage weighting decides how much, but the action a is always from the dataset.

So the actor cannot drift to OOD actions during training. Its log-prob distribution stays anchored on dataset actions. The critic's OOD predictions never come up because the actor never proposes OOD actions for the critic to evaluate.

(Wait — doesn't the critic itself sample a' ~ π(·|s') in its TD target? Yes. But because the actor stays in-distribution, those samples are also approximately in-distribution. Cycle resolved.)

Chapter 05

Why the exp Weighting

The exp(A/λ) weight isn't pulled out of a hat. It falls out of solving a constrained optimization problem. The derivation is worth seeing once.

The optimization problem

We want a new policy π that improves over the data policy πD, but doesn't drift too far:

Constrained policy improvement maxπ 𝔼s ~ D, a ~ π[ A(s, a) ] subject to: 𝔼s ~ D[ DKL( π(·|s) || πD(·|s) ) ] ≤ ε

"Maximize expected advantage under the new policy, while keeping the new policy within KL-distance ε of the data policy." Read ε as how far we're willing to deviate.

The closed-form solution

Setting up the Lagrangian and solving for π gives an analytical optimum:

Optimal π (Lagrangian) π*(a | s) ∝ πD(a | s) · exp( A(s, a) / λ )

where λ is the Lagrange multiplier corresponding to the KL constraint.

Read this: the optimal new policy is the data policy reweighted by exp of advantage. High-advantage actions get amplified, low-advantage actions get attenuated. The temperature λ controls how aggressive the reweighting is.

Why this matters

You don't need to actually solve a constrained optimization problem at every gradient step. You just need to fit a parametric policy πψ to match the optimal π*. Since π* is a reweighted version of the data, fitting via weighted maximum likelihood does the job.

From optimal policy to weighted MLE

To fit πψ to π*, minimize KL divergence:

πψ* = arg minψ 𝔼s ~ D[ DKL( π*(·|s) || πψ(·|s) ) ]

Substituting the form of π* and dropping terms that don't depend on ψ:

πψ* = arg maxψ 𝔼s ~ D[ Σa π*(a|s) log πψ(a|s) ]

Now use the fact that π*(a|s) ∝ πD(a|s) exp(A/λ), and the (s, a) samples in the dataset come from πD:

πψ* = arg maxψ 𝔼(s, a) ~ D[ exp(A(s, a) / λ) · log πψ(a|s) ]

Negate to get a loss to minimize:

AWAC actor loss (final) Lπ(ψ) = − 𝔼(s, a) ~ D[ exp(A(s, a) / λ) · log πψ(a|s) ]

This is exactly the equation from the homework PDF, derived from first principles. The exp weighting is the unique form that solves the KL-constrained policy improvement problem.

What λ controls

λexp(A/λ) behaviorResulting policy
Very small (→ 0)Sharp peak around the highest-advantage actionEffectively argmax over dataset actions — aggressive but risks overfitting
Moderate (~1)Moderate reweightingTilted BC — the AWAC sweet spot
Very largeAll weights ~ 1Reduces to plain BC — safe but doesn't improve over data

The homework defaults to a small λ (0.3 in the AWAC paper, configurable here). You won't tune it for AWAC, but Problem 2's IQL has a similar temperature you'll sweep.

Numerical stability: the clamp

Look at the starter code:

exp_weights = exp_weights.clamp(max=50.0)

Why? When advantages are large and λ is small, exp(A/λ) can overflow. exp(20) is already 4.8e8; exp(40) is 2.4e17. Clamping to 50 keeps things sane — even one outlier that gets weight 50 in a batch of 256 is fine, but exp(100) would dominate everything.

A subtle point

The clamping breaks the strict mathematical equivalence to the constrained optimization — but in practice, large weights would be dominated by sampling noise anyway. Clamping is a numerical regularization that almost every implementation uses.

Chapter 06

The Double-Q Critic

Now the other half: the critic. AWAC's critic is essentially the same as Problem 3 of HW2 (SAC) — double Q-networks with target networks and clipped target.

The TD loss

AWAC critic loss y = r(s, a) + γ (1 − d) · min(Q̄1(s', a'), Q̄2(s', a')) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ target networks, no grad LQ(φ) = 𝔼(s,a,r,s',d) ~ D[ (Qφ1(s, a) − y)2 + (Qφ2(s, a) − y)2 ]

Three components, each familiar from HW2:

Where does a' come from?

This is the one step where AWAC's critic touches the policy. To compute min(Q̄1, Q̄2)(s', a'), we need an action a' at the next state s'. We sample one from the current actor:

a' ~ πψ(· | s')

This is what makes the algorithm an "actor-critic" rather than offline Q-learning. The critic's target depends on the actor's choices.

Why this is OK in offline RL: because the actor is trained via AWAC's weighted-BC objective, it stays close to the data distribution. So its sampled a' are approximately in-distribution, and the critic's target is approximately reliable.

The fragile coupling

If the actor drifts to OOD actions, the critic's targets become unreliable, which feeds back into actor advantages, which feeds back to actor updates. AWAC works because the BC anchor is strong enough to keep the actor in-distribution. With a smaller offline dataset or a weaker BC anchor, this coupling can still break.

Terminal handling

Same as HW2: if d=1 (terminal), zero out the bootstrap. The full target:

y = r + γ · (1 − d) · min(Q̄1(s', a'), Q̄2(s', a'))

You'll write this expression in awac_critic.py.

Chapter 07

The Full Algorithm

AWAC: Advantage-Weighted Actor-Critic
  1. Initialize:
    • Two Q-networks Qφ1, Qφ2 and two target networks Q̄1, Q̄2 (copies of online).
    • Policy πψ (random init).
    • Offline dataset D loaded into replay buffer (no growth).
  2. For step = 1 to N (e.g., 1M):
    a) Sample minibatch (s, a, r, s', d) from D.
    b) Sample next-action: a' ~ πψ(· | s'). [Used only by the critic; gradients OFF.]
    c) Critic update:
    • y = r + γ (1 − d) min(Q̄1(s', a'), Q̄2(s', a')) [no_grad]
    • LQ = MSE(Qφ1(s, a), y) + MSE(Qφ2(s, a), y)
    • Step critic optimizer.
    d) Actor update:
    • Advantage estimate: A(s, a) = Q(s, a) − V(s) where V(s) = Q(s, aπ) and aπ ~ πψ(·|s) [no_grad]
    • Weights: w = exp(A / λ), clamped at 50.
    • Lπ = − mean( w · log πψ(a | s) )
    • Step actor optimizer.
    e) Soft-update target critics: Q̄ ← (1 − τ) Q̄ + τ Q.
  3. Periodically: roll out πψ in the env (eval only, never adds to D) and log average return.

Notice every step. The actor never sees its own samples for training — only dataset (s, a) pairs reweighted by advantage. The critic only sees a' sampled from the actor in one place: computing the TD target. Both ingredients combined keep the algorithm in-distribution.

Chapter 08

Code Tour

The AWAC implementation spans three files. Each maps cleanly to one part of the algorithm.

File map

FileClassResponsibility
policies/MLP_policy.pyMLPPolicyAWACActor — the AWR loss
critics/awac_critic.pyAWACCriticDouble-Q critic — TD loss
agents/awac_agent.pyAWACAgentOrchestrator — advantage estimation + train loop

The MLPPolicy base class

Look at MLPPolicy.forward in MLP_policy.py:109-125. It builds a MultivariateNormal distribution from a learned mean (passed through tanh) and a diagonal covariance from learned logstd:

batch_mean = torch.tanh(self.mean_net(observation))    # shape [B, ac_dim]
clipped_logstd = torch.clamp(self.logstd, min=−5.0, max=2.0)
std = torch.exp(clipped_logstd) * float(temperature)
scale_tril = torch.diag(std)
batch_scale_tril = scale_tril.repeat(batch_dim, 1, 1)
action_distribution = distributions.MultivariateNormal(
    batch_mean,
    scale_tril=batch_scale_tril,
)

Notable:

The MLPPolicyAWAC update

This is your first edit target. Look at MLP_policy.py:164-202:

def update(self, observations, actions, adv_n=None):
    # ... convert to tensors ...

    dist = self(observations)
    log_prob_n = dist.log_prob(actions)

    # TODO: Use adv_n and self.lambda_awac to compute exponential weights.
    exp_weights = None      # YOUR CODE

    exp_weights = exp_weights.clamp(max=50.0)
    actor_loss = -(log_prob_n * exp_weights).mean()

    self.optimizer.zero_grad()
    actor_loss.backward()
    torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
    self.optimizer.step()

Most of the wiring is done. You just need to fill in the exp weight calculation. We'll get to it in Chapter 09.

The AWACCritic update

The second edit target. Look at awac_critic.py:134-188. The wiring:

def update(self, ob_no, ac_na, next_ob_no, reward_n, terminal_n, next_actions):
    # ... convert to tensors ...

    # YOUR CODE: compute loss for q_net and q_net2.
    loss = None
    loss2 = None

    self.optimizer.zero_grad()
    (loss + loss2).backward()
    utils.clip_grad_norm_(...)
    self.optimizer.step()

You'll fill in: compute the TD target with get_target_q, then MSE losses for each Q-network. The helper _get_q_value handles the (discrete vs continuous) cases for you.

The AWACAgent

The third edit target. Two blocks in awac_agent.py:

  1. estimate_advantage() at lines 70-98: compute A(s, a) = Q(s, a) − V(s).
  2. train() at lines 100-151: sample next_actions for the critic, then call critic.update + actor.update.

Together, these wire everything up.

Chapter 09

Your Four Changes, Decoded

Every blank you'll fill in, with line-by-line annotation. This is the centerpiece chapter.

Change 1 of 4
MLPPolicyAWAC.update — the exp weights

Where: MLP_policy.py:188-191.

The math:

exp_weights[i] = exp( A(si, ai) / λ )

Already in scope: adv_n — tensor of shape [B], advantage for each sample. self.lambda_awac — scalar λ.

The code:

exp_weights = torch.exp(adv_n / self.lambda_awac)

Decoded

exp_weights = torch.exp(adv_n / self.lambda_awac)

One line, three operations.

1. adv_n / self.lambda_awac — elementwise division. adv_n is shape [B], self.lambda_awac is a Python float, so the result is shape [B] with each element adv_n[i] / lambda.

2. torch.exp(...) — elementwise exponential. Output shape [B], all positive (exp is always positive).

3. The result is the per-sample weight. Multiplying by log_prob_n on the next line gives the AWAC loss summands.

Why this works: actions with high advantage (better than average) get large weights; actions with low advantage (worse than average) get small weights. The MLE update on the next line moves the policy toward high-weight actions and away from low-weight actions.

A subtle PyTorch detail

Why torch.exp instead of Python's math.exp or NumPy's np.exp? Because adv_n is a torch tensor, possibly on the GPU, and the result needs to support backprop downstream. math.exp works on Python scalars only. np.exp would force a CPU detour. torch.exp is the only one that's GPU-native and gradient-safe.

(In this specific case adv_n is detached — the advantage is treated as a fixed number for the policy update — so backprop through exp doesn't matter. But using torch.exp is the right reflex regardless.)

Change 2 of 4
AWACCritic.update — the TD loss

Where: awac_critic.py:171-177.

The math:

y = r + γ · (1 − d) · min(Q̄1(s', a'), Q̄2(s', a')) no_grad loss1 = MSE( Q1(s, a), y ) loss2 = MSE( Q2(s, a), y )

Already in scope:

  • ob_no, ac_na, next_ob_no, reward_n, terminal_n — all tensors.
  • next_actionsa', sampled from the actor by the agent (passed in).
  • self.gamma — discount.
  • self._get_q_value(q_net, obs, actions) — helper that returns Q(s, a) for one Q-network.
  • self.get_target_q(obs, actions) — helper that returns min(Q̄1, Q̄2)(s, a).
  • self.mse_lossnn.MSELoss() instance.

The code:

with torch.no_grad():
    target_q = self.get_target_q(next_ob_no, next_actions)
    target = reward_n + self.gamma * (1.0 − terminal_n) * target_q

q1 = self._get_q_value(self.q_net, ob_no, ac_na)
q2 = self._get_q_value(self.q_net2, ob_no, ac_na)

loss = self.mse_loss(q1, target)
loss2 = self.mse_loss(q2, target)

Decoded

with torch.no_grad():

Everything inside this block is computed without tracking gradients. The TD target is a regression label — the loss is (prediction − target)2, and only the prediction should produce gradients. If you forget no_grad, gradients flow through the target into the target networks, the target moves while we're chasing it, and training diverges.

target_q = self.get_target_q(next_ob_no, next_actions)

Calls the helper at awac_critic.py:119-132. It runs both target Q-networks on (s', a'), takes the elementwise min, returns shape [B]. This is min(Q̄1(s', a'), Q̄2(s', a')) — the clipped double-Q estimate of "value at next state."

Critical: it uses self.q_net_target and self.q_net2_target internally, NOT the online self.q_net and self.q_net2. Mixing these up is the most common bug in this homework.

target = reward_n + self.gamma * (1.0 - terminal_n) * target_q

The Bellman target. Three pieces, all shape [B]:

reward_n — immediate reward.

self.gamma * target_q — discounted future value.

(1.0 - terminal_n) — the done mask. Where terminal_n[i] = 1, this is 0, killing the bootstrap. Where terminal_n[i] = 0, this is 1, keeping the bootstrap.

The whole thing: "reward I just got, plus γ times the target's estimate of next-state value, but don't bootstrap if the episode ended."

q1 = self._get_q_value(self.q_net, ob_no, ac_na)

OUT of the no_grad block now. We do want gradients on these — this is the prediction we're training.

The helper handles the (discrete vs continuous) action representation. For continuous actions (AntMaze), it does self.q_net(obs, actions).squeeze(-1) — concatenates obs and action, runs through the Q-network, returns shape [B]. For discrete, it gathers the relevant logit. You don't need to think about it; just call the helper.

q2 = self._get_q_value(self.q_net2, ob_no, ac_na)

Same thing for the second online critic. We train BOTH critics — even though only the target networks are used for the target. Training both keeps them independent enough that min(target_q1, target_q2) remains a meaningful pessimistic estimate.

loss = self.mse_loss(q1, target)

Mean squared error: mean( (q1 - target)2 ), returns a scalar. nn.MSELoss by default reduces with mean. PyTorch optimizers minimize, so this loss tries to make q1 match target.

loss2 = self.mse_loss(q2, target)

Same for the second critic. The downstream code does (loss + loss2).backward(), which sums their gradients. Each Q-network only has parameters in its own loss, so the sum's gradient w.r.t. each network's parameters is just that network's MSE gradient.

Common bugs

1. Forgetting no_grad: target networks get trained too, training diverges. The target is supposed to be a slow-moving teacher, not a student.

2. Using q_net instead of q_net_target in target: now the target moves at every gradient step. Same divergence.

3. Forgetting (1 - terminal_n): bootstraps on terminal states, where Q is undefined. The first time the dataset has terminal transitions, training corrupts.

4. Shape mismatch: reward_n is [B]; if your target_q somehow becomes [B, 1], broadcasting silently produces [B, B]. Print shapes liberally if loss numbers look weird.

Change 3 of 4
AWACAgent.estimate_advantage — the advantage

Where: awac_agent.py:89-97.

The math:

A(s, a) = Q(s, a) − V(s) where V(s) ≈ Q(s, aπ), aπ ~ πψ(·|s) single-sample estimate

Already in scope:

  • ob_no, ac_na — tensors, the dataset state-action pair.
  • self.actor — the policy (callable as self.actor(obs) returns a distribution).
  • self.critic.get_q(obs, actions) — returns min(Q1, Q2)(s, a), shape [B].

The code:

# inside `with torch.no_grad():`
q_sa = self.critic.get_q(ob_no, ac_na)               # Q(s, a) for dataset action

action_dist = self.actor(ob_no)                       # pi(.|s)
ac_pi = action_dist.sample()                          # sample one action per state
v_s = self.critic.get_q(ob_no, ac_pi)                 # Q(s, a_pi) ~ V(s)

adv = q_sa − v_s

Decoded

q_sa = self.critic.get_q(ob_no, ac_na)

Q-value at the dataset's actual action. get_q returns min(Q1, Q2)(s, a) using the online Q-networks (not target — we want our most recent estimate of value, even if noisier). Shape [B].

Note: this is computed inside with torch.no_grad(): so the actor's gradient won't flow back through the critic. The actor only learns from the log π term and treats advantage as a fixed weight.

action_dist = self.actor(ob_no)

Forward through the actor to get a MultivariateNormal distribution per state. Shape information: the distribution's mean has shape [B, ac_dim], and the covariance is [B, ac_dim, ac_dim].

ac_pi = action_dist.sample()

Draw one action per state from the policy distribution. Shape [B, ac_dim].

This is a single-sample Monte Carlo estimate of Ea~π[Q(s, a)]. We're approximating an expectation with a single sample. Higher-quality implementations sometimes use multiple samples and average; AWAC uses one for speed and finds it works well enough.

v_s = self.critic.get_q(ob_no, ac_pi)

Plug the policy's sampled action into the critic. get_q returns min(Q1, Q2)(s, a_pi). This is our estimate of V(s) = Ea~π[Q(s, a)].

Why this works as a value baseline: by definition, V(s) is the expected Q under the current policy. We approximate the expectation by sampling once. Crude but unbiased.

adv = q_sa - v_s

Advantage. Shape [B]. Positive: the dataset action was better than what the policy would currently choose. Negative: the dataset action was worse.

Used downstream as the weight on log π(a|s) in the actor update. Positive advantage → large weight → push log-prob of dataset action up. Negative advantage → small weight → barely push.

Why advantage and not raw Q?

The actor update is -(log_prob * exp(A/λ)).mean(). If we used raw Q instead of A, the weights would be exp(Q/λ). But Q can be uniformly large (like all values being -50 in this maze) without saying anything about which action is better. Subtracting V acts as a state-dependent baseline, isolating the per-action signal. Same role as the value baseline in PPO's policy gradient.

Change 4 of 4
AWACAgent.train — tying it together

Where: awac_agent.py:121-126 (next_actions) and awac_agent.py:135-138 (actor_loss).

Block 1: Sample next-actions for the critic backup

The code:

with torch.no_grad():
    next_obs_t = ptu.from_numpy(next_ob_no)
    next_dist = self.actor(next_obs_t)
    next_actions = next_dist.sample()
next_obs_t = ptu.from_numpy(next_ob_no)

next_ob_no arrives as a NumPy array; the actor needs a torch tensor on the right device. ptu.from_numpy is the homework's helper for that conversion.

next_dist = self.actor(next_obs_t)

Forward through the actor at each next-state. Returns a MultivariateNormal distribution batched over states.

next_actions = next_dist.sample()

One action per next-state, shape [B, ac_dim]. This is the a' ~ π(·|s') in the critic's TD target.

Why no_grad: we don't want gradients flowing from critic loss back into the actor through this sample. The actor has its own update path. Mixing them would be a bug.

Block 2: The actor update

The code:

adv_n = self.estimate_advantage(ob_no, ac_na)
actor_loss = self.actor.update(ob_no, ac_na, adv_n=adv_n)
adv_n = self.estimate_advantage(ob_no, ac_na)

Calls the function you wrote in Change 3. Returns shape [B] tensor of advantages, one per dataset transition.

actor_loss = self.actor.update(ob_no, ac_na, adv_n=adv_n)

Calls the function you wrote in Change 1, passing the advantages as the third argument. The actor's update internally computes the exp-weighted MLE loss, calls backward, and steps its optimizer. Returns a scalar loss for logging.

The keyword adv_n=adv_n is required because the function signature is update(self, observations, actions, adv_n=None). Without the keyword, you'd be passing the advantage as the third positional, which is fine in this case but less explicit.

The full critic+actor cycle

Now you can read the whole train method and understand exactly what each line does:

1. Sample minibatch (already done by RL_Trainer).

2. Sample next-actions from the policy at the next states (your block 1).

3. Update the critic with those next-actions in the TD target (calls the function you wrote in Change 2).

4. Estimate advantages for the actor (calls the function you wrote in Change 3).

5. Update the actor with weighted MLE (calls the function you wrote in Change 1).

6. Soft-update target networks.

One iteration of AWAC. Loop a million times. The agent never touches the environment.

Chapter 10

Running on Modal

Pre-flight (one-time)

Same setup as HW2 P3:

# Install Modal in the conda env
conda create -n cs224r-hw3 python=3.10.19
conda activate cs224r-hw3
pip install modal

# Authenticate (browser device flow)
modal setup

# Redeem course credits at modal.com/credits

# Wandb secret on Modal
modal secret create wandb-secret WANDB_API_KEY=<your_key> --force

Single-seed debug runs

While iterating on bugs:

cd /Users/ozyphus/Documents/GitHub/cs224r/homework/hw3/hw3
modal run --detach modal_train.py --algo awac \
  --env-name antmaze-umaze-v0 \
  --exp-name awac_antmaze_umaze \
  --use-wandb --seed 1

One container, one seed. Faster turnaround if you hit a bug.

Three-seed parallel runs (for the deliverable)

The submission requires 3 seeds per experiment for the table. The provided script launches all 3 in parallel:

# Part 1: U-maze (~1 hour)
modal run --detach modal_train_para.py --algo awac \
  --env-name antmaze-umaze-v0 \
  --exp-name awac_antmaze_umaze \
  --use-wandb

# Part 2: medium maze (~1.5 hours)
modal run --detach modal_train_para.py --algo awac \
  --env-name antmaze-medium-diverse-v0 \
  --exp-name awac_antmaze_medium_diverse \
  --use-wandb

The --detach flag means you can close your terminal — jobs run on Modal regardless. Watch progress via the wandb URL printed at startup.

What healthy training looks like

MetricHealthy behaviorBug signal
Critic LossDecreases over training, eventually stable in some rangeStays at 0 or explodes to 1e6+
Critic Loss2Roughly tracks Critic LossWildly different from Loss1
Actor LossNegative, slowly drifts as advantages shiftPositive (means BC is broken)
Eval_AverageReturnStarts near −700, climbs toward 0 over trainingStays at −700 (not learning) or drops below the data's max return (overfitting)

Retrieving results from Modal

mkdir -p data
modal volume get cs224r-hw3-results / ./data/

Or just one experiment:

modal volume get cs224r-hw3-results /awac_antmaze_umaze/ ./data/

Submitting the deliverable

For each of the 3 seeds, in wandb:

  1. Open the Eval_AverageReturn plot.
  2. Click the three dots in the top-right corner.
  3. Download CSV.
  4. Name it according to the convention: awac_umaze_seed1.csv, awac_umaze_seed2.csv, etc.

Then organize:

P1/
├── 1/
│   ├── awac_umaze_seed1.csv
│   ├── awac_umaze_seed2.csv
│   └── awac_umaze_seed3.csv
└── 2/
    ├── awac_medium_maze_seed1.csv
    ├── awac_medium_maze_seed2.csv
    └── awac_medium_maze_seed3.csv

For each environment, take the final-checkpoint Eval_AverageReturn for each of the 3 seeds, then compute mean and standard deviation across the 3 numbers. Fill into Tables 1 (umaze) and 2 (medium-diverse).

Chapter 11

Cheat Sheet & Self-Quiz

Equations to memorize

AWAC actor loss Lπ(ψ) = − mean( log πψ(a|s) · exp(A(s,a) / λ) ) A(s, a) = Q(s, a) − V(s), where V(s) ≈ Q(s, aπ), aπ ~ π(·|s)
AWAC critic loss y = r + γ (1 − d) · min( Q̄1(s', a'), Q̄2(s', a') ) [no_grad, a' ~ π] LQ(φ) = MSE(Q1(s, a), y) + MSE(Q2(s, a), y)
Soft target update φ̄ ← (1 − τ) φ̄ + τ φ, τ = 0.005

Key API reference

CallWhat it returnsFile
self.actor(obs)MultivariateNormal distributionMLP_policy.py
action_dist.sample()Sampled actions, shape [B, ac_dim]torch.distributions
action_dist.log_prob(actions)Log-probs, shape [B]torch.distributions
self.critic.get_q(obs, actions)min(Q1, Q2)(s, a) using ONLINE nets, shape [B]awac_critic.py:104
self.critic.get_target_q(obs, actions)min(Q̄1, Q̄2)(s, a) using TARGET nets, shape [B]awac_critic.py:119
self._get_q_value(q_net, obs, actions)Q-value from a single network, shape [B]awac_critic.py:97
ptu.from_numpy(arr)Tensor on the correct deviceinfrastructure.pytorch_util

Self-quiz

  1. What's the difference between off-policy RL (HW2 P3) and offline RL (HW3)?
  2. Why does naive Q-learning diverge in the offline setting?
  3. What does the exp(A/λ) weight do in the AWAC actor loss?
  4. What happens at λ → ∞? At λ → 0?
  5. Why is the actor's loss computed only on (s, a) pairs from the dataset, not on (s, π(s))?
  6. Why do we need min(Q1, Q2) instead of just Q in the TD target?
  7. Why does get_target_q use the target networks but get_q use the online ones?
  8. Why is estimate_advantage wrapped in with torch.no_grad():?
  9. Why is V(s) ≈ Q(s, aπ) with aπ ~ π a single-sample estimate? Could we do better?
  10. What does the (1 - terminal_n) factor do in the TD target?
  11. Why do we clamp exp_weights at 50?
  12. Where does the actor get its update signal — the critic, the dataset, or both?
Answer key

1. Off-policy buffers grow as the agent acts, mixing past + recent data. Offline uses a fixed dataset, never grows, no env interaction. Off-policy can self-correct via new env experience; offline cannot.

2. The Q-network's predictions at OOD actions are unreliable. The actor finds OOD actions where Q is overestimated. The critic gets trained on backups using those bad actions, but never has reality to correct them. Errors compound.

3. It biases the maximum-likelihood objective toward dataset actions with high advantage. Plain BC (weight=1 everywhere) imitates indiscriminately; AWAC tilts toward better-than-average actions.

4. λ → ∞: all weights become 1, AWAC reduces to plain BC. λ → 0: weights become a sharp peak at the highest-advantage action, effectively argmax over dataset actions per state — risky, can overfit.

5. Because (s, a) from the dataset are guaranteed in-distribution. If the actor were trained on its own samples (s, π(s)), it could drift to OOD actions and the entire OOD-avoidance argument collapses.

6. Maximization bias. max over noisy estimates is biased upward; the min counters this with downward bias. Combining two independent critics' minima produces a conservative target.

7. Target networks should move slowly (Polyak averaging) to provide a stable regression target. Online networks are what you're optimizing — they move fast. Computing Q for the actor's value baseline uses online (latest estimate); computing TD targets uses target (stable).

8. Advantages are weights for the actor's loss, treated as fixed numbers. We don't want gradients flowing through Q into the critic when computing the policy update — that would create an unwanted coupling. The actor's log_prob is the only thing that should produce policy gradient.

9. True V(s) = expectation over π. We use one Monte Carlo sample. Could do better by averaging K samples (lower variance) or by analytically computing V via a separate network (which is what IQL does in Problem 2). AWAC's single sample is a tradeoff between accuracy and compute.

10. Zeros out the bootstrap when the episode ended. Terminal states have no future, so the target should just be the reward, not r + γ Q(s', a'). The (1-d) mask handles this elegantly without an if/else.

11. Numerical stability. With small λ and large advantages, exp(A/λ) can overflow to inf. Clamping at 50 caps the worst single-sample weight, preventing one outlier from dominating a batch.

12. Both. The dataset provides (s, a) pairs and log-probs to differentiate. The critic provides advantages that weight those log-probs. Without either, the actor wouldn't have a useful update.

Implementation order

  1. MLPPolicyAWAC.update exp_weights — trivial. ~30 seconds.
  2. AWACCritic.update TD loss — the most error-prone. Test by running on umaze with seed 1 for a few thousand steps; critic loss should decrease.
  3. AWACAgent.estimate_advantage — ~5 minutes. Easy after critic.
  4. AWACAgent.train next_actions + actor_loss — ~5 minutes. Just plumbing.

Once all four are correct, launch the parallel run on umaze. ~1 hour. If Eval_AverageReturn rises from −700 toward something near −100 by the end, AWAC is working.

Take it back to class

You can now teach this

Three big ideas, in order of importance:

  1. Distributional shift is the central problem of offline RL. The Q-network's predictions are unreliable at actions the dataset doesn't cover. Naive Q-learning exploits these errors and diverges.
  2. AWAC's solution: keep the policy near the data distribution. Train via weighted maximum likelihood on dataset (s, a) pairs, where the weight is exp(A/λ). The MLE structure prevents OOD drift; the weight tilts the policy toward dataset actions with high advantage.
  3. The exp weighting falls out of constrained policy improvement. It's not a heuristic — it's the analytical solution to "maximize advantage under KL-bounded policy change." Same shape as soft Q-learning, soft policy iteration, MaxEnt RL, RLHF reward models — the family is everywhere.

If a friend asks: "Why does naive offline Q-learning fail?" — you say: "Because Q-networks extrapolate badly. They produce wildly inaccurate values at actions the dataset doesn't cover. The actor finds those errors and exploits them, but offline you can't correct via real experience. Errors compound. AWAC's fix is to keep the policy close to the dataset via weighted BC; IQL's fix is to never query Q at out-of-distribution actions in the first place. Both work, both are widely used."

You can teach this. Onward to IQL.