single agent mcts
All checks were successful
Build / build (push) Successful in 2m30s

This commit is contained in:
Max Känner 2025-01-25 23:49:05 +01:00
parent 3f91660583
commit 6f04d7cb7f
3 changed files with 169 additions and 22 deletions

7
Cargo.lock generated
View File

@ -182,6 +182,7 @@ dependencies = [
"criterion",
"enum-iterator",
"env_logger",
"float-ord",
"log",
"rand",
"serde",
@ -469,6 +470,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "float-ord"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
[[package]]
name = "fnv"
version = "1.0.7"

View File

@ -29,6 +29,7 @@ env_logger = "0.11"
bitvec = "1.0"
enum-iterator = "2.1"
rand = "0.8"
float-ord = "0.3"
[dev-dependencies]
criterion = "0.5"

View File

@ -9,7 +9,8 @@ use battlesnake::types::{
wire::{Request, Response},
Direction,
};
use log::{debug, error, info, warn};
use float_ord::FloatOrd;
use log::{debug, error, info, trace, warn};
use rand::prelude::*;
use serde::Serialize;
use tokio::{
@ -87,7 +88,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name {
"solo" => &|board: &Board| {
if board.num_snakes() == 0
|| board.turn() > base_turns + u32::from(request.you.length) * 3
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{
Some(board.turn())
} else {
@ -95,7 +96,8 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
}
},
_ => &|board: &Board| {
if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3
if board.num_snakes() <= 1
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
{
Some(u32::from(board.alive(id)))
} else {
@ -104,30 +106,58 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
},
};
let mut action_data = [(0, 0); 4];
let mut total_simulations = 0;
while start.elapsed() < Duration::from_millis(250) {
let mut mcts_manager = MctsManager::new(id);
let c = f32::sqrt(2.0);
'outer: while start.elapsed() < Duration::from_millis(250) {
let mut board = board.clone();
let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up);
board.next_turn(&[(id, action)]);
while let Some(action) = mcts_manager.next_action(&board, c) {
board.next_turn(&[(id, action)]);
if let Some(score) = end_condition(&board) {
mcts_manager.apply_score(score);
continue 'outer;
}
}
let score = board.simulate_random(end_condition);
let action_data = &mut action_data[usize::from(action)];
action_data.0 += score;
action_data.1 += 1;
total_simulations += 1;
mcts_manager.apply_score(score);
}
for action in actions {
let score = mcts_manager.base.next[usize::from(action)]
.as_ref()
.map(|info| info.score as f32 / info.played as f32);
if let Some(score) = score {
info!("{action:?} -> {score}");
} else {
info!("{action:?} -> None");
}
}
debug!("action data: {action_data:?}");
let action = actions.into_iter().max_by(|lhs, rhs| {
let lhs_data = action_data[usize::from(*lhs)];
let rhs_data = action_data[usize::from(*rhs)];
(u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1)))
});
let action = mcts_manager
.base
.next
.iter()
.enumerate()
.filter_map(|(index, info)| {
info.as_ref().map(|info| {
(
match index {
0 => Direction::Up,
1 => Direction::Down,
2 => Direction::Left,
3 => Direction::Right,
_ => unreachable!(),
},
info,
)
})
})
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
.map(|(action, _)| action);
if let Some(action) = action {
let action_data = action_data[usize::from(action)];
let avg_turns = action_data.0 / action_data.1;
info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns.");
info!(
"found action {action:?} after {} simulations.",
mcts_manager.base.played
);
} else {
warn!("unable to find a valid action");
}
@ -136,10 +166,119 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
direction: action.unwrap_or(Direction::Up),
shout: None,
})
}).await.unwrap()
})
.await
.unwrap()
}
async fn end(request: Json<Request>) {
let board = Board::from(&*request);
info!("got end request: {board}");
}
#[derive(Debug)]
struct ActionInfo {
score: u32,
played: u32,
next: Box<[Option<ActionInfo>; 4]>,
}
impl ActionInfo {
fn new() -> Self {
Self {
score: 0,
played: 0,
next: Box::new([None, None, None, None]),
}
}
fn uct(&self, c: f32) -> [Option<f32>; 4] {
let mut ucts = [None; 4];
for (action, uct) in self.next.iter().zip(ucts.iter_mut()) {
if let Some(action) = action {
let exploitation = action.score as f32 / action.played as f32;
let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32);
uct.replace(c.mul_add(exploration, exploitation));
}
}
ucts
}
}
#[derive(Debug)]
struct MctsManager {
base: ActionInfo,
actions: Vec<Direction>,
expanded: bool,
snake: u8,
}
impl MctsManager {
fn new(snake: u8) -> Self {
Self {
base: ActionInfo::new(),
actions: Vec::new(),
expanded: false,
snake,
}
}
fn apply_score(&mut self, score: u32) {
self.base.played += 1;
self.base.score += score;
let mut current = &mut self.base;
for action in &self.actions {
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
error!("got action without actioninfo");
break;
};
current = new_current;
current.played += 1;
current.score += score;
}
self.actions.clear();
self.expanded = false;
}
fn next_action(&mut self, board: &Board, c: f32) -> Option<Direction> {
if self.expanded {
return None;
}
let mut current = &mut self.base;
for action in &self.actions {
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
error!("got action without actioninfo");
return None;
};
current = new_current;
}
let ucts = current.uct(c);
let valid_actions = board.valid_actions(self.snake);
let ucts: Vec<_> = valid_actions
.map(|action| (action, ucts[usize::from(action)]))
.collect();
trace!("got actions: {ucts:?}");
if ucts.iter().any(|(_, uct)| uct.is_none()) {
let action = ucts
.iter()
.filter(|(_, uct)| uct.is_none())
.choose(&mut thread_rng())?
.0;
self.expanded = true;
current.next[usize::from(action)].replace(ActionInfo::new());
self.actions.push(action);
return Some(action);
}
let action = ucts
.iter()
.max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY)))
.map(|(action, _)| *action);
if let Some(action) = action {
self.actions.push(action);
}
action
}
}