Compare commits
2 Commits
e5600fe038
...
b4b332bdbb
Author | SHA1 | Date | |
---|---|---|---|
b4b332bdbb | |||
b97d7c895a |
13
Cargo.lock
generated
13
Cargo.lock
generated
@ -183,6 +183,7 @@ dependencies = [
|
|||||||
"enum-iterator",
|
"enum-iterator",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"float-ord",
|
"float-ord",
|
||||||
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
@ -516,6 +517,17 @@ version = "0.3.31"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@ -535,6 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
@ -3,3 +3,7 @@ members = ["battlesnake", "xtask"]
|
|||||||
resolver = "3"
|
resolver = "3"
|
||||||
|
|
||||||
default-members = ["battlesnake"]
|
default-members = ["battlesnake"]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
|
codegen-units = 1
|
||||||
|
@ -30,6 +30,7 @@ bitvec = "1.0"
|
|||||||
enum-iterator = "2.1"
|
enum-iterator = "2.1"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
float-ord = "0.3"
|
float-ord = "0.3"
|
||||||
|
futures-util = "0.3.31"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
|
@ -12,6 +12,7 @@ use battlesnake::types::{
|
|||||||
wire::{Request, Response},
|
wire::{Request, Response},
|
||||||
};
|
};
|
||||||
use float_ord::FloatOrd;
|
use float_ord::FloatOrd;
|
||||||
|
use futures_util::future::join_all;
|
||||||
use log::{debug, error, info, trace, warn};
|
use log::{debug, error, info, trace, warn};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
@ -88,44 +89,27 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
}
|
}
|
||||||
info!("valid actions: {actions:?}");
|
info!("valid actions: {actions:?}");
|
||||||
|
|
||||||
tokio::task::spawn_blocking(move || {
|
|
||||||
let mut rng = SmallRng::from_os_rng();
|
|
||||||
if start.elapsed() > Duration::from_millis(10) {
|
if start.elapsed() > Duration::from_millis(10) {
|
||||||
error!(
|
error!(
|
||||||
"The calculation started late ({}ms)",
|
"The calculation started late ({}ms)",
|
||||||
start.elapsed().as_millis()
|
start.elapsed().as_millis()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let base_turns = board.turn();
|
let end_condition: fn(&Board) -> Option<()> = match &*request.game.ruleset.name {
|
||||||
let rolling_horizon = (u32::from(request.you.length) * 3).max(32);
|
"solo" => end_solo,
|
||||||
let last_turn = base_turns + rolling_horizon;
|
_ => end_standard,
|
||||||
let end_condition: &dyn Fn(&Board) -> Option<_> = match &*request.game.ruleset.name {
|
|
||||||
"solo" => &|board| {
|
|
||||||
if board.valid_actions(0).count() == 0 || board.turn() > last_turn {
|
|
||||||
Some(())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => &|board| {
|
|
||||||
if board.num_snakes() <= 1 || board.turn() > last_turn {
|
|
||||||
Some(())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
let score_fn: &dyn Fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
|
let score_fn: fn(&Board, u8) -> u32 = match &*request.game.ruleset.name {
|
||||||
"solo" => &|board, id| u32::try_from(board.length(id)).unwrap_or(0),
|
"solo" => score_solo,
|
||||||
_ => &|board, id| {
|
_ => score_standard,
|
||||||
if board.alive(id) {
|
|
||||||
1 + u32::from(board.max_length() == board.length(id))
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let action_futures = (0..3).map(|_| {
|
||||||
|
let request = request.clone();
|
||||||
|
let board = board.clone();
|
||||||
|
let mut rng = SmallRng::from_os_rng();
|
||||||
|
let actions = actions.clone();
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
|
let mut mcts_managers: Vec<_> = (0..request.board.snakes.len())
|
||||||
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
|
.map(|id| MctsManager::new(u8::try_from(id).unwrap()))
|
||||||
.collect();
|
.collect();
|
||||||
@ -152,7 +136,7 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
mcts_manager.apply_score(score);
|
mcts_manager.apply_score(score);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let my_mcts_manager = &mcts_managers[usize::from(id)];
|
let my_mcts_manager = mcts_managers.into_iter().nth(usize::from(id)).unwrap();
|
||||||
for action in actions {
|
for action in actions {
|
||||||
let score = my_mcts_manager.base.next[usize::from(action)]
|
let score = my_mcts_manager.base.next[usize::from(action)]
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@ -164,33 +148,39 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let action = my_mcts_manager
|
let actions = my_mcts_manager.base.next.map(|action| {
|
||||||
.base
|
action.map_or(0.0, |action| action.score as f32 / action.played as f32)
|
||||||
.next
|
});
|
||||||
.iter()
|
|
||||||
|
(actions, my_mcts_manager.base.played)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let (actions, played) = join_all(action_futures).await.into_iter().fold(
|
||||||
|
([0.0; 4], 0),
|
||||||
|
|(mut total, mut games), actions| {
|
||||||
|
if let Ok((actions, new_games)) = actions {
|
||||||
|
for i in 0..total.len() {
|
||||||
|
total[i] += actions[i];
|
||||||
|
}
|
||||||
|
games += new_games;
|
||||||
|
}
|
||||||
|
(total, games)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let action = actions
|
||||||
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter_map(|(index, info)| {
|
.max_by_key(|(_, score)| FloatOrd(*score))
|
||||||
info.as_ref().map(|info| {
|
.map(|(index, _)| match index {
|
||||||
(
|
|
||||||
match index {
|
|
||||||
0 => Direction::Up,
|
0 => Direction::Up,
|
||||||
1 => Direction::Down,
|
1 => Direction::Down,
|
||||||
2 => Direction::Left,
|
2 => Direction::Left,
|
||||||
3 => Direction::Right,
|
3 => Direction::Right,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
},
|
});
|
||||||
info,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.max_by_key(|(_, info)| FloatOrd(info.score as f32 / info.played as f32))
|
|
||||||
.map(|(action, _)| action);
|
|
||||||
|
|
||||||
if let Some(action) = action {
|
if let Some(action) = action {
|
||||||
info!(
|
info!("found action {action:?} after {played} simulations.",);
|
||||||
"found action {action:?} after {} simulations.",
|
|
||||||
my_mcts_manager.base.played
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
warn!("unable to find a valid action");
|
warn!("unable to find a valid action");
|
||||||
}
|
}
|
||||||
@ -199,9 +189,26 @@ async fn get_move(request: Json<Request>) -> response::Json<Response> {
|
|||||||
direction: action.unwrap_or(Direction::Up),
|
direction: action.unwrap_or(Direction::Up),
|
||||||
shout: None,
|
shout: None,
|
||||||
})
|
})
|
||||||
})
|
}
|
||||||
.await
|
|
||||||
.unwrap()
|
fn end_solo(board: &Board) -> Option<()> {
|
||||||
|
(board.valid_actions(0).count() == 0).then_some(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_standard(board: &Board) -> Option<()> {
|
||||||
|
(board.num_snakes() <= 1).then_some(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn score_solo(board: &Board, id: u8) -> u32 {
|
||||||
|
u32::try_from(board.length(id)).unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn score_standard(board: &Board, id: u8) -> u32 {
|
||||||
|
if board.alive(id) {
|
||||||
|
1 + u32::from(board.max_length() == board.length(id))
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn end(request: Json<Request>) {
|
async fn end(request: Json<Request>) {
|
||||||
|
@ -33,7 +33,7 @@ impl Coord {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Sequence)]
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Sequence)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum Direction {
|
pub enum Direction {
|
||||||
/// Move in positive y direction
|
/// Move in positive y direction
|
||||||
|
@ -265,7 +265,7 @@ fn try_regression() -> Result<(usize, usize, usize), DynError> {
|
|||||||
const GAMES: usize = 100;
|
const GAMES: usize = 100;
|
||||||
// limit the parallelism
|
// limit the parallelism
|
||||||
rayon::ThreadPoolBuilder::new()
|
rayon::ThreadPoolBuilder::new()
|
||||||
.num_threads(std::thread::available_parallelism()?.get() / 4)
|
.num_threads(std::thread::available_parallelism()?.get() / 8)
|
||||||
.build_global()
|
.build_global()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user