This commit is contained in:
parent
3f91660583
commit
6f04d7cb7f
7
Cargo.lock
generated
7
Cargo.lock
generated
@ -182,6 +182,7 @@ dependencies = [
|
||||
"criterion",
|
||||
"enum-iterator",
|
||||
"env_logger",
|
||||
"float-ord",
|
||||
"log",
|
||||
"rand",
|
||||
"serde",
|
||||
@ -469,6 +470,12 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "float-ord"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -29,6 +29,7 @@ env_logger = "0.11"
|
||||
bitvec = "1.0"
|
||||
enum-iterator = "2.1"
|
||||
rand = "0.8"
|
||||
float-ord = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
|
@ -9,7 +9,8 @@ use battlesnake::types::{
|
||||
wire::{Request, Response},
|
||||
Direction,
|
||||
};
|
||||
use log::{debug, error, info, warn};
|
||||
use float_ord::FloatOrd;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use rand::prelude::*;
|
||||
use serde::Serialize;
|
||||
use tokio::{
|
||||
@ -87,7 +88,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
||||
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name {
|
||||
"solo" => &|board: &Board| {
|
||||
if board.num_snakes() == 0
|
||||
|| board.turn() > base_turns + u32::from(request.you.length) * 3
|
||||
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
|
||||
{
|
||||
Some(board.turn())
|
||||
} else {
|
||||
@ -95,7 +96,8 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
||||
}
|
||||
},
|
||||
_ => &|board: &Board| {
|
||||
if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3
|
||||
if board.num_snakes() <= 1
|
||||
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
|
||||
{
|
||||
Some(u32::from(board.alive(id)))
|
||||
} else {
|
||||
@ -104,30 +106,58 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
||||
},
|
||||
};
|
||||
|
||||
let mut action_data = [(0, 0); 4];
|
||||
let mut total_simulations = 0;
|
||||
while start.elapsed() < Duration::from_millis(250) {
|
||||
let mut mcts_manager = MctsManager::new(id);
|
||||
let c = f32::sqrt(2.0);
|
||||
'outer: while start.elapsed() < Duration::from_millis(250) {
|
||||
let mut board = board.clone();
|
||||
let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up);
|
||||
while let Some(action) = mcts_manager.next_action(&board, c) {
|
||||
board.next_turn(&[(id, action)]);
|
||||
let score = board.simulate_random(end_condition);
|
||||
let action_data = &mut action_data[usize::from(action)];
|
||||
action_data.0 += score;
|
||||
action_data.1 += 1;
|
||||
total_simulations += 1;
|
||||
if let Some(score) = end_condition(&board) {
|
||||
mcts_manager.apply_score(score);
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
let score = board.simulate_random(end_condition);
|
||||
mcts_manager.apply_score(score);
|
||||
}
|
||||
for action in actions {
|
||||
let score = mcts_manager.base.next[usize::from(action)]
|
||||
.as_ref()
|
||||
.map(|info| info.score as f32 / info.played as f32);
|
||||
if let Some(score) = score {
|
||||
info!("{action:?} -> {score}");
|
||||
} else {
|
||||
info!("{action:?} -> None");
|
||||
}
|
||||
}
|
||||
debug!("action data: {action_data:?}");
|
||||
|
||||
let action = actions.into_iter().max_by(|lhs, rhs| {
|
||||
let lhs_data = action_data[usize::from(*lhs)];
|
||||
let rhs_data = action_data[usize::from(*rhs)];
|
||||
(u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1)))
|
||||
});
|
||||
let action = mcts_manager
|
||||
.base
|
||||
.next
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, info)| {
|
||||
info.as_ref().map(|info| {
|
||||
(
|
||||
match index {
|
||||
0 => Direction::Up,
|
||||
1 => Direction::Down,
|
||||
2 => Direction::Left,
|
||||
3 => Direction::Right,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
info,
|
||||
)
|
||||
})
|
||||
})
|
||||
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
|
||||
.map(|(action, _)| action);
|
||||
|
||||
if let Some(action) = action {
|
||||
let action_data = action_data[usize::from(action)];
|
||||
let avg_turns = action_data.0 / action_data.1;
|
||||
info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns.");
|
||||
info!(
|
||||
"found action {action:?} after {} simulations.",
|
||||
mcts_manager.base.played
|
||||
);
|
||||
} else {
|
||||
warn!("unable to find a valid action");
|
||||
}
|
||||
@ -136,10 +166,119 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
||||
direction: action.unwrap_or(Direction::Up),
|
||||
shout: None,
|
||||
})
|
||||
}).await.unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn end(request: Json<Request>) {
|
||||
let board = Board::from(&*request);
|
||||
info!("got end request: {board}");
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ActionInfo {
|
||||
score: u32,
|
||||
played: u32,
|
||||
next: Box<[Option<ActionInfo>; 4]>,
|
||||
}
|
||||
|
||||
impl ActionInfo {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
score: 0,
|
||||
played: 0,
|
||||
next: Box::new([None, None, None, None]),
|
||||
}
|
||||
}
|
||||
|
||||
fn uct(&self, c: f32) -> [Option<f32>; 4] {
|
||||
let mut ucts = [None; 4];
|
||||
for (action, uct) in self.next.iter().zip(ucts.iter_mut()) {
|
||||
if let Some(action) = action {
|
||||
let exploitation = action.score as f32 / action.played as f32;
|
||||
let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32);
|
||||
uct.replace(c.mul_add(exploration, exploitation));
|
||||
}
|
||||
}
|
||||
ucts
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MctsManager {
|
||||
base: ActionInfo,
|
||||
actions: Vec<Direction>,
|
||||
expanded: bool,
|
||||
snake: u8,
|
||||
}
|
||||
|
||||
impl MctsManager {
|
||||
fn new(snake: u8) -> Self {
|
||||
Self {
|
||||
base: ActionInfo::new(),
|
||||
actions: Vec::new(),
|
||||
expanded: false,
|
||||
snake,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_score(&mut self, score: u32) {
|
||||
self.base.played += 1;
|
||||
self.base.score += score;
|
||||
let mut current = &mut self.base;
|
||||
for action in &self.actions {
|
||||
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
|
||||
error!("got action without actioninfo");
|
||||
break;
|
||||
};
|
||||
current = new_current;
|
||||
current.played += 1;
|
||||
current.score += score;
|
||||
}
|
||||
self.actions.clear();
|
||||
self.expanded = false;
|
||||
}
|
||||
|
||||
fn next_action(&mut self, board: &Board, c: f32) -> Option<Direction> {
|
||||
if self.expanded {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut current = &mut self.base;
|
||||
for action in &self.actions {
|
||||
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
|
||||
error!("got action without actioninfo");
|
||||
return None;
|
||||
};
|
||||
current = new_current;
|
||||
}
|
||||
|
||||
let ucts = current.uct(c);
|
||||
let valid_actions = board.valid_actions(self.snake);
|
||||
let ucts: Vec<_> = valid_actions
|
||||
.map(|action| (action, ucts[usize::from(action)]))
|
||||
.collect();
|
||||
trace!("got actions: {ucts:?}");
|
||||
if ucts.iter().any(|(_, uct)| uct.is_none()) {
|
||||
let action = ucts
|
||||
.iter()
|
||||
.filter(|(_, uct)| uct.is_none())
|
||||
.choose(&mut thread_rng())?
|
||||
.0;
|
||||
self.expanded = true;
|
||||
current.next[usize::from(action)].replace(ActionInfo::new());
|
||||
self.actions.push(action);
|
||||
return Some(action);
|
||||
}
|
||||
|
||||
let action = ucts
|
||||
.iter()
|
||||
.max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY)))
|
||||
.map(|(action, _)| *action);
|
||||
if let Some(action) = action {
|
||||
self.actions.push(action);
|
||||
}
|
||||
action
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user