main.rs (9856B)
1 #![crate_name = "wala"] 2 3 use tiny_http::{ 4 Server, 5 ServerConfig, 6 Request, 7 Header, 8 Method, 9 }; 10 use mime::Mime; 11 use std::net::{Ipv4Addr, SocketAddrV4}; 12 use std::str::FromStr; 13 //use std::path::{PathBuf, Path}; 14 use std::path::Path; 15 use std::fs::{ 16 File, 17 create_dir_all, 18 }; 19 use std::error::Error; 20 use std::fmt; 21 use std::io::{ 22 copy as io_copy, 23 // Read, 24 Seek, 25 empty, 26 }; 27 use std::time::Duration; 28 29 use std::sync::Arc; 30 use std::sync::atomic::{AtomicBool, Ordering}; 31 32 use env_logger; 33 use ascii::AsciiStr; 34 //use signal_hook::flag; 35 //use signal_hook::consts; 36 37 use wala::auth::{ 38 AuthSpec, 39 AuthResult, 40 }; 41 42 use wala::record::{ 43 RequestResult, 44 RequestResultType, 45 }; 46 47 use wala::request::process_method; 48 use wala::response::{ 49 exec_response, 50 preflight_response, 51 }; 52 53 #[cfg(feature = "trace")] 54 use wala::trace::trace_request; 55 56 mod arg; 57 use arg::Settings; 58 59 use log::{info, error, warn}; 60 61 use tempfile::tempfile; 62 63 #[cfg(feature = "dev")] 64 use wala::auth::mock::auth_check as mock_auth_check; 65 66 #[cfg(feature = "pgpauth")] 67 //use wala::auth::pgp::auth_check as pgp_auth_check; 68 use wala::auth::pgp_sequoia::auth_check as pgp_auth_check; 69 70 71 #[derive(Debug)] 72 pub struct NoAuthError; 73 74 impl Error for NoAuthError { 75 fn description(&self) -> &str{ 76 "no auth" 77 } 78 } 79 80 impl fmt::Display for NoAuthError { 81 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 82 fmt.write_str(&self.to_string()) 83 } 84 } 85 86 fn exec_auth(auth_spec: AuthSpec, data: &File, data_length: usize) -> Option<AuthResult> { 87 #[cfg(feature = "dev")] 88 match mock_auth_check(&auth_spec, data, data_length) { 89 Ok(v) => { 90 return Some(v); 91 }, 92 Err(e) => { 93 error!("mock auth check error ({})", e) 94 }, 95 } 96 97 #[cfg(feature = "pgpauth")] 98 match pgp_auth_check(&auth_spec, data, data_length) { 99 Ok(v) => { 100 return Some(v); 101 }, 102 Err(e) => { 103 error!("pgp auth check error ({})", e) 104 }, 105 } 106 107 None 108 } 109 110 111 fn process_auth(auth_spec: AuthSpec, data: &File, data_length: usize) -> Option<AuthResult> { 112 if !auth_spec.valid() { 113 let r = AuthResult{ 114 identity: vec!(), 115 error: true, 116 }; 117 return Some(r); 118 } 119 exec_auth(auth_spec, data, data_length) 120 } 121 122 123 fn auth_from_headers(headers: &[Header], method: &Method) -> Option<AuthSpec> { 124 for h in headers { 125 let k = &h.field; 126 if k.equiv("Authorization") { 127 let v = &h.value; 128 let r = AuthSpec::from_str(v.as_str()); 129 match r { 130 Ok(v) => { 131 return Some(v); 132 }, 133 Err(e) => { 134 error!("malformed auth string: {} ({})", &h.value, e); 135 let r = AuthSpec{ 136 method: String::from(method.as_str()), 137 key: String::new(), 138 signature: String::new(), 139 }; 140 return Some(r); 141 } 142 } 143 } 144 } 145 None 146 } 147 148 149 fn process_request(req: &mut Request, f: &File) -> AuthResult { 150 let headers = req.headers(); 151 let method = req.method(); 152 153 let r: Option<AuthResult>; 154 155 r = match auth_from_headers(headers, method) { 156 Some(v) => { 157 process_auth(v, f, 0) 158 }, 159 _ => { 160 None 161 }, 162 }; 163 164 match r { 165 Some(v) => { 166 return v; 167 }, 168 _ => {}, 169 }; 170 171 // is not auth 172 AuthResult{ 173 identity: vec!(), 174 error: false, 175 } 176 } 177 178 fn process_meta(req: &Request, path: &Path, digest: Vec<u8>) -> Option<Mime> { 179 let headers = req.headers(); 180 let mut m: Option<mime::Mime> = None; 181 let mut n: Option<String> = None; 182 183 for h in headers { 184 let k = &h.field; 185 if k.equiv("Content-Type") { 186 let v = &h.value; 187 m = match Mime::from_str(v.as_str()) { 188 Err(e) => { 189 error!("invalid mime type ({})", e); 190 return None; 191 }, 192 Ok(v) => { 193 Some(v) 194 }, 195 }; 196 } else if k.equiv("X-Filename") { 197 let v = &h.value; 198 let p = Path::new(v.as_str()); 199 let fp = p.to_str().unwrap(); 200 n = Some(String::from(fp)); 201 } 202 } 203 204 #[cfg(feature = "meta")] 205 match m { 206 Some(v) => { 207 match wala::meta::register_type(path, &digest, v) { 208 Err(e) => { 209 error!("could not register content type: {}", &e); 210 }, 211 _ => {}, 212 }; 213 }, 214 _ => {}, 215 }; 216 217 #[cfg(feature = "meta")] 218 match n { 219 Some(v) => { 220 match wala::meta::register_filename(path, &digest, v) { 221 Err(e) => { 222 error!("could not register content type: {}", &e); 223 }, 224 _ => {}, 225 }; 226 }, 227 _ => {}, 228 }; 229 230 None 231 } 232 233 234 fn main() { 235 env_logger::init(); 236 237 let settings = Settings::from_args(); 238 let base_path = settings.dir.as_path(); 239 240 let spool_path = base_path.join("spool"); 241 //let mut spool_ok = false; 242 // 243 244 #[cfg(feature = "trace")] 245 { 246 match create_dir_all(&spool_path) { 247 Ok(_) => { 248 // spool_ok = true; 249 }, 250 Err(e) => { 251 warn!("spool directory could not be created: {:?}", e); 252 }, 253 }; 254 } 255 256 info!("Using data dir: {:?}", &base_path); 257 258 let ip_addr = Ipv4Addr::from_str(&settings.host).unwrap(); 259 let tcp_port: u16 = settings.port; 260 let sock_addr = SocketAddrV4::new(ip_addr, tcp_port); 261 let srv_cfg = ServerConfig{ 262 addr: sock_addr, 263 ssl: None, 264 }; 265 let srv = Server::new(srv_cfg).unwrap(); 266 267 let term = Arc::new(AtomicBool::new(false)); 268 signal_hook::flag::register(signal_hook::consts::SIGINT, Arc::clone(&term)).unwrap(); 269 270 #[cfg(feature = "docker")] 271 signal_hook::flag::register(signal_hook::consts::SIGTERM, Arc::clone(&term)).unwrap(); 272 273 const LOOP_TIMEOUT: Duration = Duration::new(1, 0); 274 275 while !term.load(Ordering::Relaxed) { 276 277 let b = srv.recv_timeout(LOOP_TIMEOUT); 278 //let mut hasreq: Option<Request>; 279 let hasreq: Option<Request>; 280 match b { 281 Ok(v) => hasreq = v, 282 Err(e) => { 283 error!("{}", e); 284 break; 285 } 286 }; 287 let mut req: Request; 288 match hasreq { 289 Some(v) => { 290 req = v; 291 }, 292 None => { 293 continue 294 } 295 }; 296 297 let method = req.method().clone(); 298 match &method { 299 Method::Options => { 300 preflight_response(req); 301 continue; 302 }, 303 _ => {}, 304 } 305 306 let url = String::from(&req.url()[1..]); 307 let expected_size = match req.body_length() { 308 Some(v) => { 309 v 310 }, 311 None => { 312 0 313 }, 314 }; 315 let f = req.as_reader(); 316 //let mut path = base_path.clone(); 317 let path = base_path.clone(); 318 let mut res: AuthResult = AuthResult{ 319 identity: vec!(), 320 error: false, 321 }; 322 let rw: Option<File> = match tempfile() { 323 Ok(mut v) => { 324 match io_copy(f, &mut v) { 325 Ok(_) => { 326 }, 327 Err(e) => { 328 error!("could not copy file: {:?} ({})", path, e); 329 continue; 330 }, 331 }; 332 match v.rewind() { 333 Ok(_) => { 334 }, 335 Err(e) => { 336 error!("could not rewind file for request: {:?} ({})", path, e); 337 continue; 338 }, 339 }; 340 res = process_request(&mut req, &mut v); 341 match v.rewind() { 342 Ok(_) => { 343 }, 344 Err(e) => { 345 error!("could not rewind file for return: {:?} ({})", path, e); 346 continue; 347 }, 348 }; 349 Some(v) 350 }, 351 Err(e) => { 352 error!("tempfile error: {:?} ({})", path, e); 353 None 354 }, 355 }; 356 357 //let mut result: RequestResult; 358 let result: RequestResult; 359 match rw { 360 Some(v) => { 361 result = process_method(&method, url, v, expected_size, &path, res); 362 }, 363 None => { 364 let v = empty(); 365 result = process_method(&method, url, v, expected_size, &path, res); 366 }, 367 }; 368 369 match &result.typ { 370 RequestResultType::Changed => { 371 let digest_hex = result.v.clone().unwrap(); 372 let digest = hex::decode(&digest_hex).unwrap(); 373 process_meta(&req, &path, digest); 374 }, 375 RequestResultType::Found => { 376 377 }, 378 _ => {}, 379 } 380 381 #[cfg(feature="trace")] 382 { 383 for h in req.headers() { 384 if h.field.equiv("X-Wala-Trace") { 385 let mut identity = false; 386 if h.value.eq_ignore_ascii_case(AsciiStr::from_ascii("identity").unwrap()) { 387 identity = true; 388 } 389 trace_request(&spool_path, &result, identity); 390 } 391 } 392 } 393 394 exec_response(req, result); 395 396 } 397 }