From 879f99e23fe85f0deba408525e2f65cbcf5609cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=A4nner?= Date: Sun, 26 Jan 2025 19:29:45 +0100 Subject: [PATCH] multi agent mcts --- battlesnake/src/main.rs | 60 +++++++++++++++++++++-------- battlesnake/src/types/simulation.rs | 3 -- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/battlesnake/src/main.rs b/battlesnake/src/main.rs index 4873392..7b5d006 100644 --- a/battlesnake/src/main.rs +++ b/battlesnake/src/main.rs @@ -65,6 +65,7 @@ async fn start(request: Json) { info!("got start request: {board}"); } +#[allow(clippy::too_many_lines)] async fn get_move(request: Json) -> response::Json { let start = Instant::now(); let board = Board::from(&*request); @@ -88,43 +89,68 @@ async fn get_move(request: Json) -> response::Json { tokio::task::spawn_blocking(move || { let base_turns = board.turn(); - let end_condition: &dyn Fn(&Board) -> Option = match &*request.game.ruleset.name { - "solo" => &|board: &Board| { + let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name { + "solo" => &|board| { if board.num_snakes() == 0 || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) { - Some(board.turn()) + Some(()) } else { None } }, - _ => &|board: &Board| { + _ => &|board| { if board.num_snakes() <= 1 || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) { - Some(u32::from(board.alive(id))) + Some(()) } else { None } }, }; + let start_snakes = u32::try_from(board.num_snakes()).unwrap_or(0); + let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name { + "solo" => &|board, _| board.turn(), + _ => &|board, id| { + if board.alive(id) { + 1 + start_snakes - u32::try_from(board.num_snakes()).unwrap_or(start_snakes) + } else { + 0 + } + }, + }; - let mut mcts_manager = MctsManager::new(id); + let mut mcts_managers: Vec<_> = (0..request.board.snakes.len()) + .map(|id| MctsManager::new(u8::try_from(id).unwrap())) + .collect(); let c = f32::sqrt(2.0); - 'outer: while start.elapsed() < Duration::from_millis(250) { + while start.elapsed() < Duration::from_millis(250) { let mut board = board.clone(); - while let Some(action) = mcts_manager.next_action(&board, c) { - board.next_turn(&[(id, action)]); - if let Some(score) = end_condition(&board) { - mcts_manager.apply_score(score); - continue 'outer; + while end_condition(&board).is_none() { + let actions: Vec<_> = mcts_managers + .iter_mut() + .filter_map(|mcts_manager| { + mcts_manager + .next_action(&board, c) + .map(|action| (mcts_manager.snake, action)) + }) + .collect(); + board.next_turn(&actions); + if actions.is_empty() { + break; } } - let score = board.simulate_random(end_condition); - mcts_manager.apply_score(score); + board.simulate_random(end_condition); + for mcts_manager in &mut mcts_managers { + let id = mcts_manager.snake; + let score = score_fn(&board, id); + mcts_manager.apply_score(score); + } } + let my_mcts_manager = &mcts_managers[usize::from(id)]; for action in actions { - let score = mcts_manager.base.next[usize::from(action)] + let score = my_mcts_manager.base.next[usize::from(action)] .as_ref() .map(|info| info.score as f32 / info.played as f32); if let Some(score) = score { @@ -134,7 +160,7 @@ async fn get_move(request: Json) -> response::Json { } } - let action = mcts_manager + let action = my_mcts_manager .base .next .iter() @@ -159,7 +185,7 @@ async fn get_move(request: Json) -> response::Json { if let Some(action) = action { info!( "found action {action:?} after {} simulations.", - mcts_manager.base.played + my_mcts_manager.base.played ); } else { warn!("unable to find a valid action"); diff --git a/battlesnake/src/types/simulation.rs b/battlesnake/src/types/simulation.rs index 89dfed7..c94ce77 100644 --- a/battlesnake/src/types/simulation.rs +++ b/battlesnake/src/types/simulation.rs @@ -233,9 +233,6 @@ impl Board { pub fn valid_actions(&self, id: u8) -> impl Iterator + use<'_> { let index = self.id_to_index(id); - if index.is_none() { - warn!("Asked for a snake that doesn't exist"); - } index .into_iter() .flat_map(|index| self.valid_actions_index(index))