init
This commit is contained in:
259
backend/src/main.rs
Normal file
259
backend/src/main.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct TruthBooth {
|
||||
p1: String, // from group 1
|
||||
p2: String, // from group 2
|
||||
match_: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Ceremony {
|
||||
pairs: HashMap<String, String>, // p1 -> p2
|
||||
beams: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct SolveRequest {
|
||||
group1: Vec<String>,
|
||||
group2: Vec<String>,
|
||||
truth_booths: Vec<TruthBooth>,
|
||||
ceremonies: Vec<Ceremony>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct GridData {
|
||||
index: Vec<String>,
|
||||
columns: Vec<String>,
|
||||
data: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SolveResponse {
|
||||
possibilities: usize,
|
||||
grid_data: GridData,
|
||||
}
|
||||
|
||||
struct Solver {
|
||||
n: usize,
|
||||
m: usize,
|
||||
allowed: Vec<u32>, // bitmask of allowed g1 for each g2
|
||||
ceremonies: Vec<(Vec<(usize, usize)>, usize)>,
|
||||
}
|
||||
|
||||
impl Solver {
|
||||
fn new(req: &SolveRequest) -> Result<(Self, HashMap<String, usize>, HashMap<String, usize>), String> {
|
||||
let n = req.group1.len();
|
||||
let m = req.group2.len();
|
||||
|
||||
if n > m {
|
||||
return Err("Group 1 cannot be larger than Group 2".to_string());
|
||||
}
|
||||
if n > 32 {
|
||||
return Err("Group 1 size too large".to_string()); // fit in u32
|
||||
}
|
||||
|
||||
let mut g1_map = HashMap::new();
|
||||
for (i, name) in req.group1.iter().enumerate() {
|
||||
g1_map.insert(name.clone(), i);
|
||||
}
|
||||
|
||||
let mut g2_map = HashMap::new();
|
||||
for (i, name) in req.group2.iter().enumerate() {
|
||||
g2_map.insert(name.clone(), i);
|
||||
}
|
||||
|
||||
let mut allowed = vec![(1 << n) - 1; m];
|
||||
|
||||
for tb in &req.truth_booths {
|
||||
let g1_idx = *g1_map.get(&tb.p1).ok_or("Invalid p1 in TB")?;
|
||||
let g2_idx = *g2_map.get(&tb.p2).ok_or("Invalid p2 in TB")?;
|
||||
|
||||
if tb.match_ {
|
||||
allowed[g2_idx] = 1 << g1_idx;
|
||||
} else {
|
||||
allowed[g2_idx] &= !(1 << g1_idx);
|
||||
}
|
||||
}
|
||||
|
||||
let mut ceremonies = Vec::new();
|
||||
for c in &req.ceremonies {
|
||||
let mut pairs = Vec::new();
|
||||
for (p1, p2) in &c.pairs {
|
||||
let g1_idx = *g1_map.get(p1).ok_or("Invalid p1 in ceremony")?;
|
||||
let g2_idx = *g2_map.get(p2).ok_or("Invalid p2 in ceremony")?;
|
||||
pairs.push((g2_idx, g1_idx));
|
||||
}
|
||||
ceremonies.push((pairs, c.beams));
|
||||
}
|
||||
|
||||
Ok((Solver {
|
||||
n,
|
||||
m,
|
||||
allowed,
|
||||
ceremonies,
|
||||
}, g1_map, g2_map))
|
||||
}
|
||||
|
||||
fn solve(&self) -> (usize, Vec<Vec<usize>>) {
|
||||
// We will generate combinations using a parallel recursive approach.
|
||||
// A full state is an array of size M, where A[g2] = g1.
|
||||
// To parallelize, we generate all valid prefixes of length `prefix_len`.
|
||||
let prefix_len = self.m.min(4); // generate up to length 4 sequentially, then parallelize
|
||||
|
||||
let mut prefixes = Vec::new();
|
||||
self.generate_prefixes(0, 0, 0, &mut Vec::new(), prefix_len, &mut prefixes);
|
||||
|
||||
let results: Vec<(usize, Vec<Vec<usize>>)> = prefixes.into_par_iter().map(|(mask_used, mut prefix)| {
|
||||
let mut counts = vec![vec![0; self.m]; self.n];
|
||||
let mut valid_count = 0;
|
||||
self.dfs(prefix_len, mask_used, &mut prefix, &mut valid_count, &mut counts);
|
||||
(valid_count, counts)
|
||||
}).collect();
|
||||
|
||||
let mut total_valid = 0;
|
||||
let mut total_counts = vec![vec![0; self.m]; self.n];
|
||||
|
||||
for (v_count, counts) in results {
|
||||
total_valid += v_count;
|
||||
for i in 0..self.n {
|
||||
for j in 0..self.m {
|
||||
total_counts[i][j] += counts[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(total_valid, total_counts)
|
||||
}
|
||||
|
||||
fn generate_prefixes(&self, g2_idx: usize, mask_used: u32, current_allowed: u32, current: &mut Vec<usize>, target_len: usize, out: &mut Vec<(u32, Vec<usize>)>) {
|
||||
if g2_idx == target_len {
|
||||
out.push((mask_used, current.clone()));
|
||||
return;
|
||||
}
|
||||
|
||||
let mut allowed = self.allowed[g2_idx];
|
||||
|
||||
// Minor optimization: if we have (M - g2_idx) elements left, and we need to cover (N - bits_set) elements,
|
||||
// we can prune if it's impossible.
|
||||
let needed = self.n as u32 - mask_used.count_ones();
|
||||
let remaining = (self.m - g2_idx) as u32;
|
||||
if needed > remaining {
|
||||
return; // cannot form surjective mapping
|
||||
}
|
||||
|
||||
while allowed > 0 {
|
||||
let bit = allowed & (!allowed + 1); // get lowest set bit
|
||||
allowed &= !bit;
|
||||
let g1_idx = bit.trailing_zeros() as usize;
|
||||
|
||||
current.push(g1_idx);
|
||||
self.generate_prefixes(g2_idx + 1, mask_used | bit, 0, current, target_len, out);
|
||||
current.pop();
|
||||
}
|
||||
}
|
||||
|
||||
fn dfs(&self, g2_idx: usize, mask_used: u32, current: &mut Vec<usize>, valid_count: &mut usize, counts: &mut Vec<Vec<usize>>) {
|
||||
if g2_idx == self.m {
|
||||
if mask_used.count_ones() as usize != self.n {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check ceremonies
|
||||
for (pairs, beams) in &self.ceremonies {
|
||||
let mut actual_beams = 0;
|
||||
for &(g2, g1) in pairs {
|
||||
if current[g2] == g1 {
|
||||
actual_beams += 1;
|
||||
}
|
||||
}
|
||||
if actual_beams != *beams {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Valid scenario
|
||||
*valid_count += 1;
|
||||
for (g2, &g1) in current.iter().enumerate() {
|
||||
counts[g1][g2] += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let needed = self.n as u32 - mask_used.count_ones();
|
||||
let remaining = (self.m - g2_idx) as u32;
|
||||
if needed > remaining {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut allowed = self.allowed[g2_idx];
|
||||
|
||||
// To aggressively prune with ceremonies, if we are at the end of checking pairs, but that's complex since we evaluate g2 sequentially.
|
||||
// It's usually fast enough to just prune at the leaf for M=11.
|
||||
|
||||
while allowed > 0 {
|
||||
let bit = allowed & (!allowed + 1);
|
||||
allowed &= !bit;
|
||||
let g1_idx = bit.trailing_zeros() as usize;
|
||||
|
||||
current.push(g1_idx);
|
||||
self.dfs(g2_idx + 1, mask_used | bit, current, valid_count, counts);
|
||||
current.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn solve_handler(Json(payload): Json<SolveRequest>) -> Result<Json<SolveResponse>, axum::http::StatusCode> {
|
||||
// We process the solve in a blocking thread to avoid blocking the async runtime
|
||||
let (solver, _g1_map, _g2_map) = Solver::new(&payload).map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let (total, counts) = tokio::task::spawn_blocking(move || {
|
||||
solver.solve()
|
||||
}).await.unwrap();
|
||||
|
||||
let mut data = vec![vec![0.0; payload.group2.len()]; payload.group1.len()];
|
||||
if total > 0 {
|
||||
for i in 0..payload.group1.len() {
|
||||
for j in 0..payload.group2.len() {
|
||||
data[i][j] = counts[i][j] as f64 / total as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(SolveResponse {
|
||||
possibilities: total,
|
||||
grid_data: GridData {
|
||||
index: payload.group1,
|
||||
columns: payload.group2,
|
||||
data,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/solve", post(solve_handler))
|
||||
.fallback_service(ServeDir::new("../ayto").fallback(ServeFile::new("../ayto/index.html")))
|
||||
.layer(cors);
|
||||
|
||||
let addr = "0.0.0.0:8080";
|
||||
println!("Starting server on {}", addr);
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
Reference in New Issue
Block a user