diff --git a/Cargo.lock b/Cargo.lock index 6623f89..cd35e31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,6 +182,7 @@ dependencies = [ "criterion", "enum-iterator", "env_logger", + "float-ord", "log", "rand", "serde", @@ -469,6 +470,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + [[package]] name = "fnv" version = "1.0.7" diff --git a/battlesnake/Cargo.toml b/battlesnake/Cargo.toml index 9d68f27..31b4c35 100644 --- a/battlesnake/Cargo.toml +++ b/battlesnake/Cargo.toml @@ -29,6 +29,7 @@ env_logger = "0.11" bitvec = "1.0" enum-iterator = "2.1" rand = "0.8" +float-ord = "0.3" [dev-dependencies] criterion = "0.5" diff --git a/battlesnake/src/main.rs b/battlesnake/src/main.rs index 2fb265f..a806d24 100644 --- a/battlesnake/src/main.rs +++ b/battlesnake/src/main.rs @@ -9,7 +9,8 @@ use battlesnake::types::{ wire::{Request, Response}, Direction, }; -use log::{debug, error, info, warn}; +use float_ord::FloatOrd; +use log::{debug, error, info, trace, warn}; use rand::prelude::*; use serde::Serialize; use tokio::{ @@ -87,7 +88,7 @@ async fn get_move(request: Json) -> response::Json { let end_condition: &dyn Fn(&Board) -> Option = match &*request.game.ruleset.name { "solo" => &|board: &Board| { if board.num_snakes() == 0 - || board.turn() > base_turns + u32::from(request.you.length) * 3 + || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) { Some(board.turn()) } else { @@ -95,7 +96,8 @@ async fn get_move(request: Json) -> response::Json { } }, _ => &|board: &Board| { - if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3 + if board.num_snakes() <= 1 + || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) { Some(u32::from(board.alive(id))) } else { @@ -104,30 +106,58 @@ async fn get_move(request: Json) -> response::Json { }, }; - let mut action_data = [(0, 0); 4]; - let mut total_simulations = 0; - while start.elapsed() < Duration::from_millis(250) { + let mut mcts_manager = MctsManager::new(id); + let c = f32::sqrt(2.0); + 'outer: while start.elapsed() < Duration::from_millis(250) { let mut board = board.clone(); - let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up); - board.next_turn(&[(id, action)]); + while let Some(action) = mcts_manager.next_action(&board, c) { + board.next_turn(&[(id, action)]); + if let Some(score) = end_condition(&board) { + mcts_manager.apply_score(score); + continue 'outer; + } + } let score = board.simulate_random(end_condition); - let action_data = &mut action_data[usize::from(action)]; - action_data.0 += score; - action_data.1 += 1; - total_simulations += 1; + mcts_manager.apply_score(score); + } + for action in actions { + let score = mcts_manager.base.next[usize::from(action)] + .as_ref() + .map(|info| info.score as f32 / info.played as f32); + if let Some(score) = score { + info!("{action:?} -> {score}"); + } else { + info!("{action:?} -> None"); + } } - debug!("action data: {action_data:?}"); - let action = actions.into_iter().max_by(|lhs, rhs| { - let lhs_data = action_data[usize::from(*lhs)]; - let rhs_data = action_data[usize::from(*rhs)]; - (u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1))) - }); + let action = mcts_manager + .base + .next + .iter() + .enumerate() + .filter_map(|(index, info)| { + info.as_ref().map(|info| { + ( + match index { + 0 => Direction::Up, + 1 => Direction::Down, + 2 => Direction::Left, + 3 => Direction::Right, + _ => unreachable!(), + }, + info, + ) + }) + }) + .max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32)) + .map(|(action, _)| action); if let Some(action) = action { - let action_data = action_data[usize::from(action)]; - let avg_turns = action_data.0 / action_data.1; - info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns."); + info!( + "found action {action:?} after {} simulations.", + mcts_manager.base.played + ); } else { warn!("unable to find a valid action"); } @@ -136,10 +166,119 @@ async fn get_move(request: Json) -> response::Json { direction: action.unwrap_or(Direction::Up), shout: None, }) - }).await.unwrap() + }) + .await + .unwrap() } async fn end(request: Json) { let board = Board::from(&*request); info!("got end request: {board}"); } + +#[derive(Debug)] +struct ActionInfo { + score: u32, + played: u32, + next: Box<[Option; 4]>, +} + +impl ActionInfo { + fn new() -> Self { + Self { + score: 0, + played: 0, + next: Box::new([None, None, None, None]), + } + } + + fn uct(&self, c: f32) -> [Option; 4] { + let mut ucts = [None; 4]; + for (action, uct) in self.next.iter().zip(ucts.iter_mut()) { + if let Some(action) = action { + let exploitation = action.score as f32 / action.played as f32; + let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32); + uct.replace(c.mul_add(exploration, exploitation)); + } + } + ucts + } +} + +#[derive(Debug)] +struct MctsManager { + base: ActionInfo, + actions: Vec, + expanded: bool, + snake: u8, +} + +impl MctsManager { + fn new(snake: u8) -> Self { + Self { + base: ActionInfo::new(), + actions: Vec::new(), + expanded: false, + snake, + } + } + + fn apply_score(&mut self, score: u32) { + self.base.played += 1; + self.base.score += score; + let mut current = &mut self.base; + for action in &self.actions { + let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else { + error!("got action without actioninfo"); + break; + }; + current = new_current; + current.played += 1; + current.score += score; + } + self.actions.clear(); + self.expanded = false; + } + + fn next_action(&mut self, board: &Board, c: f32) -> Option { + if self.expanded { + return None; + } + + let mut current = &mut self.base; + for action in &self.actions { + let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else { + error!("got action without actioninfo"); + return None; + }; + current = new_current; + } + + let ucts = current.uct(c); + let valid_actions = board.valid_actions(self.snake); + let ucts: Vec<_> = valid_actions + .map(|action| (action, ucts[usize::from(action)])) + .collect(); + trace!("got actions: {ucts:?}"); + if ucts.iter().any(|(_, uct)| uct.is_none()) { + let action = ucts + .iter() + .filter(|(_, uct)| uct.is_none()) + .choose(&mut thread_rng())? + .0; + self.expanded = true; + current.next[usize::from(action)].replace(ActionInfo::new()); + self.actions.push(action); + return Some(action); + } + + let action = ucts + .iter() + .max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY))) + .map(|(action, _)| *action); + if let Some(action) = action { + self.actions.push(action); + } + action + } +}