You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
6.8 KiB

mod deserializers;
use deserializers::*;
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::SparseEntry;
use serde::{Deserialize, Serialize};
use std::fs::File;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Airport {
airport_id: i32,
name: String,
city: String,
country: String,
#[serde(rename = "IATA")]
iata_code: String,
#[serde(rename = "ICAO")]
icao_code: String,
lat: f64,
lon: f64,
altitude: i32,
timezone: String,
dst: String,
tzdb: String,
#[serde(rename = "type")]
port_type: String,
source: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Route {
airline: String,
#[serde(deserialize_with = "error_as_none")]
airline_id: Option<i32>,
src_airport: String,
#[serde(deserialize_with = "error_as_none")]
src_airport_id: Option<i32>,
dest_airport: String,
#[serde(deserialize_with = "error_as_none")]
dest_airport_id: Option<i32>,
#[serde(deserialize_with = "default_as_none")]
codeshare: Option<String>,
#[serde(deserialize_with = "error_as_none")]
stops: Option<i32>,
planes: String,
cost: f64,
}
use std::collections::{BTreeSet, HashMap, HashSet};
fn airport_map(routes: &Vec<Route>) -> HashMap<String, usize> {
let airports: BTreeSet<String> = routes
.iter()
.flat_map(|route| vec![route.src_airport.clone(), route.dest_airport.clone()])
.collect();
let airport_map: HashMap<String, usize> = airports
.into_iter()
.enumerate()
.map(|(i, k)| (k, i))
.collect();
airport_map
}
fn into_matrix(routes: &Vec<Route>, airport_map: &HashMap<String, usize>) -> CsrMatrix<f64> {
let num_airports = airport_map.len();
let mut adjacency_matrix = DMatrix::<f64>::zeros(num_airports, num_airports);
routes.iter().for_each(|route| {
let src_idx = airport_map[&route.src_airport];
let dest_idx = airport_map[&route.dest_airport];
adjacency_matrix[(src_idx, dest_idx)] = route.cost;
});
CsrMatrix::from(&adjacency_matrix)
}
fn load_data<R, T>(read: R) -> Vec<T>
where
T: for<'a> Deserialize<'a> + std::fmt::Debug,
R: std::io::Read,
{
let mut data = vec![];
let mut rdr = csv::Reader::from_reader(read);
for result in rdr.deserialize() {
let record: T = result.unwrap();
data.push(record);
}
data
}
trait Adjacent {
fn all_adjacent(&self, src: usize) -> Vec<usize>;
}
impl Adjacent for CsrMatrix<f64> {
fn all_adjacent(&self, src: usize) -> Vec<usize> {
self.row(src).col_indices().iter().cloned().collect()
}
}
// returns a list of a set of connections; these are guarenteed to be unique
fn dijkstras(graph: &CsrMatrix<f64>, start: usize, end: usize) -> Option<Vec<usize>> {
let num_nodes = graph.nrows();
let mut dist: Vec<f64> = vec![f64::MAX; num_nodes];
let mut prev: Vec<Option<usize>> = vec![None; num_nodes];
let mut visited: Vec<bool> = vec![false; num_nodes];
dist[start] = 0f64;
for _ in 0..num_nodes {
let u = min_distance(&dist, &visited);
visited[u] = true;
if u == end {
break;
}
for v in graph.all_adjacent(u) {
let alt = dist[u]
+ match graph.get_entry(u, v).unwrap() {
SparseEntry::Zero => {
continue;
}
SparseEntry::NonZero(i) => *i,
}
+ f64::from(10000000);
if alt < dist[v] {
dist[v] = alt;
prev[v] = Some(u);
}
}
}
if dist[end] == f64::MAX {
None
} else {
Some(reconstruct_path(prev, start, end))
}
}
fn min_distance(dist: &Vec<f64>, visited: &Vec<bool>) -> usize {
let mut min_dist = f64::MAX;
let mut min_index = 0;
for (i, &d) in dist.iter().enumerate() {
if !visited[i] && d <= min_dist {
min_dist = d;
min_index = i;
}
}
min_index
}
fn reconstruct_path(prev: Vec<Option<usize>>, start: usize, end: usize) -> Vec<usize> {
let mut path = vec![end];
let mut current = end;
while let Some(prev_node) = prev[current] {
path.push(prev_node);
current = prev_node;
if current == start {
break;
}
}
path.reverse();
path
}
fn list_airports(connections: Vec<usize>, map: &HashMap<String, usize>) -> Vec<String> {
let reverse_map: HashMap<usize, String> = map.iter().map(|(s, i)| (*i, s.clone())).collect();
connections
.iter()
.map(|i| reverse_map.get(i).unwrap().clone())
.collect()
}
use text_io::read;
fn read_iata_code(s: &str, airports: &Vec<Airport>) -> Airport {
loop {
print!("{}", s);
let input: String = read!();
if let Some(airport) = airports.iter().find(|a| a.iata_code == input) {
return airport.clone();
} else {
println!("Invalid IATA code! Try again.");
}
}
}
fn get_request(airports: &Vec<Airport>, map: &HashMap<String, usize>) -> (usize, usize) {
let from = read_iata_code("From IATA code: ", &airports);
let to = read_iata_code("From IATA code: ", &airports);
(
*map.get(&from.iata_code).unwrap(),
*map.get(&to.iata_code).unwrap(),
)
}
use itertools::Itertools;
fn get_weight(route: &Vec<usize>, matrix: &CsrMatrix<f64>) -> Vec<f64> {
route
.iter()
.tuple_windows()
.map(|(src, dest)| match matrix.get_entry(*src, *dest).unwrap() {
SparseEntry::Zero => panic!("Invlaid path!"),
SparseEntry::NonZero(nz) => *nz,
})
.collect()
}
fn main() {
// read from CSV -> Vec<Route>
let route_file = File::open("data/new.csv").unwrap();
let routes: Vec<Route> = load_data(route_file);
println!("Loading route data!");
let airports_with_route: HashSet<String> = routes
.iter()
.flat_map(|route| vec![route.src_airport.clone(), route.dest_airport.clone()])
.collect();
// Vec<Route> -> HashMap<String, uint>
let airports_file = File::open("data/airports.csv").unwrap();
println!("Loading airport data!");
let airports: Vec<Airport> = load_data(airports_file)
.into_iter()
.filter(|airport: &Airport| airports_with_route.get(&airport.iata_code).is_some())
.collect();
println!("Generating airport mapping!");
let map = airport_map(&routes);
// Vec<Route> + HashMap<String, usint> -> Matrix<cost of flight>
println!("Generating route matrix!");
let matrix = into_matrix(&routes, &map);
let (from, to) = get_request(&airports, &map);
let first = dijkstras(&matrix, from, to).unwrap();
println!("{:?}", get_weight(&first, &matrix));
println!("{:?}", list_airports(first, &map));
}