This commit is contained in:
parent
3f91660583
commit
6f04d7cb7f
7
Cargo.lock
generated
7
Cargo.lock
generated
@ -182,6 +182,7 @@ dependencies = [
|
|||||||
"criterion",
|
"criterion",
|
||||||
"enum-iterator",
|
"enum-iterator",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"float-ord",
|
||||||
"log",
|
"log",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
@ -469,6 +470,12 @@ version = "1.0.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "float-ord"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
|
@ -29,6 +29,7 @@ env_logger = "0.11"
|
|||||||
bitvec = "1.0"
|
bitvec = "1.0"
|
||||||
enum-iterator = "2.1"
|
enum-iterator = "2.1"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
|
float-ord = "0.3"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
|
@ -9,7 +9,8 @@ use battlesnake::types::{
|
|||||||
wire::{Request, Response},
|
wire::{Request, Response},
|
||||||
Direction,
|
Direction,
|
||||||
};
|
};
|
||||||
use log::{debug, error, info, warn};
|
use float_ord::FloatOrd;
|
||||||
|
use log::{debug, error, info, trace, warn};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
@ -87,7 +88,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name {
|
let end_condition: &dyn Fn(&Board) -> Option<u32> = match &*request.game.ruleset.name {
|
||||||
"solo" => &|board: &Board| {
|
"solo" => &|board: &Board| {
|
||||||
if board.num_snakes() == 0
|
if board.num_snakes() == 0
|
||||||
|| board.turn() > base_turns + u32::from(request.you.length) * 3
|
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
|
||||||
{
|
{
|
||||||
Some(board.turn())
|
Some(board.turn())
|
||||||
} else {
|
} else {
|
||||||
@ -95,7 +96,8 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
_ => &|board: &Board| {
|
_ => &|board: &Board| {
|
||||||
if board.num_snakes() <= 1 || board.turn() > base_turns + u32::from(request.you.length) * 3
|
if board.num_snakes() <= 1
|
||||||
|
|| board.turn() > base_turns + (u32::from(request.you.length) * 3).min(32)
|
||||||
{
|
{
|
||||||
Some(u32::from(board.alive(id)))
|
Some(u32::from(board.alive(id)))
|
||||||
} else {
|
} else {
|
||||||
@ -104,30 +106,58 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut action_data = [(0, 0); 4];
|
let mut mcts_manager = MctsManager::new(id);
|
||||||
let mut total_simulations = 0;
|
let c = f32::sqrt(2.0);
|
||||||
while start.elapsed() < Duration::from_millis(250) {
|
'outer: while start.elapsed() < Duration::from_millis(250) {
|
||||||
let mut board = board.clone();
|
let mut board = board.clone();
|
||||||
let action = *actions.choose(&mut thread_rng()).unwrap_or(&Direction::Up);
|
while let Some(action) = mcts_manager.next_action(&board, c) {
|
||||||
board.next_turn(&[(id, action)]);
|
board.next_turn(&[(id, action)]);
|
||||||
|
if let Some(score) = end_condition(&board) {
|
||||||
|
mcts_manager.apply_score(score);
|
||||||
|
continue 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
let score = board.simulate_random(end_condition);
|
let score = board.simulate_random(end_condition);
|
||||||
let action_data = &mut action_data[usize::from(action)];
|
mcts_manager.apply_score(score);
|
||||||
action_data.0 += score;
|
}
|
||||||
action_data.1 += 1;
|
for action in actions {
|
||||||
total_simulations += 1;
|
let score = mcts_manager.base.next[usize::from(action)]
|
||||||
|
.as_ref()
|
||||||
|
.map(|info| info.score as f32 / info.played as f32);
|
||||||
|
if let Some(score) = score {
|
||||||
|
info!("{action:?} -> {score}");
|
||||||
|
} else {
|
||||||
|
info!("{action:?} -> None");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
debug!("action data: {action_data:?}");
|
|
||||||
|
|
||||||
let action = actions.into_iter().max_by(|lhs, rhs| {
|
let action = mcts_manager
|
||||||
let lhs_data = action_data[usize::from(*lhs)];
|
.base
|
||||||
let rhs_data = action_data[usize::from(*rhs)];
|
.next
|
||||||
(u64::from(lhs_data.0) * u64::from(rhs_data.1)).cmp(&(u64::from(rhs_data.0) * u64::from(lhs_data.1)))
|
.iter()
|
||||||
});
|
.enumerate()
|
||||||
|
.filter_map(|(index, info)| {
|
||||||
|
info.as_ref().map(|info| {
|
||||||
|
(
|
||||||
|
match index {
|
||||||
|
0 => Direction::Up,
|
||||||
|
1 => Direction::Down,
|
||||||
|
2 => Direction::Left,
|
||||||
|
3 => Direction::Right,
|
||||||
|
_ => unreachable!(),
|
||||||
|
},
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
|
||||||
|
.map(|(action, _)| action);
|
||||||
|
|
||||||
if let Some(action) = action {
|
if let Some(action) = action {
|
||||||
let action_data = action_data[usize::from(action)];
|
info!(
|
||||||
let avg_turns = action_data.0 / action_data.1;
|
"found action {action:?} after {} simulations.",
|
||||||
info!("found action {action:?} after {total_simulations} simulations with an average of {avg_turns} turns.");
|
mcts_manager.base.played
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
warn!("unable to find a valid action");
|
warn!("unable to find a valid action");
|
||||||
}
|
}
|
||||||
@ -136,10 +166,119 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
direction: action.unwrap_or(Direction::Up),
|
direction: action.unwrap_or(Direction::Up),
|
||||||
shout: None,
|
shout: None,
|
||||||
})
|
})
|
||||||
}).await.unwrap()
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn end(request: Json<Request>) {
|
async fn end(request: Json<Request>) {
|
||||||
let board = Board::from(&*request);
|
let board = Board::from(&*request);
|
||||||
info!("got end request: {board}");
|
info!("got end request: {board}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ActionInfo {
|
||||||
|
score: u32,
|
||||||
|
played: u32,
|
||||||
|
next: Box<[Option<ActionInfo>; 4]>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActionInfo {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
score: 0,
|
||||||
|
played: 0,
|
||||||
|
next: Box::new([None, None, None, None]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uct(&self, c: f32) -> [Option<f32>; 4] {
|
||||||
|
let mut ucts = [None; 4];
|
||||||
|
for (action, uct) in self.next.iter().zip(ucts.iter_mut()) {
|
||||||
|
if let Some(action) = action {
|
||||||
|
let exploitation = action.score as f32 / action.played as f32;
|
||||||
|
let exploration = f32::sqrt(f32::ln(self.played as f32) / action.played as f32);
|
||||||
|
uct.replace(c.mul_add(exploration, exploitation));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ucts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MctsManager {
|
||||||
|
base: ActionInfo,
|
||||||
|
actions: Vec<Direction>,
|
||||||
|
expanded: bool,
|
||||||
|
snake: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MctsManager {
|
||||||
|
fn new(snake: u8) -> Self {
|
||||||
|
Self {
|
||||||
|
base: ActionInfo::new(),
|
||||||
|
actions: Vec::new(),
|
||||||
|
expanded: false,
|
||||||
|
snake,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_score(&mut self, score: u32) {
|
||||||
|
self.base.played += 1;
|
||||||
|
self.base.score += score;
|
||||||
|
let mut current = &mut self.base;
|
||||||
|
for action in &self.actions {
|
||||||
|
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
|
||||||
|
error!("got action without actioninfo");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
current = new_current;
|
||||||
|
current.played += 1;
|
||||||
|
current.score += score;
|
||||||
|
}
|
||||||
|
self.actions.clear();
|
||||||
|
self.expanded = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_action(&mut self, board: &Board, c: f32) -> Option<Direction> {
|
||||||
|
if self.expanded {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut current = &mut self.base;
|
||||||
|
for action in &self.actions {
|
||||||
|
let Some(ref mut new_current) = &mut current.next[usize::from(*action)] else {
|
||||||
|
error!("got action without actioninfo");
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
current = new_current;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ucts = current.uct(c);
|
||||||
|
let valid_actions = board.valid_actions(self.snake);
|
||||||
|
let ucts: Vec<_> = valid_actions
|
||||||
|
.map(|action| (action, ucts[usize::from(action)]))
|
||||||
|
.collect();
|
||||||
|
trace!("got actions: {ucts:?}");
|
||||||
|
if ucts.iter().any(|(_, uct)| uct.is_none()) {
|
||||||
|
let action = ucts
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, uct)| uct.is_none())
|
||||||
|
.choose(&mut thread_rng())?
|
||||||
|
.0;
|
||||||
|
self.expanded = true;
|
||||||
|
current.next[usize::from(action)].replace(ActionInfo::new());
|
||||||
|
self.actions.push(action);
|
||||||
|
return Some(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
let action = ucts
|
||||||
|
.iter()
|
||||||
|
.max_by_key(|(_, uct)| FloatOrd(uct.unwrap_or(f32::NEG_INFINITY)))
|
||||||
|
.map(|(action, _)| *action);
|
||||||
|
if let Some(action) = action {
|
||||||
|
self.actions.push(action);
|
||||||
|
}
|
||||||
|
action
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user