add multithreading
This commit is contained in:
		
							
								
								
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -183,6 +183,7 @@ dependencies = [ | ||||
|  "enum-iterator", | ||||
|  "env_logger", | ||||
|  "float-ord", | ||||
|  "futures-util", | ||||
|  "log", | ||||
|  "rand", | ||||
|  "serde", | ||||
| @@ -516,6 +517,17 @@ version = "0.3.31" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-macro" | ||||
| version = "0.3.31" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-sink" | ||||
| version = "0.3.31" | ||||
| @@ -535,6 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" | ||||
| dependencies = [ | ||||
|  "futures-core", | ||||
|  "futures-macro", | ||||
|  "futures-sink", | ||||
|  "futures-task", | ||||
|  "pin-project-lite", | ||||
|   | ||||
| @@ -3,3 +3,7 @@ members = ["battlesnake", "xtask"] | ||||
| resolver = "3" | ||||
|  | ||||
| default-members = ["battlesnake"] | ||||
|  | ||||
| [profile.release] | ||||
| lto = "fat" | ||||
| codegen-units = 1 | ||||
|   | ||||
| @@ -30,6 +30,7 @@ bitvec = "1.0" | ||||
| enum-iterator = "2.1" | ||||
| rand = "0.9" | ||||
| float-ord = "0.3" | ||||
| futures-util = "0.3.31" | ||||
|  | ||||
| [dev-dependencies] | ||||
| criterion = "0.5" | ||||
|   | ||||
| @@ -12,6 +12,7 @@ use battlesnake::types::{ | ||||
|     wire::{Request, Response}, | ||||
| }; | ||||
| use float_ord::FloatOrd; | ||||
| use futures_util::future::join_all; | ||||
| use log::{debug, error, info, trace, warn}; | ||||
| use rand::prelude::*; | ||||
| use serde::Serialize; | ||||
| @@ -88,120 +89,126 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> { | ||||
|     } | ||||
|     info!("valid actions: {actions:?}"); | ||||
|  | ||||
|     tokio::task::spawn_blocking(move || { | ||||
|     if start.elapsed() > Duration::from_millis(10) { | ||||
|         error!( | ||||
|             "The calculation started late ({}ms)", | ||||
|             start.elapsed().as_millis() | ||||
|         ); | ||||
|     } | ||||
|     let end_condition: fn(&Board) -> Option<()> = match &*request.game.ruleset.name { | ||||
|         "solo" => end_solo, | ||||
|         _ => end_standard, | ||||
|     }; | ||||
|     let score_fn: fn(&Board, u8) -> u32 = match &*request.game.ruleset.name { | ||||
|         "solo" => score_solo, | ||||
|         _ => score_standard, | ||||
|     }; | ||||
|  | ||||
|     let action_futures = (0..3).map(|_| { | ||||
|         let request = request.clone(); | ||||
|         let board = board.clone(); | ||||
|         let mut rng = SmallRng::from_os_rng(); | ||||
|         if start.elapsed() > Duration::from_millis(10) { | ||||
|             error!( | ||||
|                 "The calculation started late ({}ms)", | ||||
|                 start.elapsed().as_millis() | ||||
|             ); | ||||
|         } | ||||
|         let base_turns = board.turn(); | ||||
|         let rolling_horizon = (u32::from(request.you.length) * 3).max(32); | ||||
|         let last_turn = base_turns + rolling_horizon; | ||||
|         let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name { | ||||
|             "solo" => &|board| { | ||||
|                 if board.valid_actions(0).count() == 0 || board.turn() > last_turn { | ||||
|                     Some(()) | ||||
|                 } else { | ||||
|                     None | ||||
|         let actions = actions.clone(); | ||||
|         tokio::task::spawn_blocking(move || { | ||||
|             let mut mcts_managers: Vec<_> = (0..request.board.snakes.len()) | ||||
|                 .map(|id| MctsManager::new(u8::try_from(id).unwrap())) | ||||
|                 .collect(); | ||||
|             let c = f32::sqrt(2.0); | ||||
|             let mut mcts_actions = Vec::new(); | ||||
|             while start.elapsed() < timeout * 4 / 5 { | ||||
|                 let mut board = board.clone(); | ||||
|                 while end_condition(&board).is_none() { | ||||
|                     mcts_actions.clear(); | ||||
|                     mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| { | ||||
|                         mcts_manager | ||||
|                             .next_action(&board, c, &mut rng) | ||||
|                             .map(|action| (mcts_manager.snake, action)) | ||||
|                     })); | ||||
|                     board.next_turn(&mcts_actions, &mut rng); | ||||
|                     if mcts_actions.is_empty() { | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             }, | ||||
|             _ => &|board| { | ||||
|                 if board.num_snakes() <= 1 || board.turn() > last_turn { | ||||
|                     Some(()) | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             }, | ||||
|         }; | ||||
|         let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name { | ||||
|             "solo" => &|board, id| u32::try_from(board.length(id)).unwrap_or(0), | ||||
|             _ => &|board, id| { | ||||
|                 if board.alive(id) { | ||||
|                     1 + u32::from(board.max_length() == board.length(id)) | ||||
|                 } else { | ||||
|                     0 | ||||
|                 } | ||||
|             }, | ||||
|         }; | ||||
|  | ||||
|         let mut mcts_managers: Vec<_> = (0..request.board.snakes.len()) | ||||
|             .map(|id| MctsManager::new(u8::try_from(id).unwrap())) | ||||
|             .collect(); | ||||
|         let c = f32::sqrt(2.0); | ||||
|         let mut mcts_actions = Vec::new(); | ||||
|         while start.elapsed() < timeout * 4 / 5 { | ||||
|             let mut board = board.clone(); | ||||
|             while end_condition(&board).is_none() { | ||||
|                 mcts_actions.clear(); | ||||
|                 mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| { | ||||
|                     mcts_manager | ||||
|                         .next_action(&board, c, &mut rng) | ||||
|                         .map(|action| (mcts_manager.snake, action)) | ||||
|                 })); | ||||
|                 board.next_turn(&mcts_actions, &mut rng); | ||||
|                 if mcts_actions.is_empty() { | ||||
|                     break; | ||||
|                 board.simulate_random(&mut rng, end_condition); | ||||
|                 for mcts_manager in &mut mcts_managers { | ||||
|                     let id = mcts_manager.snake; | ||||
|                     let score = score_fn(&board, id); | ||||
|                     mcts_manager.apply_score(score); | ||||
|                 } | ||||
|             } | ||||
|             board.simulate_random(&mut rng, end_condition); | ||||
|             for mcts_manager in &mut mcts_managers { | ||||
|                 let id = mcts_manager.snake; | ||||
|                 let score = score_fn(&board, id); | ||||
|                 mcts_manager.apply_score(score); | ||||
|             let my_mcts_manager = mcts_managers.into_iter().nth(usize::from(id)).unwrap(); | ||||
|             for action in actions { | ||||
|                 let score = my_mcts_manager.base.next[usize::from(action)] | ||||
|                     .as_ref() | ||||
|                     .map(|info| info.score as f32 / info.played as f32); | ||||
|                 if let Some(score) = score { | ||||
|                     info!("{action:?} -> {score}"); | ||||
|                 } else { | ||||
|                     info!("{action:?} -> None"); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         let my_mcts_manager = &mcts_managers[usize::from(id)]; | ||||
|         for action in actions { | ||||
|             let score = my_mcts_manager.base.next[usize::from(action)] | ||||
|                 .as_ref() | ||||
|                 .map(|info| info.score as f32 / info.played as f32); | ||||
|             if let Some(score) = score { | ||||
|                 info!("{action:?} -> {score}"); | ||||
|             } else { | ||||
|                 info!("{action:?} -> None"); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let action = my_mcts_manager | ||||
|             .base | ||||
|             .next | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .filter_map(|(index, info)| { | ||||
|                 info.as_ref().map(|info| { | ||||
|                     ( | ||||
|                         match index { | ||||
|                             0 => Direction::Up, | ||||
|                             1 => Direction::Down, | ||||
|                             2 => Direction::Left, | ||||
|                             3 => Direction::Right, | ||||
|                             _ => unreachable!(), | ||||
|                         }, | ||||
|                         info, | ||||
|                     ) | ||||
|                 }) | ||||
|             }) | ||||
|             .max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32)) | ||||
|             .map(|(action, _)| action); | ||||
|             let actions = my_mcts_manager.base.next.map(|action| { | ||||
|                 action.map_or(0.0, |action| action.score as f32 / action.played as f32) | ||||
|             }); | ||||
|  | ||||
|         if let Some(action) = action { | ||||
|             info!( | ||||
|                 "found action {action:?} after {} simulations.", | ||||
|                 my_mcts_manager.base.played | ||||
|             ); | ||||
|         } else { | ||||
|             warn!("unable to find a valid action"); | ||||
|         } | ||||
|         info!("chose {action:?}"); | ||||
|         response::Json(Response { | ||||
|             direction: action.unwrap_or(Direction::Up), | ||||
|             shout: None, | ||||
|             (actions, my_mcts_manager.base.played) | ||||
|         }) | ||||
|     }); | ||||
|     let (actions, played) = join_all(action_futures).await.into_iter().fold( | ||||
|         ([0.0; 4], 0), | ||||
|         |(mut total, mut games), actions| { | ||||
|             if let Ok((actions, new_games)) = actions { | ||||
|                 for i in 0..total.len() { | ||||
|                     total[i] += actions[i]; | ||||
|                 } | ||||
|                 games += new_games; | ||||
|             } | ||||
|             (total, games) | ||||
|         }, | ||||
|     ); | ||||
|     let action = actions | ||||
|         .into_iter() | ||||
|         .enumerate() | ||||
|         .max_by_key(|(_, score)| FloatOrd(*score)) | ||||
|         .map(|(index, _)| match index { | ||||
|             0 => Direction::Up, | ||||
|             1 => Direction::Down, | ||||
|             2 => Direction::Left, | ||||
|             3 => Direction::Right, | ||||
|             _ => unreachable!(), | ||||
|         }); | ||||
|  | ||||
|     if let Some(action) = action { | ||||
|         info!("found action {action:?} after {played} simulations.",); | ||||
|     } else { | ||||
|         warn!("unable to find a valid action"); | ||||
|     } | ||||
|     info!("chose {action:?}"); | ||||
|     response::Json(Response { | ||||
|         direction: action.unwrap_or(Direction::Up), | ||||
|         shout: None, | ||||
|     }) | ||||
|     .await | ||||
|     .unwrap() | ||||
| } | ||||
|  | ||||
| fn end_solo(board: &Board) -> Option<()> { | ||||
|     (board.valid_actions(0).count() == 0).then_some(()) | ||||
| } | ||||
|  | ||||
| fn end_standard(board: &Board) -> Option<()> { | ||||
|     (board.num_snakes() <= 1).then_some(()) | ||||
| } | ||||
|  | ||||
| fn score_solo(board: &Board, id: u8) -> u32 { | ||||
|     u32::try_from(board.length(id)).unwrap_or(0) | ||||
| } | ||||
|  | ||||
| fn score_standard(board: &Board, id: u8) -> u32 { | ||||
|     if board.alive(id) { | ||||
|         1 + u32::from(board.max_length() == board.length(id)) | ||||
|     } else { | ||||
|         0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn end(request: Json<Request>) { | ||||
|   | ||||
| @@ -33,7 +33,7 @@ impl Coord { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Sequence)] | ||||
| #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Sequence)] | ||||
| #[serde(rename_all = "lowercase")] | ||||
| pub enum Direction { | ||||
|     /// Move in positive y direction | ||||
|   | ||||
		Reference in New Issue
	
	Block a user