This commit is contained in:
		
							
								
								
									
										7
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										7
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -182,6 +182,7 @@ dependencies = [ | ||||
|  "criterion", | ||||
|  "enum-iterator", | ||||
|  "env_logger", | ||||
|  "float-ord", | ||||
|  "log", | ||||
|  "rand", | ||||
|  "serde", | ||||
| @@ -469,6 +470,12 @@ version = "1.0.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" | ||||
|  | ||||
| [[package]] | ||||
| name = "float-ord" | ||||
| version = "0.3.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" | ||||
|  | ||||
| [[package]] | ||||
| name = "fnv" | ||||
| version = "1.0.7" | ||||
|   | ||||
| @@ -29,6 +29,7 @@ env_logger = "0.11" | ||||
| bitvec = "1.0" | ||||
| enum-iterator = "2.1" | ||||
| rand = "0.8" | ||||
| float-ord = "0.3" | ||||
|  | ||||
| [dev-dependencies] | ||||
| criterion = "0.5" | ||||
|   | ||||
| @@ -9,7 +9,8 @@ use battlesnake::types::{ | ||||
|     wire::{Request, Response}, | ||||
|     Direction, | ||||
| }; | ||||
| use log::{debug, error, info, warn}; | ||||
| use float_ord::FloatOrd; | ||||
| use log::{debug, error, info, trace, warn}; | ||||
| use rand::prelude::*; | ||||
| use serde::Serialize; | ||||
| use tokio::{ | ||||
| @@ -87,7 +88,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> { | ||||
|         let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name { | ||||
|             "solo" => &|board: &Board| { | ||||
|                 if board.num_snakes() == 0 | ||||
|                     || board.turn() > base_turns + u32::from(request.you.length) * 3 | ||||
|                     || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) | ||||
|                 { | ||||
|                     Some(board.turn()) | ||||
|                 } else { | ||||
| @@ -95,7 +96,8 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> { | ||||
|                 } | ||||
|             }, | ||||
|             _ => &|board: &Board| { | ||||
|                 if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3 | ||||
|                 if board.num_snakes() <= 1 | ||||
|                     || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32) | ||||
|                 { | ||||
|                     Some(u32::from(board.alive(id))) | ||||
|                 } else { | ||||
| @@ -104,30 +106,58 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> { | ||||
|             }, | ||||
|         }; | ||||
|  | ||||
|         let mut action_data = [(0, 0); 4]; | ||||
|         let mut total_simulations = 0; | ||||
|         while start.elapsed() < Duration::from_millis(250) { | ||||
|         let mut mcts_manager = MctsManager::new(id); | ||||
|         let c = f32::sqrt(2.0); | ||||
|         'outer: while start.elapsed() < Duration::from_millis(250) { | ||||
|             let mut board = board.clone(); | ||||
|             let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up); | ||||
|             board.next_turn(&[(id, action)]); | ||||
|             while let Some(action) = mcts_manager.next_action(&board, c) { | ||||
|                 board.next_turn(&[(id, action)]); | ||||
|                 if let Some(score) = end_condition(&board) { | ||||
|                     mcts_manager.apply_score(score); | ||||
|                     continue 'outer; | ||||
|                 } | ||||
|             } | ||||
|             let score = board.simulate_random(end_condition); | ||||
|             let action_data = &mut action_data[usize::from(action)]; | ||||
|             action_data.0 += score; | ||||
|             action_data.1 += 1; | ||||
|             total_simulations += 1; | ||||
|             mcts_manager.apply_score(score); | ||||
|         } | ||||
|         for action in actions { | ||||
|             let score = mcts_manager.base.next[usize::from(action)] | ||||
|                 .as_ref() | ||||
|                 .map(|info| info.score as f32 / info.played as f32); | ||||
|             if let Some(score) = score { | ||||
|                 info!("{action:?} -> {score}"); | ||||
|             } else { | ||||
|                 info!("{action:?} -> None"); | ||||
|             } | ||||
|         } | ||||
|         debug!("action data: {action_data:?}"); | ||||
|  | ||||
|         let action = actions.into_iter().max_by(|lhs, rhs| { | ||||
|             let lhs_data = action_data[usize::from(*lhs)]; | ||||
|             let rhs_data = action_data[usize::from(*rhs)]; | ||||
|             (u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1))) | ||||
|         }); | ||||
|         let action = mcts_manager | ||||
|             .base | ||||
|             .next | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .filter_map(|(index, info)| { | ||||
|                 info.as_ref().map(|info| { | ||||
|                     ( | ||||
|                         match index { | ||||
|                             0 => Direction::Up, | ||||
|                             1 => Direction::Down, | ||||
|                             2 => Direction::Left, | ||||
|                             3 => Direction::Right, | ||||
|                             _ => unreachable!(), | ||||
|                         }, | ||||
|                         info, | ||||
|                     ) | ||||
|                 }) | ||||
|             }) | ||||
|             .max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32)) | ||||
|             .map(|(action, _)| action); | ||||
|  | ||||
|         if let Some(action) = action { | ||||
|             let action_data = action_data[usize::from(action)]; | ||||
|             let avg_turns = action_data.0 / action_data.1; | ||||
|             info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns."); | ||||
|             info!( | ||||
|                 "found action {action:?} after {} simulations.", | ||||
|                 mcts_manager.base.played | ||||
|             ); | ||||
|         } else { | ||||
|             warn!("unable to find a valid action"); | ||||
|         } | ||||
| @@ -136,10 +166,119 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> { | ||||
|             direction: action.unwrap_or(Direction::Up), | ||||
|             shout: None, | ||||
|         }) | ||||
|     }).await.unwrap() | ||||
|     }) | ||||
|     .await | ||||
|     .unwrap() | ||||
| } | ||||
|  | ||||
| async fn end(request: Json<Request>) { | ||||
|     let board = Board::from(&*request); | ||||
|     info!("got end request: {board}"); | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| struct ActionInfo { | ||||
|     score: u32, | ||||
|     played: u32, | ||||
|     next: Box<[Option<ActionInfo>; 4]>, | ||||
| } | ||||
|  | ||||
| impl ActionInfo { | ||||
|     fn new() -> Self { | ||||
|         Self { | ||||
|             score: 0, | ||||
|             played: 0, | ||||
|             next: Box::new([None, None, None, None]), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn uct(&self, c: f32) -> [Option<f32>; 4] { | ||||
|         let mut ucts = [None; 4]; | ||||
|         for (action, uct) in self.next.iter().zip(ucts.iter_mut()) { | ||||
|             if let Some(action) = action { | ||||
|                 let exploitation = action.score as f32 / action.played as f32; | ||||
|                 let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32); | ||||
|                 uct.replace(c.mul_add(exploration, exploitation)); | ||||
|             } | ||||
|         } | ||||
|         ucts | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| struct MctsManager { | ||||
|     base: ActionInfo, | ||||
|     actions: Vec<Direction>, | ||||
|     expanded: bool, | ||||
|     snake: u8, | ||||
| } | ||||
|  | ||||
| impl MctsManager { | ||||
|     fn new(snake: u8) -> Self { | ||||
|         Self { | ||||
|             base: ActionInfo::new(), | ||||
|             actions: Vec::new(), | ||||
|             expanded: false, | ||||
|             snake, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn apply_score(&mut self, score: u32) { | ||||
|         self.base.played += 1; | ||||
|         self.base.score += score; | ||||
|         let mut current = &mut self.base; | ||||
|         for action in &self.actions { | ||||
|             let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else { | ||||
|                 error!("got action without actioninfo"); | ||||
|                 break; | ||||
|             }; | ||||
|             current = new_current; | ||||
|             current.played += 1; | ||||
|             current.score += score; | ||||
|         } | ||||
|         self.actions.clear(); | ||||
|         self.expanded = false; | ||||
|     } | ||||
|  | ||||
|     fn next_action(&mut self, board: &Board, c: f32) -> Option<Direction> { | ||||
|         if self.expanded { | ||||
|             return None; | ||||
|         } | ||||
|  | ||||
|         let mut current = &mut self.base; | ||||
|         for action in &self.actions { | ||||
|             let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else { | ||||
|                 error!("got action without actioninfo"); | ||||
|                 return None; | ||||
|             }; | ||||
|             current = new_current; | ||||
|         } | ||||
|  | ||||
|         let ucts = current.uct(c); | ||||
|         let valid_actions = board.valid_actions(self.snake); | ||||
|         let ucts: Vec<_> = valid_actions | ||||
|             .map(|action| (action, ucts[usize::from(action)])) | ||||
|             .collect(); | ||||
|         trace!("got actions: {ucts:?}"); | ||||
|         if ucts.iter().any(|(_, uct)| uct.is_none()) { | ||||
|             let action = ucts | ||||
|                 .iter() | ||||
|                 .filter(|(_, uct)| uct.is_none()) | ||||
|                 .choose(&mut thread_rng())? | ||||
|                 .0; | ||||
|             self.expanded = true; | ||||
|             current.next[usize::from(action)].replace(ActionInfo::new()); | ||||
|             self.actions.push(action); | ||||
|             return Some(action); | ||||
|         } | ||||
|  | ||||
|         let action = ucts | ||||
|             .iter() | ||||
|             .max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY))) | ||||
|             .map(|(action, _)| *action); | ||||
|         if let Some(action) = action { | ||||
|             self.actions.push(action); | ||||
|         } | ||||
|         action | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user