single agent mcts
All checks were successful
Build / build (push) Successful in 2m30s

This commit is contained in:
Max Känner 2025-01-25 23:49:05 +01:00
parent 3f91660583
commit 6f04d7cb7f
3 changed files with 169 additions and 22 deletions

7
Cargo.lock generated
View File

@ -182,6 +182,7 @@ dependencies = [
"criterion", "criterion",
"enum-iterator", "enum-iterator",
"env_logger", "env_logger",
"float-ord",
"log", "log",
"rand", "rand",
"serde", "serde",
@ -469,6 +470,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "float-ord"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"

View File

@ -29,6 +29,7 @@ env_logger = "0.11"
bitvec = "1.0" bitvec = "1.0"
enum-iterator = "2.1" enum-iterator = "2.1"
rand = "0.8" rand = "0.8"
float-ord = "0.3"
[dev-dependencies] [dev-dependencies]
criterion = "0.5" criterion = "0.5"

View File

@ -9,7 +9,8 @@ use battlesnake::types::{
wire::{Request, Response}, wire::{Request, Response},
Direction, Direction,
}; };
use log::{debug, error, info, warn}; use float_ord::FloatOrd;
use log::{debug, error, info, trace, warn};
use rand::prelude::*; use rand::prelude::*;
use serde::Serialize; use serde::Serialize;
use tokio::{ use tokio::{
@ -87,7 +88,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name { let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name {
"solo" => &|board: &Board| { "solo" => &|board: &Board| {
if board.num_snakes() == 0 if board.num_snakes() == 0
|| board.turn() > base_turns + u32::from(request.you.length) * 3 || board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{ {
Some(board.turn()) Some(board.turn())
} else { } else {
@ -95,7 +96,8 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
} }
}, },
_ => &|board: &Board| { _ => &|board: &Board| {
if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3 if board.num_snakes() <= 1
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{ {
Some(u32::from(board.alive(id))) Some(u32::from(board.alive(id)))
} else { } else {
@ -104,30 +106,58 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
}, },
}; };
let mut action_data = [(0, 0); 4]; let mut mcts_manager = MctsManager::new(id);
let mut total_simulations = 0; let c = f32::sqrt(2.0);
while start.elapsed() < Duration::from_millis(250) { 'outer: while start.elapsed() < Duration::from_millis(250) {
let mut board = board.clone(); let mut board = board.clone();
let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up); while let Some(action) = mcts_manager.next_action(&board, c) {
board.next_turn(&[(id, action)]); board.next_turn(&[(id, action)]);
let score = board.simulate_random(end_condition); if let Some(score) = end_condition(&board) {
let action_data = &mut action_data[usize::from(action)]; mcts_manager.apply_score(score);
action_data.0 += score; continue 'outer;
action_data.1 += 1; }
total_simulations += 1; }
let score = board.simulate_random(end_condition);
mcts_manager.apply_score(score);
}
for action in actions {
let score = mcts_manager.base.next[usize::from(action)]
.as_ref()
.map(|info| info.score as f32 / info.played as f32);
if let Some(score) = score {
info!("{action:?} -> {score}");
} else {
info!("{action:?} -> None");
}
} }
debug!("action data: {action_data:?}");
let action = actions.into_iter().max_by(|lhs, rhs| { let action = mcts_manager
let lhs_data = action_data[usize::from(*lhs)]; .base
let rhs_data = action_data[usize::from(*rhs)]; .next
(u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1))) .iter()
}); .enumerate()
.filter_map(|(index, info)| {
info.as_ref().map(|info| {
(
match index {
0 => Direction::Up,
1 => Direction::Down,
2 => Direction::Left,
3 => Direction::Right,
_ => unreachable!(),
},
info,
)
})
})
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
.map(|(action, _)| action);
if let Some(action) = action { if let Some(action) = action {
let action_data = action_data[usize::from(action)]; info!(
let avg_turns = action_data.0 / action_data.1; "found action {action:?} after {} simulations.",
info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns."); mcts_manager.base.played
);
} else { } else {
warn!("unable to find a valid action"); warn!("unable to find a valid action");
} }
@ -136,10 +166,119 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
direction: action.unwrap_or(Direction::Up), direction: action.unwrap_or(Direction::Up),
shout: None, shout: None,
}) })
}).await.unwrap() })
.await
.unwrap()
} }
async fn end(request: Json<Request>) { async fn end(request: Json<Request>) {
let board = Board::from(&*request); let board = Board::from(&*request);
info!("got end request: {board}"); info!("got end request: {board}");
} }
#[derive(Debug)]
struct ActionInfo {
score: u32,
played: u32,
next: Box<[Option<ActionInfo>; 4]>,
}
impl ActionInfo {
fn new() -> Self {
Self {
score: 0,
played: 0,
next: Box::new([None, None, None, None]),
}
}
fn uct(&self, c: f32) -> [Option<f32>; 4] {
let mut ucts = [None; 4];
for (action, uct) in self.next.iter().zip(ucts.iter_mut()) {
if let Some(action) = action {
let exploitation = action.score as f32 / action.played as f32;
let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32);
uct.replace(c.mul_add(exploration, exploitation));
}
}
ucts
}
}
#[derive(Debug)]
struct MctsManager {
base: ActionInfo,
actions: Vec<Direction>,
expanded: bool,
snake: u8,
}
impl MctsManager {
fn new(snake: u8) -> Self {
Self {
base: ActionInfo::new(),
actions: Vec::new(),
expanded: false,
snake,
}
}
fn apply_score(&mut self, score: u32) {
self.base.played += 1;
self.base.score += score;
let mut current = &mut self.base;
for action in &self.actions {
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
error!("got action without actioninfo");
break;
};
current = new_current;
current.played += 1;
current.score += score;
}
self.actions.clear();
self.expanded = false;
}
fn next_action(&mut self, board: &Board, c: f32) -> Option<Direction> {
if self.expanded {
return None;
}
let mut current = &mut self.base;
for action in &self.actions {
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
error!("got action without actioninfo");
return None;
};
current = new_current;
}
let ucts = current.uct(c);
let valid_actions = board.valid_actions(self.snake);
let ucts: Vec<_> = valid_actions
.map(|action| (action, ucts[usize::from(action)]))
.collect();
trace!("got actions: {ucts:?}");
if ucts.iter().any(|(_, uct)| uct.is_none()) {
let action = ucts
.iter()
.filter(|(_, uct)| uct.is_none())
.choose(&mut thread_rng())?
.0;
self.expanded = true;
current.next[usize::from(action)].replace(ActionInfo::new());
self.actions.push(action);
return Some(action);
}
let action = ucts
.iter()
.max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY)))
.map(|(action, _)| *action);
if let Some(action) = action {
self.actions.push(action);
}
action
}
}