Resume Training
A common question we get asked is how to set up model checkpoints to continue training. In this document, we take this PPO example to explain that question.
Save model checkpoints
The first step is to save models periodically. By default, we save the model to wandb
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14 | num_updates = args.total_timesteps // args.batch_size
CHECKPOINT_FREQUENCY = 50
starting_update = 1
for update in range(starting_update, num_updates + 1):
# ... do rollouts and train models
if args.track:
# make sure to tune `CHECKPOINT_FREQUENCY`
# so models are not saved too frequently
if update % CHECKPOINT_FREQUENCY == 0:
torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt")
wandb.save(f"{wandb.run.dir}/agent.pt", policy="now")
|
Then we could run the following to train our agents
python ppo_gridnet.py --prod-mode --capture_video
If the training was terminated early, we can still see the last updated model agent.pt
in W&B like in this URL https://wandb.ai/costa-huang/cleanRL/runs/21421tda/files or as follows
Resume training
The second step is to automatically download the agent.pt
from the URL above and resume training as follows:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 | num_updates = args.total_timesteps // args.batch_size
CHECKPOINT_FREQUENCY = 50
starting_update = 1
if args.track and wandb.run.resumed:
starting_update = run.summary.get("charts/update") + 1
global_step = starting_update * args.batch_size
api = wandb.Api()
run = api.run(f"{run.entity}/{run.project}/{run.id}")
model = run.file("agent.pt")
model.download(f"models/{experiment_name}/")
agent.load_state_dict(torch.load(
f"models/{experiment_name}/agent.pt", map_location=device))
agent.eval()
print(f"resumed at update {starting_update}")
for update in range(starting_update, num_updates + 1):
# ... do rollouts and train models
if args.track:
# make sure to tune `CHECKPOINT_FREQUENCY`
# so models are not saved too frequently
if update % CHECKPOINT_FREQUENCY == 0:
torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt")
wandb.save(f"{wandb.run.dir}/agent.pt", policy="now")
|
To resume training, note the ID of the experiment is 21421tda
as in the URL https://wandb.ai/costa-huang/cleanRL/runs/21421tda, so we need to pass in the ID via environment variable to trigger the resume mode of W&B:
WANDB_RUN_ID=21421tda WANDB_RESUME=must python ppo_gridnet.py --prod-mode --capture_video