multi agent mcts
All checks were successful
Build / build (push) Successful in 2m10s

This commit is contained in:
Max Känner 2025-01-26 19:29:45 +01:00
parent 302f5cac50
commit 879f99e23f
2 changed files with 43 additions and 20 deletions

View File

@ -65,6 +65,7 @@ async fn start(request: Json<Request>) {
info!("got start request: {board}"); info!("got start request: {board}");
} }
#[allow(clippy::too_many_lines)]
async fn get_move(request: Json<Request>) -> response::Json<Response> { async fn get_move(request: Json<Request>) -> response::Json<Response> {
let start = Instant::now(); let start = Instant::now();
let board = Board::from(&*request); let board = Board::from(&*request);
@ -88,43 +89,68 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
let base_turns = board.turn(); let base_turns = board.turn();
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name { let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name {
"solo" => &|board: &Board| { "solo" => &|board| {
if board.num_snakes() == 0 if board.num_snakes() == 0
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{ {
Some(board.turn()) Some(())
} else { } else {
None None
} }
}, },
_ => &|board: &Board| { _ => &|board| {
if board.num_snakes() <= 1 if board.num_snakes() <= 1
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{ {
Some(u32::from(board.alive(id))) Some(())
} else { } else {
None None
} }
}, },
}; };
let start_snakes = u32::try_from(board.num_snakes()).unwrap_or(0);
let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
"solo" => &|board, _| board.turn(),
_ => &|board, id| {
if board.alive(id) {
1 + start_snakes - u32::try_from(board.num_snakes()).unwrap_or(start_snakes)
} else {
0
}
},
};
let mut mcts_manager = MctsManager::new(id); let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
.collect();
let c = f32::sqrt(2.0); let c = f32::sqrt(2.0);
'outer: while start.elapsed() < Duration::from_millis(250) { while start.elapsed() < Duration::from_millis(250) {
let mut board = board.clone(); let mut board = board.clone();
while let Some(action) = mcts_manager.next_action(&board, c) { while end_condition(&board).is_none() {
board.next_turn(&[(id, action)]); let actions: Vec<_> = mcts_managers
if let Some(score) = end_condition(&board) { .iter_mut()
mcts_manager.apply_score(score); .filter_map(|mcts_manager| {
continue 'outer; mcts_manager
.next_action(&board, c)
.map(|action| (mcts_manager.snake, action))
})
.collect();
board.next_turn(&actions);
if actions.is_empty() {
break;
} }
} }
let score = board.simulate_random(end_condition); board.simulate_random(end_condition);
mcts_manager.apply_score(score); for mcts_manager in &mut mcts_managers {
let id = mcts_manager.snake;
let score = score_fn(&board, id);
mcts_manager.apply_score(score);
}
} }
let my_mcts_manager = &mcts_managers[usize::from(id)];
for action in actions { for action in actions {
let score = mcts_manager.base.next[usize::from(action)] let score = my_mcts_manager.base.next[usize::from(action)]
.as_ref() .as_ref()
.map(|info| info.score as f32 / info.played as f32); .map(|info| info.score as f32 / info.played as f32);
if let Some(score) = score { if let Some(score) = score {
@ -134,7 +160,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
} }
} }
let action = mcts_manager let action = my_mcts_manager
.base .base
.next .next
.iter() .iter()
@ -159,7 +185,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
if let Some(action) = action { if let Some(action) = action {
info!( info!(
"found action {action:?} after {} simulations.", "found action {action:?} after {} simulations.",
mcts_manager.base.played my_mcts_manager.base.played
); );
} else { } else {
warn!("unable to find a valid action"); warn!("unable to find a valid action");

View File

@ -233,9 +233,6 @@ impl Board {
pub fn valid_actions(&self, id: u8) -> impl Iterator<Item = Direction> + use<'_> { pub fn valid_actions(&self, id: u8) -> impl Iterator<Item = Direction> + use<'_> {
let index = self.id_to_index(id); let index = self.id_to_index(id);
if index.is_none() {
warn!("Asked for a snake that doesn't exist");
}
index index
.into_iter() .into_iter()
.flat_map(|index| self.valid_actions_index(index)) .flat_map(|index| self.valid_actions_index(index))