Compare commits

...

2 Commits

Author SHA1 Message Date
b4b332bdbb reduce regression parallelism
All checks were successful
Build / build (push) Successful in 2m52s
2025-06-06 21:16:22 +02:00
b97d7c895a add multithreading 2025-06-06 21:15:26 +02:00
6 changed files with 132 additions and 107 deletions

13
Cargo.lock generated
View File

@ -183,6 +183,7 @@ dependencies = [
"enum-iterator",
"env_logger",
"float-ord",
"futures-util",
"log",
"rand",
"serde",
@ -516,6 +517,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@ -535,6 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-core",
"futures-macro",
"futures-sink",
"futures-task",
"pin-project-lite",

View File

@ -3,3 +3,7 @@ members = ["battlesnake", "xtask"]
resolver = "3"
default-members = ["battlesnake"]
[profile.release]
lto = "fat"
codegen-units = 1

View File

@ -30,6 +30,7 @@ bitvec = "1.0"
enum-iterator = "2.1"
rand = "0.9"
float-ord = "0.3"
futures-util = "0.3.31"
[dev-dependencies]
criterion = "0.5"

View File

@ -12,6 +12,7 @@ use battlesnake::types::{
wire::{Request, Response},
};
use float_ord::FloatOrd;
use futures_util::future::join_all;
use log::{debug, error, info, trace, warn};
use rand::prelude::*;
use serde::Serialize;
@ -88,120 +89,126 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
}
info!("valid actions: {actions:?}");
tokio::task::spawn_blocking(move || {
if start.elapsed() > Duration::from_millis(10) {
error!(
"The calculation started late ({}ms)",
start.elapsed().as_millis()
);
}
let end_condition: fn(&Board) -> Option<()> = match &*request.game.ruleset.name {
"solo" => end_solo,
_ => end_standard,
};
let score_fn: fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
"solo" => score_solo,
_ => score_standard,
};
let action_futures = (0..3).map(|_| {
let request = request.clone();
let board = board.clone();
let mut rng = SmallRng::from_os_rng();
if start.elapsed() > Duration::from_millis(10) {
error!(
"The calculation started late ({}ms)",
start.elapsed().as_millis()
);
}
let base_turns = board.turn();
let rolling_horizon = (u32::from(request.you.length) * 3).max(32);
let last_turn = base_turns + rolling_horizon;
let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name {
"solo" => &|board| {
if board.valid_actions(0).count() == 0 || board.turn() > last_turn {
Some(())
} else {
None
let actions = actions.clone();
tokio::task::spawn_blocking(move || {
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
.collect();
let c = f32::sqrt(2.0);
let mut mcts_actions = Vec::new();
while start.elapsed() < timeout * 4 / 5 {
let mut board = board.clone();
while end_condition(&board).is_none() {
mcts_actions.clear();
mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| {
mcts_manager
.next_action(&board, c, &mut rng)
.map(|action| (mcts_manager.snake, action))
}));
board.next_turn(&mcts_actions, &mut rng);
if mcts_actions.is_empty() {
break;
}
}
},
_ => &|board| {
if board.num_snakes() <= 1 || board.turn() > last_turn {
Some(())
} else {
None
}
},
};
let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
"solo" => &|board, id| u32::try_from(board.length(id)).unwrap_or(0),
_ => &|board, id| {
if board.alive(id) {
1 + u32::from(board.max_length() == board.length(id))
} else {
0
}
},
};
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
.collect();
let c = f32::sqrt(2.0);
let mut mcts_actions = Vec::new();
while start.elapsed() < timeout * 4 / 5 {
let mut board = board.clone();
while end_condition(&board).is_none() {
mcts_actions.clear();
mcts_actions.extend(mcts_managers.iter_mut().filter_map(|mcts_manager| {
mcts_manager
.next_action(&board, c, &mut rng)
.map(|action| (mcts_manager.snake, action))
}));
board.next_turn(&mcts_actions, &mut rng);
if mcts_actions.is_empty() {
break;
board.simulate_random(&mut rng, end_condition);
for mcts_manager in &mut mcts_managers {
let id = mcts_manager.snake;
let score = score_fn(&board, id);
mcts_manager.apply_score(score);
}
}
board.simulate_random(&mut rng, end_condition);
for mcts_manager in &mut mcts_managers {
let id = mcts_manager.snake;
let score = score_fn(&board, id);
mcts_manager.apply_score(score);
let my_mcts_manager = mcts_managers.into_iter().nth(usize::from(id)).unwrap();
for action in actions {
let score = my_mcts_manager.base.next[usize::from(action)]
.as_ref()
.map(|info| info.score as f32 / info.played as f32);
if let Some(score) = score {
info!("{action:?} -> {score}");
} else {
info!("{action:?} -> None");
}
}
}
let my_mcts_manager = &mcts_managers[usize::from(id)];
for action in actions {
let score = my_mcts_manager.base.next[usize::from(action)]
.as_ref()
.map(|info| info.score as f32 / info.played as f32);
if let Some(score) = score {
info!("{action:?} -> {score}");
} else {
info!("{action:?} -> None");
}
}
let action = my_mcts_manager
.base
.next
.iter()
.enumerate()
.filter_map(|(index, info)| {
info.as_ref().map(|info| {
(
match index {
0 => Direction::Up,
1 => Direction::Down,
2 => Direction::Left,
3 => Direction::Right,
_ => unreachable!(),
},
info,
)
})
})
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
.map(|(action, _)| action);
let actions = my_mcts_manager.base.next.map(|action| {
action.map_or(0.0, |action| action.score as f32 / action.played as f32)
});
if let Some(action) = action {
info!(
"found action {action:?} after {} simulations.",
my_mcts_manager.base.played
);
} else {
warn!("unable to find a valid action");
}
info!("chose {action:?}");
response::Json(Response {
direction: action.unwrap_or(Direction::Up),
shout: None,
(actions, my_mcts_manager.base.played)
})
});
let (actions, played) = join_all(action_futures).await.into_iter().fold(
([0.0; 4], 0),
|(mut total, mut games), actions| {
if let Ok((actions, new_games)) = actions {
for i in 0..total.len() {
total[i] += actions[i];
}
games += new_games;
}
(total, games)
},
);
let action = actions
.into_iter()
.enumerate()
.max_by_key(|(_, score)| FloatOrd(*score))
.map(|(index, _)| match index {
0 => Direction::Up,
1 => Direction::Down,
2 => Direction::Left,
3 => Direction::Right,
_ => unreachable!(),
});
if let Some(action) = action {
info!("found action {action:?} after {played} simulations.",);
} else {
warn!("unable to find a valid action");
}
info!("chose {action:?}");
response::Json(Response {
direction: action.unwrap_or(Direction::Up),
shout: None,
})
.await
.unwrap()
}
fn end_solo(board: &Board) -> Option<()> {
(board.valid_actions(0).count() == 0).then_some(())
}
fn end_standard(board: &Board) -> Option<()> {
(board.num_snakes() <= 1).then_some(())
}
fn score_solo(board: &Board, id: u8) -> u32 {
u32::try_from(board.length(id)).unwrap_or(0)
}
fn score_standard(board: &Board, id: u8) -> u32 {
if board.alive(id) {
1 + u32::from(board.max_length() == board.length(id))
} else {
0
}
}
async fn end(request: Json<Request>) {

View File

@ -33,7 +33,7 @@ impl Coord {
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Sequence)]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Sequence)]
#[serde(rename_all = "lowercase")]
pub enum Direction {
/// Move in positive y direction

View File

@ -265,7 +265,7 @@ fn try_regression() -> Result<(usize, usize, usize), DynError> {
const GAMES: usize = 100;
// limit the parallelism
rayon::ThreadPoolBuilder::new()
.num_threads(std::thread::available_parallelism()?.get() / 4)
.num_threads(std::thread::available_parallelism()?.get() / 8)
.build_global()
.unwrap();