diff --git a/Cargo.lock b/Cargo.lock index ecafd0d..9a1be32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,6 +183,7 @@ dependencies = [ "enum-iterator", "env_logger", "float-ord", + "futures-util", "log", "rand", "serde", @@ -516,6 +517,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -535,6 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", + "futures-macro", "futures-sink", "futures-task", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 1d7f7ab..f7a6e27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,3 +3,7 @@ members = ["battlesnake", "xtask"] resolver = "3" default-members = ["battlesnake"] + +[profile.release] +lto = "fat" +codegen-units = 1 diff --git a/battlesnake/Cargo.toml b/battlesnake/Cargo.toml index b6112ec..7beb3af 100644 --- a/battlesnake/Cargo.toml +++ b/battlesnake/Cargo.toml @@ -30,6 +30,7 @@ bitvec = "1.0" enum-iterator = "2.1" rand = "0.9" float-ord = "0.3" +futures-util = "0.3.31" [dev-dependencies] criterion = "0.5" diff --git a/battlesnake/src/main.rs b/battlesnake/src/main.rs index 96c41c3..6da1b64 100644 --- a/battlesnake/src/main.rs +++ b/battlesnake/src/main.rs @@ -12,6 +12,7 @@ use battlesnake::types::{ wire::{Request, Response}, }; use float_ord::FloatOrd; +use futures_util::future::join_all; use log::{debug, error, info, trace, warn}; use rand::prelude::*; use serde::Serialize; @@ -88,120 +89,126 @@ async fn get_move(request: Json) -> response::Json { } info!("valid actions: {actions:?}"); - tokio::task::spawn_blocking(move || { + if start.elapsed() > Duration::from_millis(10) { + error!( + "The calculation started late ({}ms)", + start.elapsed().as_millis() + ); + } + let end_condition: fn(&Board) -> Option<()> = match &*request.game.ruleset.name { + "solo" => end_solo, + _ => end_standard, + }; + let score_fn: fn(&Board, u8) -> u32 = match &*request.game.ruleset.name { + "solo" => score_solo, + _ => score_standard, + }; + + let action_futures = (0..3).map(|_| { + let request = request.clone(); + let board = board.clone(); let mut rng = SmallRng::from_os_rng(); - if start.elapsed() > Duration::from_millis(10) { - error!( - "The calculation started late ({}ms)", - start.elapsed().as_millis() - ); - } - let base_turns = board.turn(); - let rolling_horizon = (u32::from(request.you.length) * 3).max(32); - let last_turn = base_turns + rolling_horizon; - let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name { - "solo" => &|board| { - if board.valid_actions(0).count() == 0 || board.turn() > last_turn { - Some(()) - } else { - None + let actions = actions.clone(); + tokio::task::spawn_blocking(move || { + let mut mcts_managers: Vec<_> = (0..request.board.snakes.len()) + .map(|id| MctsManager::new(u8::try_from(id).unwrap())) + .collect(); + let c = f32::sqrt(2.0); + let mut mcts_actions = Vec::new(); + while start.elapsed() < timeout * 4 / 5 { + let mut board = board.clone(); + while end_condition(&board).is_none() { + mcts_actions.clear(); + mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| { + mcts_manager + .next_action(&board, c, &mut rng) + .map(|action| (mcts_manager.snake, action)) + })); + board.next_turn(&mcts_actions, &mut rng); + if mcts_actions.is_empty() { + break; + } } - }, - _ => &|board| { - if board.num_snakes() <= 1 || board.turn() > last_turn { - Some(()) - } else { - None - } - }, - }; - let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name { - "solo" => &|board, id| u32::try_from(board.length(id)).unwrap_or(0), - _ => &|board, id| { - if board.alive(id) { - 1 + u32::from(board.max_length() == board.length(id)) - } else { - 0 - } - }, - }; - - let mut mcts_managers: Vec<_> = (0..request.board.snakes.len()) - .map(|id| MctsManager::new(u8::try_from(id).unwrap())) - .collect(); - let c = f32::sqrt(2.0); - let mut mcts_actions = Vec::new(); - while start.elapsed() < timeout * 4 / 5 { - let mut board = board.clone(); - while end_condition(&board).is_none() { - mcts_actions.clear(); - mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| { - mcts_manager - .next_action(&board, c, &mut rng) - .map(|action| (mcts_manager.snake, action)) - })); - board.next_turn(&mcts_actions, &mut rng); - if mcts_actions.is_empty() { - break; + board.simulate_random(&mut rng, end_condition); + for mcts_manager in &mut mcts_managers { + let id = mcts_manager.snake; + let score = score_fn(&board, id); + mcts_manager.apply_score(score); } } - board.simulate_random(&mut rng, end_condition); - for mcts_manager in &mut mcts_managers { - let id = mcts_manager.snake; - let score = score_fn(&board, id); - mcts_manager.apply_score(score); + let my_mcts_manager = mcts_managers.into_iter().nth(usize::from(id)).unwrap(); + for action in actions { + let score = my_mcts_manager.base.next[usize::from(action)] + .as_ref() + .map(|info| info.score as f32 / info.played as f32); + if let Some(score) = score { + info!("{action:?} -> {score}"); + } else { + info!("{action:?} -> None"); + } } - } - let my_mcts_manager = &mcts_managers[usize::from(id)]; - for action in actions { - let score = my_mcts_manager.base.next[usize::from(action)] - .as_ref() - .map(|info| info.score as f32 / info.played as f32); - if let Some(score) = score { - info!("{action:?} -> {score}"); - } else { - info!("{action:?} -> None"); - } - } - let action = my_mcts_manager - .base - .next - .iter() - .enumerate() - .filter_map(|(index, info)| { - info.as_ref().map(|info| { - ( - match index { - 0 => Direction::Up, - 1 => Direction::Down, - 2 => Direction::Left, - 3 => Direction::Right, - _ => unreachable!(), - }, - info, - ) - }) - }) - .max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32)) - .map(|(action, _)| action); + let actions = my_mcts_manager.base.next.map(|action| { + action.map_or(0.0, |action| action.score as f32 / action.played as f32) + }); - if let Some(action) = action { - info!( - "found action {action:?} after {} simulations.", - my_mcts_manager.base.played - ); - } else { - warn!("unable to find a valid action"); - } - info!("chose {action:?}"); - response::Json(Response { - direction: action.unwrap_or(Direction::Up), - shout: None, + (actions, my_mcts_manager.base.played) }) + }); + let (actions, played) = join_all(action_futures).await.into_iter().fold( + ([0.0; 4], 0), + |(mut total, mut games), actions| { + if let Ok((actions, new_games)) = actions { + for i in 0..total.len() { + total[i] += actions[i]; + } + games += new_games; + } + (total, games) + }, + ); + let action = actions + .into_iter() + .enumerate() + .max_by_key(|(_, score)| FloatOrd(*score)) + .map(|(index, _)| match index { + 0 => Direction::Up, + 1 => Direction::Down, + 2 => Direction::Left, + 3 => Direction::Right, + _ => unreachable!(), + }); + + if let Some(action) = action { + info!("found action {action:?} after {played} simulations.",); + } else { + warn!("unable to find a valid action"); + } + info!("chose {action:?}"); + response::Json(Response { + direction: action.unwrap_or(Direction::Up), + shout: None, }) - .await - .unwrap() +} + +fn end_solo(board: &Board) -> Option<()> { + (board.valid_actions(0).count() == 0).then_some(()) +} + +fn end_standard(board: &Board) -> Option<()> { + (board.num_snakes() <= 1).then_some(()) +} + +fn score_solo(board: &Board, id: u8) -> u32 { + u32::try_from(board.length(id)).unwrap_or(0) +} + +fn score_standard(board: &Board, id: u8) -> u32 { + if board.alive(id) { + 1 + u32::from(board.max_length() == board.length(id)) + } else { + 0 + } } async fn end(request: Json) { diff --git a/battlesnake/src/types/mod.rs b/battlesnake/src/types/mod.rs index fa62d61..75cb91f 100644 --- a/battlesnake/src/types/mod.rs +++ b/battlesnake/src/types/mod.rs @@ -33,7 +33,7 @@ impl Coord { } } -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Sequence)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Sequence)] #[serde(rename_all = "lowercase")] pub enum Direction { /// Move in positive y direction