Compare commits
2 Commits
e5600fe038
...
b4b332bdbb
Author | SHA1 | Date | |
---|---|---|---|
b4b332bdbb | |||
b97d7c895a |
13
Cargo.lock
generated
13
Cargo.lock
generated
@ -183,6 +183,7 @@ dependencies = [
|
||||
"enum-iterator",
|
||||
"env_logger",
|
||||
"float-ord",
|
||||
"futures-util",
|
||||
"log",
|
||||
"rand",
|
||||
"serde",
|
||||
@ -516,6 +517,17 @@ version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.31"
|
||||
@ -535,6 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"pin-project-lite",
|
||||
|
@ -3,3 +3,7 @@ members = ["battlesnake", "xtask"]
|
||||
resolver = "3"
|
||||
|
||||
default-members = ["battlesnake"]
|
||||
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
@ -30,6 +30,7 @@ bitvec = "1.0"
|
||||
enum-iterator = "2.1"
|
||||
rand = "0.9"
|
||||
float-ord = "0.3"
|
||||
futures-util = "0.3.31"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
|
@ -12,6 +12,7 @@ use battlesnake::types::{
|
||||
wire::{Request, Response},
|
||||
};
|
||||
use float_ord::FloatOrd;
|
||||
use futures_util::future::join_all;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use rand::prelude::*;
|
||||
use serde::Serialize;
|
||||
@ -88,120 +89,126 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
||||
}
|
||||
info!("valid actions: {actions:?}");
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
if start.elapsed() > Duration::from_millis(10) {
|
||||
error!(
|
||||
"The calculation started late ({}ms)",
|
||||
start.elapsed().as_millis()
|
||||
);
|
||||
}
|
||||
let end_condition: fn(&Board) -> Option<()> = match &*request.game.ruleset.name {
|
||||
"solo" => end_solo,
|
||||
_ => end_standard,
|
||||
};
|
||||
let score_fn: fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
|
||||
"solo" => score_solo,
|
||||
_ => score_standard,
|
||||
};
|
||||
|
||||
let action_futures = (0..3).map(|_| {
|
||||
let request = request.clone();
|
||||
let board = board.clone();
|
||||
let mut rng = SmallRng::from_os_rng();
|
||||
if start.elapsed() > Duration::from_millis(10) {
|
||||
error!(
|
||||
"The calculation started late ({}ms)",
|
||||
start.elapsed().as_millis()
|
||||
);
|
||||
}
|
||||
let base_turns = board.turn();
|
||||
let rolling_horizon = (u32::from(request.you.length) * 3).max(32);
|
||||
let last_turn = base_turns + rolling_horizon;
|
||||
let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name {
|
||||
"solo" => &|board| {
|
||||
if board.valid_actions(0).count() == 0 || board.turn() > last_turn {
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
let actions = actions.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
|
||||
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
|
||||
.collect();
|
||||
let c = f32::sqrt(2.0);
|
||||
let mut mcts_actions = Vec::new();
|
||||
while start.elapsed() < timeout * 4 / 5 {
|
||||
let mut board = board.clone();
|
||||
while end_condition(&board).is_none() {
|
||||
mcts_actions.clear();
|
||||
mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| {
|
||||
mcts_manager
|
||||
.next_action(&board, c, &mut rng)
|
||||
.map(|action| (mcts_manager.snake, action))
|
||||
}));
|
||||
board.next_turn(&mcts_actions, &mut rng);
|
||||
if mcts_actions.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => &|board| {
|
||||
if board.num_snakes() <= 1 || board.turn() > last_turn {
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
|
||||
"solo" => &|board, id| u32::try_from(board.length(id)).unwrap_or(0),
|
||||
_ => &|board, id| {
|
||||
if board.alive(id) {
|
||||
1 + u32::from(board.max_length() == board.length(id))
|
||||
} else {
|
||||
0
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
|
||||
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
|
||||
.collect();
|
||||
let c = f32::sqrt(2.0);
|
||||
let mut mcts_actions = Vec::new();
|
||||
while start.elapsed() < timeout * 4 / 5 {
|
||||
let mut board = board.clone();
|
||||
while end_condition(&board).is_none() {
|
||||
mcts_actions.clear();
|
||||
mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| {
|
||||
mcts_manager
|
||||
.next_action(&board, c, &mut rng)
|
||||
.map(|action| (mcts_manager.snake, action))
|
||||
}));
|
||||
board.next_turn(&mcts_actions, &mut rng);
|
||||
if mcts_actions.is_empty() {
|
||||
break;
|
||||
board.simulate_random(&mut rng, end_condition);
|
||||
for mcts_manager in &mut mcts_managers {
|
||||
let id = mcts_manager.snake;
|
||||
let score = score_fn(&board, id);
|
||||
mcts_manager.apply_score(score);
|
||||
}
|
||||
}
|
||||
board.simulate_random(&mut rng, end_condition);
|
||||
for mcts_manager in &mut mcts_managers {
|
||||
let id = mcts_manager.snake;
|
||||
let score = score_fn(&board, id);
|
||||
mcts_manager.apply_score(score);
|
||||
let my_mcts_manager = mcts_managers.into_iter().nth(usize::from(id)).unwrap();
|
||||
for action in actions {
|
||||
let score = my_mcts_manager.base.next[usize::from(action)]
|
||||
.as_ref()
|
||||
.map(|info| info.score as f32 / info.played as f32);
|
||||
if let Some(score) = score {
|
||||
info!("{action:?} -> {score}");
|
||||
} else {
|
||||
info!("{action:?} -> None");
|
||||
}
|
||||
}
|
||||
}
|
||||
let my_mcts_manager = &mcts_managers[usize::from(id)];
|
||||
for action in actions {
|
||||
let score = my_mcts_manager.base.next[usize::from(action)]
|
||||
.as_ref()
|
||||
.map(|info| info.score as f32 / info.played as f32);
|
||||
if let Some(score) = score {
|
||||
info!("{action:?} -> {score}");
|
||||
} else {
|
||||
info!("{action:?} -> None");
|
||||
}
|
||||
}
|
||||
|
||||
let action = my_mcts_manager
|
||||
.base
|
||||
.next
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, info)| {
|
||||
info.as_ref().map(|info| {
|
||||
(
|
||||
match index {
|
||||
0 => Direction::Up,
|
||||
1 => Direction::Down,
|
||||
2 => Direction::Left,
|
||||
3 => Direction::Right,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
info,
|
||||
)
|
||||
})
|
||||
})
|
||||
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
|
||||
.map(|(action, _)| action);
|
||||
let actions = my_mcts_manager.base.next.map(|action| {
|
||||
action.map_or(0.0, |action| action.score as f32 / action.played as f32)
|
||||
});
|
||||
|
||||
if let Some(action) = action {
|
||||
info!(
|
||||
"found action {action:?} after {} simulations.",
|
||||
my_mcts_manager.base.played
|
||||
);
|
||||
} else {
|
||||
warn!("unable to find a valid action");
|
||||
}
|
||||
info!("chose {action:?}");
|
||||
response::Json(Response {
|
||||
direction: action.unwrap_or(Direction::Up),
|
||||
shout: None,
|
||||
(actions, my_mcts_manager.base.played)
|
||||
})
|
||||
});
|
||||
let (actions, played) = join_all(action_futures).await.into_iter().fold(
|
||||
([0.0; 4], 0),
|
||||
|(mut total, mut games), actions| {
|
||||
if let Ok((actions, new_games)) = actions {
|
||||
for i in 0..total.len() {
|
||||
total[i] += actions[i];
|
||||
}
|
||||
games += new_games;
|
||||
}
|
||||
(total, games)
|
||||
},
|
||||
);
|
||||
let action = actions
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.max_by_key(|(_, score)| FloatOrd(*score))
|
||||
.map(|(index, _)| match index {
|
||||
0 => Direction::Up,
|
||||
1 => Direction::Down,
|
||||
2 => Direction::Left,
|
||||
3 => Direction::Right,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
if let Some(action) = action {
|
||||
info!("found action {action:?} after {played} simulations.",);
|
||||
} else {
|
||||
warn!("unable to find a valid action");
|
||||
}
|
||||
info!("chose {action:?}");
|
||||
response::Json(Response {
|
||||
direction: action.unwrap_or(Direction::Up),
|
||||
shout: None,
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn end_solo(board: &Board) -> Option<()> {
|
||||
(board.valid_actions(0).count() == 0).then_some(())
|
||||
}
|
||||
|
||||
fn end_standard(board: &Board) -> Option<()> {
|
||||
(board.num_snakes() <= 1).then_some(())
|
||||
}
|
||||
|
||||
fn score_solo(board: &Board, id: u8) -> u32 {
|
||||
u32::try_from(board.length(id)).unwrap_or(0)
|
||||
}
|
||||
|
||||
fn score_standard(board: &Board, id: u8) -> u32 {
|
||||
if board.alive(id) {
|
||||
1 + u32::from(board.max_length() == board.length(id))
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
async fn end(request: Json<Request>) {
|
||||
|
@ -33,7 +33,7 @@ impl Coord {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Sequence)]
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Sequence)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Direction {
|
||||
/// Move in positive y direction
|
||||
|
@ -265,7 +265,7 @@ fn try_regression() -> Result<(usize, usize, usize), DynError> {
|
||||
const GAMES: usize = 100;
|
||||
// limit the parallelism
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(std::thread::available_parallelism()?.get() / 4)
|
||||
.num_threads(std::thread::available_parallelism()?.get() / 8)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
|
Reference in New Issue
Block a user