This commit is contained in:
2026-03-14 21:31:12 +01:00
commit 3c83dfb07e
9 changed files with 2375 additions and 0 deletions

259
backend/src/main.rs Normal file
View File

@@ -0,0 +1,259 @@
use axum::{
extract::State,
routing::{get, post},
Json, Router,
};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
#[derive(Deserialize, Debug)]
struct TruthBooth {
p1: String, // from group 1
p2: String, // from group 2
match_: bool,
}
#[derive(Deserialize, Debug)]
struct Ceremony {
pairs: HashMap<String, String>, // p1 -> p2
beams: usize,
}
#[derive(Deserialize, Debug)]
struct SolveRequest {
group1: Vec<String>,
group2: Vec<String>,
truth_booths: Vec<TruthBooth>,
ceremonies: Vec<Ceremony>,
}
#[derive(Serialize)]
struct GridData {
index: Vec<String>,
columns: Vec<String>,
data: Vec<Vec<f64>>,
}
#[derive(Serialize)]
struct SolveResponse {
possibilities: usize,
grid_data: GridData,
}
struct Solver {
n: usize,
m: usize,
allowed: Vec<u32>, // bitmask of allowed g1 for each g2
ceremonies: Vec<(Vec<(usize, usize)>, usize)>,
}
impl Solver {
fn new(req: &SolveRequest) -> Result<(Self, HashMap<String, usize>, HashMap<String, usize>), String> {
let n = req.group1.len();
let m = req.group2.len();
if n > m {
return Err("Group 1 cannot be larger than Group 2".to_string());
}
if n > 32 {
return Err("Group 1 size too large".to_string()); // fit in u32
}
let mut g1_map = HashMap::new();
for (i, name) in req.group1.iter().enumerate() {
g1_map.insert(name.clone(), i);
}
let mut g2_map = HashMap::new();
for (i, name) in req.group2.iter().enumerate() {
g2_map.insert(name.clone(), i);
}
let mut allowed = vec![(1 << n) - 1; m];
for tb in &req.truth_booths {
let g1_idx = *g1_map.get(&tb.p1).ok_or("Invalid p1 in TB")?;
let g2_idx = *g2_map.get(&tb.p2).ok_or("Invalid p2 in TB")?;
if tb.match_ {
allowed[g2_idx] = 1 << g1_idx;
} else {
allowed[g2_idx] &= !(1 << g1_idx);
}
}
let mut ceremonies = Vec::new();
for c in &req.ceremonies {
let mut pairs = Vec::new();
for (p1, p2) in &c.pairs {
let g1_idx = *g1_map.get(p1).ok_or("Invalid p1 in ceremony")?;
let g2_idx = *g2_map.get(p2).ok_or("Invalid p2 in ceremony")?;
pairs.push((g2_idx, g1_idx));
}
ceremonies.push((pairs, c.beams));
}
Ok((Solver {
n,
m,
allowed,
ceremonies,
}, g1_map, g2_map))
}
fn solve(&self) -> (usize, Vec<Vec<usize>>) {
// We will generate combinations using a parallel recursive approach.
// A full state is an array of size M, where A[g2] = g1.
// To parallelize, we generate all valid prefixes of length `prefix_len`.
let prefix_len = self.m.min(4); // generate up to length 4 sequentially, then parallelize
let mut prefixes = Vec::new();
self.generate_prefixes(0, 0, 0, &mut Vec::new(), prefix_len, &mut prefixes);
let results: Vec<(usize, Vec<Vec<usize>>)> = prefixes.into_par_iter().map(|(mask_used, mut prefix)| {
let mut counts = vec![vec![0; self.m]; self.n];
let mut valid_count = 0;
self.dfs(prefix_len, mask_used, &mut prefix, &mut valid_count, &mut counts);
(valid_count, counts)
}).collect();
let mut total_valid = 0;
let mut total_counts = vec![vec![0; self.m]; self.n];
for (v_count, counts) in results {
total_valid += v_count;
for i in 0..self.n {
for j in 0..self.m {
total_counts[i][j] += counts[i][j];
}
}
}
(total_valid, total_counts)
}
fn generate_prefixes(&self, g2_idx: usize, mask_used: u32, current_allowed: u32, current: &mut Vec<usize>, target_len: usize, out: &mut Vec<(u32, Vec<usize>)>) {
if g2_idx == target_len {
out.push((mask_used, current.clone()));
return;
}
let mut allowed = self.allowed[g2_idx];
// Minor optimization: if we have (M - g2_idx) elements left, and we need to cover (N - bits_set) elements,
// we can prune if it's impossible.
let needed = self.n as u32 - mask_used.count_ones();
let remaining = (self.m - g2_idx) as u32;
if needed > remaining {
return; // cannot form surjective mapping
}
while allowed > 0 {
let bit = allowed & (!allowed + 1); // get lowest set bit
allowed &= !bit;
let g1_idx = bit.trailing_zeros() as usize;
current.push(g1_idx);
self.generate_prefixes(g2_idx + 1, mask_used | bit, 0, current, target_len, out);
current.pop();
}
}
fn dfs(&self, g2_idx: usize, mask_used: u32, current: &mut Vec<usize>, valid_count: &mut usize, counts: &mut Vec<Vec<usize>>) {
if g2_idx == self.m {
if mask_used.count_ones() as usize != self.n {
return;
}
// Check ceremonies
for (pairs, beams) in &self.ceremonies {
let mut actual_beams = 0;
for &(g2, g1) in pairs {
if current[g2] == g1 {
actual_beams += 1;
}
}
if actual_beams != *beams {
return;
}
}
// Valid scenario
*valid_count += 1;
for (g2, &g1) in current.iter().enumerate() {
counts[g1][g2] += 1;
}
return;
}
let needed = self.n as u32 - mask_used.count_ones();
let remaining = (self.m - g2_idx) as u32;
if needed > remaining {
return;
}
let mut allowed = self.allowed[g2_idx];
// To aggressively prune with ceremonies, if we are at the end of checking pairs, but that's complex since we evaluate g2 sequentially.
// It's usually fast enough to just prune at the leaf for M=11.
while allowed > 0 {
let bit = allowed & (!allowed + 1);
allowed &= !bit;
let g1_idx = bit.trailing_zeros() as usize;
current.push(g1_idx);
self.dfs(g2_idx + 1, mask_used | bit, current, valid_count, counts);
current.pop();
}
}
}
async fn solve_handler(Json(payload): Json<SolveRequest>) -> Result<Json<SolveResponse>, axum::http::StatusCode> {
// We process the solve in a blocking thread to avoid blocking the async runtime
let (solver, _g1_map, _g2_map) = Solver::new(&payload).map_err(|_| axum::http::StatusCode::BAD_REQUEST)?;
let (total, counts) = tokio::task::spawn_blocking(move || {
solver.solve()
}).await.unwrap();
let mut data = vec![vec![0.0; payload.group2.len()]; payload.group1.len()];
if total > 0 {
for i in 0..payload.group1.len() {
for j in 0..payload.group2.len() {
data[i][j] = counts[i][j] as f64 / total as f64;
}
}
}
Ok(Json(SolveResponse {
possibilities: total,
grid_data: GridData {
index: payload.group1,
columns: payload.group2,
data,
},
}))
}
#[tokio::main]
async fn main() {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/api/solve", post(solve_handler))
.fallback_service(ServeDir::new("../ayto").fallback(ServeFile::new("../ayto/index.html")))
.layer(cors);
let addr = "0.0.0.0:8080";
println!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}