Diplomacy RL Part 1

Diplomacy RL Part 1

Ben Glickenhaus

Introduction#

In the last post, I demoed an approach to using trl + OpenEnv + Modal to fine tune an LLM on a toy RL problem. In this series of posts, we'll be exploring how to scale this approach to learn a much more complicated game: Diplomacy.


I've been interested in Diplomacy for a decade now - in some ways it's the anti-RL RL environment: perfect information in terms of game state, but highly imperfect in terms of player intent. Everything that makes Diplomacy hard is a human problem. And that makes it an interesting LLM problem. Especially once we get to full press Diplomacy, we'll learn if an LLM is willing to lie, betray, and backstab if it learns that maximizes reward.


But, we have a long way to go before we get to full press Diplomacy. So grab a snack, get comfy, and let's go for a ride.

Note

Rather than a single "paper" style post with the final code and results, I'm going to do a series of posts walking through the process of building the infrastructure and training the model.


I'll start with the PoC then identify what didn't work and the optimizations I made to get past each barrier.

The building blocks of RL#

Before we even start thinking about machine learning, we need to lay a solid foundation to build on. An RL system - in its simplest form - consists of 3 components:

  • We need a way to simulate lots of games (rollouts) of Diplomacy as quickly as possible
  • We need an agent that can make decisions in the game
    • If the agent is an LLM, we need an inference API to get the agent's actions
  • We need to aggregate everything that happens in the game and compute a reward for each agent

As tempting as it is to start with the modeling problem, I'd argue the most important first step of any RL project is to build a highly scalable rollout engine and test the daylights out of it.

Section 1: The rollout engine#

Thanks to the power of open source, there's a pip installable package to run the actual Diplomacy game engine. We can run this in a Modal image on a CPU for near infinite horizontal scaling. We just need to write a small wrapper around the engine to make it easy to represent the game state and compute rewards for our agents, and do a couple tricks in our rollout function to eventually support GRPO style training.

Diplomacy Rollout Engine

1@app.function(
2 image=cpu_image,
3 cpu=1.0,
4 memory=2048, # Increased memory to hold G game copies
5 timeout=3600,
6)
7async def run_rollout(config_dict: dict, power_adapters: dict[str, str]):
8 BASELINE_BOTS = {
9 "random_bot": RandomBot(),
10 "chaos_bot": ChaosBot(),
11 }
12
13 rollout_start_time = time.time()
14 rollout_id = ""
15 metrics: RolloutMetrics | None = None
16
17 # 1. THE WARMUP (Generate a random state)
18 # ---------------------------------------
19 warmup_phases = random.randint(0, 8)
20 if random.random() < cfg.rollout_no_warmup_chance:
21 warmup_phases = 0
22
23 # Init main game
24 main_game = DiplomacyWrapper(horizon=99)
25 rollout_id = main_game.game.game_id
26 metrics = RolloutMetrics(rollout_id=rollout_id)
27
28 # Log rollout start
29 log_rollout_start(
30 rollout_id=rollout_id,
31 warmup_phases=warmup_phases,
32 samples_per_group=cfg.samples_per_group,
33 horizon_years=cfg.rollout_horizon_years,
34 )
35
36 vis = None
37 if should_visualize:
38 logger.info("Visualizing game...")
39 vis = GameVisualizer()
40 vis.capture_turn(main_game.game, "Warmup Start")
41 logger.info(f"🔥 Starting Warmup: {warmup_phases} phases")
42 # Play through warmup
43 for i in range(warmup_phases):
44 if main_game.is_done():
45 break
46
47 all_orders = []
48 phase = main_game.get_current_phase()
49
50 # Get inputs for all powers
51 inputs = main_game.get_all_inputs(agent=agent)
52
53 # 1. Handle baseline bots directly (no LLM)
54 for idx, power in enumerate(inputs["power_names"]):
55 adapter = power_adapters.get(power)
56 assert adapter is not None and adapter in BASELINE_BOTS
57 bot = BASELINE_BOTS[adapter]
58 orders = bot.get_orders(main_game, power)
59 expected_count = len(inputs["valid_moves"][idx])
60
61 log_orders_extracted(
62 rollout_id=rollout_id,
63 power_name=power,
64 orders_count=len(orders),
65 expected_count=expected_count,
66 raw_response_length=0,
67 phase=phase,
68 raw_response="[BASELINE BOT]",
69 )
70 metrics.record_extraction(len(orders), expected_count)
71 all_orders.extend(orders)
72
73 main_game.step(all_orders)
74
75 if should_visualize and vis:
76 vis.capture_turn(
77 main_game.game,
78 f"Warmup step {i + 1}/{warmup_phases}
79{chr(10).join(all_orders)}",
80 )
81 # 2. THE FORK (Clone the state G times)
82 # -------------------------------------
83 frozen_state = cloudpickle.dumps(main_game)
84 frozen_vis = cloudpickle.dumps(vis)
85 if should_visualize and vis:
86 visualizers = [cloudpickle.loads(frozen_vis) for _ in range(cfg.samples_per_group)]
87 else:
88 visualizers = None
89
90 games = [cloudpickle.loads(frozen_state) for _ in range(cfg.samples_per_group)]
91 current_year = main_game.get_year()
92 target_year = current_year + cfg.rollout_horizon_years
93
94 fork_data = {i: {} for i in range(len(games))}
95
96 async def run_game_async(g_idx: int, game: DiplomacyWrapper, vis_obj) -> dict:
97 """Run a single game clone asynchronously until completion."""
98 game_fork_data = {}
99 step_count = 0
100
101 while game.get_year() < target_year and not game.is_done():
102 step_count += 1
103
104 # Get inputs for all powers in this game
105 inputs = game.get_all_inputs(agent=agent)
106 phase = game.get_current_phase()
107
108 # Results will be populated as we process each power
109 # Format: {power_idx: {"orders": [...], "response_data": {...}}}
110 power_results: dict[int, dict] = {}
111
112 # 1. Handle baseline bots
113 for idx, power in enumerate(inputs["power_names"]):
114 adapter = power_adapters.get(power)
115 bot = BASELINE_BOTS[adapter]
116 orders = bot.get_orders(game, power)
117 power_results[idx] = {
118 "orders": orders,
119 "response_data": {
120 "text": "[BASELINE BOT]",
121 "prompt_token_ids": [],
122 "token_ids": [],
123 "completion_logprobs": [],
124 },
125 }
126
127 # 2. Collect all orders and log metrics
128 all_orders = []
129 for idx, power in enumerate(inputs["power_names"]):
130 result = power_results[idx]
131 orders = result["orders"]
132 response_data = result["response_data"]
133 response_text = response_data["text"]
134 expected_count = len(inputs["valid_moves"][idx])
135
136 log_orders_extracted(
137 rollout_id=rollout_id,
138 power_name=power,
139 orders_count=len(orders),
140 expected_count=expected_count,
141 raw_response_length=len(response_text),
142 phase=phase,
143 raw_response=response_text,
144 )
145 metrics.record_extraction(len(orders), expected_count)
146
147 game_fork_data[power] = {
148 "prompt": inputs["prompts"][idx],
149 "completion": response_text,
150 "prompt_token_ids": response_data.get("prompt_token_ids", []),
151 "completion_token_ids": response_data.get("token_ids", []),
152 "completion_logprobs": response_data.get("completion_logprobs", []),
153 }
154
155 all_orders.extend(orders)
156
157 # Step the game
158 game.step(all_orders)
159
160 # Update visualization if enabled
161 if should_visualize and vis_obj is not None:
162 vis_obj.capture_turn(
163 game.game,
164 f"Rollout step {step_count}
165{chr(10).join(all_orders)}",
166 )
167
168 return {"g_idx": g_idx, "fork_data": game_fork_data}
169
Define the Modal function
For the game container, we can use a lightweight CPU image. We slightly increase the memory to hold G game copies, which we'll need for GRPO training.
1 / 8
1 / 8

Benchmarking the rollout engine#

Again, before we even start doing ML, we need to collect some baseline throughput data. We're interested in a couple things:

  • How efficiently can we scale horizontally?
  • How does game length affect throughput?
  • Where are the bottlenecks?

Feel free to skim the following graphs - they're not particularly interesting with just the baseline bots. My point is more to emphasize the importance of building these instrumentation tools early, because once you start adding LLMs for inference and training, it becomes a game of whack-a-mole to eliminate the long pole in your pipeline and maximize throughput.

Scaling horizontally#

Thanks to modal, we can scale our throughput of the game engine pretty linearly.

Scaling horizontally

Game length#

We can also see that the game length affects throughput. Interestingly, there seems to be a sweet spot around 7 years. I suspect this is because the the more late game you get in Diplomacy, the more complex order resolution becomes.

Game length

Latency waterfall#

Finally, we can compute a waterfall view of rollout latency. This isn't so interesting for baseline bots since each step is pretty simple, but it will be useful when we add LLM powers.

Latency waterfall

Section 1 Summary#

Cool, we've built our rollout engine and proven we can scale it on Modal. Now we can be confident building the rest of our RL pipeline, knowing that any future latency we observe is probably due to inference, training, or I/O.

Section 2: The agent#

We glossed over it in Section 1, but we actually need an Agent interface to power even the baseline bots. Everything becomes much more complicated when we add an LLM into the rollout loop, however. This section will cover two things:

  • The LLM inference logic that make learning a game as complex as Diplomacy feasible
  • The inference engine we need to add to our rollout infrastructure to support it

The LLM Agent#

Conceptually an LLM Diplomacy agent is remarkably simple. Consider the classic (s,a) -> r RL problem. Well, since we're using an LLM, s is just a representation of the game state in text, maybe with some extra instructions. It could look something like this:

MOVEMENT_PREFIX = """\
### DIPLOMACY MOVEMENT PHASE ###

You are playing Diplomacy. Your task: output exactly one order per unit.

RULES:
- Copy-paste EXACT move strings from the valid moves list below
- One order per line inside <orders> tags
- Movement: "A PAR - BUR" means Army Paris moves to Burgundy
- Hold: "A PAR H" means Army Paris holds
- Support: "A PAR S A BUR - MAR" means Paris supports Burgundy to Marseilles
- Convoy: "F NTH C A LON - NWY" means Fleet convoys Army from London to Norway

"""

prompt = (
    f"{MOVEMENT_PREFIX}"
    f"GAME STATE:\n"
    f"Power: {power_name}\n"
    f"Phase: {phase}\n"
    f"You have {unit_count} units.\n\n"
    f"VALID MOVES:\n{moves_display}\n\n"
    f"Output {unit_count} orders (one per unit):\n"
    "<orders>\n"
)

For a, again - LLM, so it's more text. Notice in the example prompt we're seeding the LLM generation with <orders>, so we can teach the LLM to output XML and we can parse the orders from a predictable format.


Easy peasy, time to start training right? Not so fast, here's a couple reasons the example above is a very bad idea.

  1. Token bloat: Including all the VALID_MOVES in the prompt might seem like a good idea - hey this way we can force the model to only choose valid actions, right? True until you get to the mid game, when you can easily have 10+ units (NN) with 15-25 valid moves (MM) each. Listing these moves consumes O(N×M)O(N \times M) tokens in the prompt. Multiply that by 7 game powers over a dozen turns/game and 100 games/step and that's a lof of extra tokens. Worse still, the action space explodes combinatorially to MNM^N possible actions per turn. With typical mid-game values, this easily exceeds most context windows, making reasoning or search infeasible, which brings us to the next point...
  2. Logits processing: With an action space that large, learning to take good moves is hard enough. We don't want to waste cycles first learning how to take legal actions. To make things worse, with a naive implementation, we'd risk the agent only outputting garbage moves, resulting in the default behavior of all units holding, thus generating no advantage and no gradient to learn from. To prevent this, we need to implement a custom logits processor that runs every token generation step to constrain model generation to only valid Diplomacy moves. Luckily vLLM supports custom logits processors. Let's dive in, since this is one of the cooler parts of the project.

Custom logits processor#

The rough idea for the logits processor is this:

  • Before generation: We use the game engine to compute a trie over token ID sequences for all valid moves for the current game state + power
  • At generation start: The model can generate freely. This allows us to output reasoning traces, call tools, etc.
  • During generation: We "listen" for the opening <orders> tag - once we see it we're in an active phase and we constrain outputs to only token sequences that can form valid moves.
  • Active phase: Each token, we advance a pointer in the trie of valid moves. When we reach a leaf node (end-of-move), we allow newlines and reset to the root of the trie to begin the next order.
  • Active phase, cont.: After each completed order, we extract the unit identifier (e.g., "A PAR" from "A PAR - BUR"), mark it as used, and rebuild the trie excluding all moves for that unit. This prevents duplicate orders per unit.
  • Completion: We listen for the closing </orders> tag - once we see it we're done and can return the generated orders. If we've emitted enough orders for all units, we force-complete the closing tag.
Note

This was the first time I'd coded a trie since grinding leet codes for my first job. Turns out that stuff is useful, who knew.

Interactive Demo#

To see how this works in practice, try the interactive demo below. It walks through the key stages of constrained generation:

Logits Processor in Action

🔓 FREE GENERATION
Step 1 / 13
Generated Output
Let
Allowed Next Tokens
"Let"
"I"
"Here"
"First"
"<orders>"
"..."
Trie State
ROOT
"A"
" PAR"(A PAR)
" H"(A PAR)
✓ End of move
" -"(A PAR)
" BUR"(A PAR)
✓ End of move
" MAR"(A PAR)
✓ End of move
"F"
" MAR"(F MAR)
" H"(F MAR)
✓ End of move
" -"(F MAR)
" SPA"(F MAR)
✓ End of move
Generation Start
Before we see the <orders> tag, the model can generate freely. This allows reasoning, explanations, etc.

We can look at the impact on valid move accuracy across a variety of prompt configurations:

PromptDescription
VerboseThe original prompt implementation with a verbose description of the game task, state, and full valid moves
Full moves (compact)Incliude the entire list of valid moves in the prompt. Compact the json to minimize token usage and compact task description verbiage
MinimalReduce task description. Don't include valid moves or any game state contect.
Minimal with windowsReduce task description. Don't include valid moves. Represent relevant parts of the board as a linked list to provide strategic context
Minimal full contextReduce task description. Don't include valid moves, but include # of valid moves for each unit. Represent relevant parts of the board as a linked list to provide strategic context
Logits processor accuracy

We can see that even including the full moves list, the naive implementation only achieves around 50% valid move accuracy. When you remove the full list, the model is basically randomly guessing and accuracy tanks to single digit percentages.


When we include the logits processor, we see a dramatic improvement across the board - around 82%. This makes it much easier for the model to take impactful moves in the games that result in advantage and a gradient to learn strategy from.


A natural next question is whether this approach increases latency. Improving acuracy 10-20x wouldn't do us any good if it's 20-30x slower to generate orders.

Runtime complexity and performance impact#

One of the key advantages of this approach is that the runtime overhead is actually minimal:

Trie construction (one-time per request):

  • Building the trie requires encoding each valid move into token IDs
  • For NN units with MM moves each, we encode O(N×M)O(N \times M) moves
  • Each move tokenizes to roughly TT tokens on average
  • Trie construction: O(N×M×T)O(N \times M \times T) time, O(N×M×T)O(N \times M \times T) space
  • In practice with N=10N=10 units, M=20M=20 moves, and T5T \approx 5 tokens/move, this is ~1000 tokens to process
  • This happens once at the start of the request, not per generation step

Per-token overhead during generation:

  • Trie lookup: O(1)O(1) dictionary lookup in the current node's children (typically < 100 options)
  • Tag detection: O(k)O(k) where kk is the tag length (constant, ~10 characters)
  • Logits masking: O(V)O(V) where VV is vocab size (~50k-100k tokens), but we only mask when active. This is also vectorized so minimal wall clock time impact.
  • Total per-token: O(V)O(V) worst case, but only when inside <orders> tags

Empirically, we see that the logits processor actually improves throughput ~10x, from 900ms to 80ms. My theory is the logits processot not only constrains the model to valid outputs, it also constraints the model's total token output by forcing it to generate the closing </orders> tag once it's generated all the orders for all units.

Logits processor throughput

Testing our training pipeline#

At this point we have the ability to run a bunch of rollouts and use an LLM as the agent. We're ready to test our infrastructure with a minimal training run.

Cooperative self-play

Diplomacy is different than traditional RL environments like Chess or Go in two important ways:

  1. The lack of transitivity If you're better than me at Chess, and Alice is better than you, Alice is almost certainly better than me. This is not the case in Diplomacy. Diplomacy is more like rocks paper scissors - an honorable player who never backstabs might win amongst other honorable players, but if I'm willing to betray you all to win then everything changes.
  2. Diplomacy is not striclty zero sum Chess or Go are zero sum games - one player wins and the other loses. If I blunder my queen and my win probability drops, yours necessarily goes up. Diplomacy is not zero sum in the short term - we can both capture a supply center or increase our win probability on the same turn. For this initial training run, we're going to use pure self play (think AlphaZero) to train a model. Given the above, with 1 policy controlling all 7 powers, we should expect the model to learn to maximize the total score of all players. This may result in a "peace treaty" equilibrium, where each power expands into their surrounding neutral territories but never shows hostility, since it "knows" any hostility it shows will be met with equal hostility in return.

Despite the disclaimer above, this training setting will be a perfect test for our infra and core learning pipeline. Even if we can't expect to come away with a competitive human level player, we should prove to ourselves that the model is capable of learning to maximize its reward in this environment.

TODO: Add some detauls about the training setup

With a full training run we can essentially sovle the cooperative self-play setting. To validate the model actually learns anything useful, we can eval it against our baseline bots from before. And look! It learns to conquer the world pretty effectively. It learns to get nearly 12 supply centers on average against both the random and chaos bots.