From 4d51571e9bbbc0260f2189ffe64bb123cab621e1 Mon Sep 17 00:00:00 2001 From: Louis Vallat Date: Sat, 19 Feb 2022 17:56:34 +0100 Subject: [PATCH] Added the HTTPs client to OVHClient to remove this lengthy argument always passed alongside it to simplify the arguments required for functions Signed-off-by: Louis Vallat --- src/main.rs | 45 ++++++++------ src/records.rs | 156 +++++++++++++++++++++++-------------------------- src/utils.rs | 41 +++++++------ 3 files changed, 125 insertions(+), 117 deletions(-) diff --git a/src/main.rs b/src/main.rs index 9bafa32..d8fc59a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use hyper_tls::HttpsConnector; -use hyper::{Client, client::HttpConnector}; +use hyper::Client; use walkdir::WalkDir; use crate::{ utils::{OVHClient, get_delta, get_subdomain_zone_domain_from_pem, compute_certificate, get_hash_from_cert}, @@ -12,31 +12,36 @@ mod utils; mod records; async fn scan_and_update_whole_folder(base_cert_dir: &str, list: &List, - ovh_client: &OVHClient, issuer_hash: &str, - client: &Client>) { + ovh_client: &OVHClient, issuer_hash: &str) { let interesting_records = vec!["A", "AAAA", "MX", "CNAME"]; + for entry in WalkDir::new(base_cert_dir).into_iter().filter_map(|e| e.ok()) { if !entry.path().ends_with("cert.pem") { continue; } + println!("Found certificate! Located at '{}'.", entry.path().display()); - let (subdomain, zone, domain) = get_subdomain_zone_domain_from_pem(entry.path(), base_cert_dir, list); - let (subdomain, zone, domain) = (subdomain.as_str(), zone.as_str(), domain.as_str()); + let (subdomain, zone, domain) = + get_subdomain_zone_domain_from_pem(entry.path(), base_cert_dir, list); + let (subdomain, zone, domain) = + (subdomain.as_str(), zone.as_str(), domain.as_str()); println!("Computing domain '{}', which has domain '{}' and subdomain '{}'.", domain, zone, subdomain); - let records = get_all_records_from_zone(&ovh_client, &client, zone, subdomain) + + let records = get_all_records_from_zone(&ovh_client, zone, subdomain) .await; if !records.iter().any(|r| interesting_records.contains(&r.field_type.as_str())) { - println!("\tDomain '{}' has no known interesting record. Flushing.", domain); - flush_tlsa_record_for_subdomain(&ovh_client, &client, zone, subdomain).await; + println!("\tDomain '{}' has no known interesting record. Flushing.", + domain); + flush_tlsa_record_for_subdomain(&ovh_client, zone, subdomain).await; println!(); continue; } - compute_certificate(ovh_client, client, &entry.path(), base_cert_dir, issuer_hash, list).await; + compute_certificate(ovh_client, &entry.path(), base_cert_dir, issuer_hash, + list).await; println!(); } } async fn watch_folder(base_cert_dir: &str, ovh_client: &OVHClient, list: &List, - client: &Client>, rx: &Receiver, issuer_hash: &str) { loop { match rx.recv() { @@ -44,15 +49,18 @@ async fn watch_folder(base_cert_dir: &str, ovh_client: &OVHClient, list: &List, if !path.ends_with("cert.pem") || (op != Op::CLOSE_WRITE && op != Op::REMOVE) { continue; } if op == Op::CLOSE_WRITE { - println!("Certificate '{}' modified or created. Updating.", path.display()); - compute_certificate(ovh_client, client, &path.as_path(), base_cert_dir, issuer_hash, list).await; + println!("Certificate '{}' modified or created. Updating.", + path.display()); + compute_certificate(ovh_client, &path.as_path(), base_cert_dir, + issuer_hash, list).await; } if op == Op::REMOVE { println!("Certificate '{}' deleted. Flushing.", path.display()); let (subdomain, zone, _) = - get_subdomain_zone_domain_from_pem(&path.as_path(), base_cert_dir, list); + get_subdomain_zone_domain_from_pem(&path.as_path(), + base_cert_dir, list); let (subdomain, zone) = (subdomain.as_str(), zone.as_str()); - flush_tlsa_record_for_subdomain(ovh_client, client, zone, + flush_tlsa_record_for_subdomain(ovh_client, zone, subdomain).await; } println!(); @@ -65,7 +73,6 @@ async fn watch_folder(base_cert_dir: &str, ovh_client: &OVHClient, list: &List, #[tokio::main] async fn main() { - let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new()); let list = List::fetch().unwrap(); let base_cert_dir = "/etc/nginx/certs/"; @@ -79,20 +86,22 @@ async fn main() { .expect("Missing key value 'OVH_CONSUMER_KEY'."), endpoint: env::var("OVH_API_ENDPOINT") .expect("Missing key value 'OVH_API_ENDPOINT'."), + client: + Client::builder().build::<_, hyper::Body>(HttpsConnector::new()), delta: 0 }; let (tx, rx) = channel(); let mut watcher = raw_watcher(tx).unwrap(); watcher.watch(base_cert_dir, RecursiveMode::Recursive).unwrap(); - ovh_client.delta = get_delta(&ovh_client, &client).await; + ovh_client.delta = get_delta(&ovh_client).await; println!("Delta time is {}", ovh_client.delta); println!("Starting initialization procedure."); scan_and_update_whole_folder(base_cert_dir, &list, &ovh_client, - issuer_hash.as_str(), &client).await; + issuer_hash.as_str()).await; println!("Initializing sequence finished. Entering sentinel mode."); - watch_folder(base_cert_dir, &ovh_client, &list, &client, &rx, issuer_hash.as_str()).await; + watch_folder(base_cert_dir, &ovh_client, &list, &rx, issuer_hash.as_str()).await; } diff --git a/src/records.rs b/src/records.rs index 43eb529..2aa8e0c 100644 --- a/src/records.rs +++ b/src/records.rs @@ -1,5 +1,4 @@ -use hyper::{Method, Client, client::HttpConnector}; -use hyper_tls::HttpsConnector; +use hyper::Method; use serde::{Deserialize, Serialize}; use serde_json::{from_str, json}; use crate::utils::{OVHClient, build_request, body_to_str, get_tlsa_subdomain}; @@ -17,128 +16,121 @@ pub struct Record { } -pub async fn get_record_from_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, id: u64) -> Record { - let req = build_request(ovh_client, &Method::GET, - format!("/domain/zone/{}/record/{}", - zone, id).as_str(), ""); - let res = client.request(req).await.unwrap(); - assert!(res.status().is_success()); - return from_str(body_to_str(res.into_body()).await.as_str()).unwrap(); -} - -pub async fn get_records_from_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, field: &str, subdomain: &str) -> Vec { +pub async fn get_record_from_zone(ovh_client: &OVHClient, zone: &str, id: u64) + -> Record { let req = build_request(ovh_client, &Method::GET, - format!("/domain/zone/{}/record?fieldType={}&subDomain={}", - zone, field, subdomain).as_str(), ""); - let res = client.request(req).await.unwrap(); - assert!(res.status().is_success()); - let array: Vec = from_str(body_to_str(res.into_body()).await.as_str()).unwrap(); - let mut records = vec![]; - for i in array { - records.push(get_record_from_zone(ovh_client, client, zone, i).await); - } - return records; -} - -pub async fn get_all_records_from_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, subdomain: &str) -> Vec { - let req = build_request(ovh_client, &Method::GET, - format!("/domain/zone/{}/record?&subDomain={}", - zone, subdomain).as_str(), ""); - let res = client.request(req).await.unwrap(); - assert!(res.status().is_success()); - let array: Vec = from_str(body_to_str(res.into_body()).await.as_str()).unwrap(); - let mut records = vec![]; - for i in array { - records.push(get_record_from_zone(ovh_client, client, zone, i).await); - } - return records; -} - -pub async fn add_record_to_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, record: &Record) -> Record { - let req = build_request(ovh_client, &Method::POST, - format!("/domain/zone/{}/record", zone).as_str(), - serde_json::to_string(record).unwrap().as_str()); - let res = client.request(req).await.unwrap(); + format!("/domain/zone/{}/record/{}", zone, id) + .as_str(), ""); + let res = ovh_client.client.request(req).await.unwrap(); assert!(res.status().is_success()); return from_str(body_to_str(res.into_body()).await.as_str()).unwrap(); } -pub async fn update_record_in_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, record: &Record) { +pub async fn get_records_from_zone(ovh_client: &OVHClient, zone: &str, + field: &str, subdomain: &str) -> Vec { + let req = build_request(ovh_client, &Method::GET, + format!("/domain/zone/{}/record?fieldType={}&subDomain={}", + zone, field, subdomain).as_str(), ""); + let res = ovh_client.client.request(req).await.unwrap(); + assert!(res.status().is_success()); + let array: Vec = from_str(body_to_str(res.into_body()).await.as_str()) + .unwrap(); + let mut records = vec![]; + for i in array { + records.push(get_record_from_zone(ovh_client, zone, i).await); + } + return records; +} + +pub async fn get_all_records_from_zone(ovh_client: &OVHClient, zone: &str, + subdomain: &str) -> Vec { + let req = build_request(ovh_client, &Method::GET, + format!("/domain/zone/{}/record?&subDomain={}", + zone, subdomain).as_str(), ""); + let res = ovh_client.client.request(req).await.unwrap(); + assert!(res.status().is_success()); + let array: Vec = from_str(body_to_str(res.into_body()).await.as_str()) + .unwrap(); + let mut records = vec![]; + for i in array { + records.push(get_record_from_zone(ovh_client, zone, i).await); + } + return records; +} + +pub async fn add_record_to_zone(ovh_client: &OVHClient, zone: &str, record: &Record) + -> Record { + let req = build_request(ovh_client, &Method::POST, + format!("/domain/zone/{}/record", zone).as_str(), + serde_json::to_string(record).unwrap().as_str()); + let res = ovh_client.client.request(req).await.unwrap(); + assert!(res.status().is_success()); + return from_str(body_to_str(res.into_body()).await.as_str()).unwrap(); +} + +pub async fn update_record_in_zone(ovh_client: &OVHClient, zone: &str, + record: &Record) { let req = build_request(ovh_client, &Method::PUT, - format!("/domain/zone/{}/record/{}", zone, record.id).as_str(), - json!({"subDomain": record.sub_domain, - "target": record.target, "ttl": record.ttl}).to_string().as_str()); - let res = client.request(req).await.unwrap(); + format!("/domain/zone/{}/record/{}", zone, record.id) + .as_str(), json!({"subDomain": record.sub_domain, + "target": record.target, "ttl": record.ttl}) + .to_string().as_str()); + let res = ovh_client.client.request(req).await.unwrap(); assert!(res.status().is_success()); } -pub async fn delete_record_from_zone(ovh_client: &OVHClient, - client: &Client>, - zone: &str, id: u64) { +pub async fn delete_record_from_zone(ovh_client: &OVHClient, zone: &str, id: u64) { let req = build_request(ovh_client, &Method::DELETE, format!("/domain/zone/{}/record/{}", zone, id).as_str(), ""); - let res = client.request(req).await.unwrap(); + let res = ovh_client.client.request(req).await.unwrap(); assert!(res.status().is_success()); } -pub async fn refresh_zone(ovh_client: &OVHClient, client: &Client>, - zone: &str) { +pub async fn refresh_zone(ovh_client: &OVHClient, zone: &str) { let req = build_request(ovh_client, &Method::POST, - format!("/domain/zone/{}/refresh", zone).as_str(), - ""); - let res = client.request(req).await.unwrap(); + format!("/domain/zone/{}/refresh", zone).as_str(), ""); + let res = ovh_client.client.request(req).await.unwrap(); assert!(res.status().is_success()); } -pub async fn flush_tlsa_record_for_subdomain(ovh_client: &OVHClient, - client: &Client>, - zone: &str, subdomain: &str) { - let mut tlsa = get_records_from_zone(ovh_client, client, zone, "TLSA", +pub async fn flush_tlsa_record_for_subdomain(ovh_client: &OVHClient, zone: &str, + subdomain: &str) { + let mut tlsa = get_records_from_zone(ovh_client, zone, "TLSA", get_tlsa_subdomain(subdomain, 25, "tcp") .as_str()).await; - tlsa.append(&mut get_records_from_zone(ovh_client, client, zone, "TLSA", + tlsa.append(&mut get_records_from_zone(ovh_client, zone, "TLSA", get_tlsa_subdomain(subdomain, 443, "tcp") .as_str()).await); for record in tlsa { - delete_record_from_zone(ovh_client, client, zone, record.id).await; + delete_record_from_zone(ovh_client, zone, record.id).await; } - refresh_zone(ovh_client, client, zone).await; + refresh_zone(ovh_client, zone).await; } -pub async fn update_tlsa_for_subdomain(ovh_client: &OVHClient, - client: &Client>, - zone: &str, subdomain: &str, hash: &str, +pub async fn update_tlsa_for_subdomain(ovh_client: &OVHClient, zone: &str, + subdomain: &str, hash: &str, issuer_hash: &str, port: u32, protocol: &str) { let tlsa_subdomain = get_tlsa_subdomain(subdomain, port, protocol); - let records = get_records_from_zone(ovh_client, client, zone, "TLSA", &tlsa_subdomain).await; + let records = get_records_from_zone(ovh_client, zone, "TLSA", &tlsa_subdomain) + .await; for record in records { - delete_record_from_zone(ovh_client, client, zone, record.id).await; + delete_record_from_zone(ovh_client, zone, record.id).await; } - add_record_to_zone(ovh_client, client, zone, &Record { + add_record_to_zone(ovh_client, zone, &Record { sub_domain: get_tlsa_subdomain(subdomain, port, protocol), target: format!("3 1 1 {}", hash).to_string(), field_type: "TLSA".to_string(), ttl: 0, id: 0 }).await; - add_record_to_zone(ovh_client, client, zone, &Record { + add_record_to_zone(ovh_client, zone, &Record { sub_domain: get_tlsa_subdomain(subdomain, port, protocol), target: format!("2 1 1 {}", issuer_hash).to_string(), field_type: "TLSA".to_string(), ttl: 0, id: 0 }).await; - refresh_zone(ovh_client, client, zone).await; + refresh_zone(ovh_client, zone).await; } diff --git a/src/utils.rs b/src/utils.rs index 6cab000..f568f3e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -10,6 +10,7 @@ pub struct OVHClient { pub app_secret: String, pub consumer_key: String, pub endpoint: String, + pub client: Client>, pub delta: i64 } @@ -34,7 +35,8 @@ pub fn get_signature(ovh_client: &OVHClient, method: &Method, query: &str, body: return format!("$1${}", sha1::Sha1::from(stringapi).hexdigest()); } -pub fn build_request(ovh_client: &OVHClient, method: &Method, uri: &str, body: &str) -> Request { +pub fn build_request(ovh_client: &OVHClient, method: &Method, uri: &str, body: &str) + -> Request { return Request::builder() .method(method) .uri(ovh_client.endpoint.clone() + uri) @@ -42,17 +44,18 @@ pub fn build_request(ovh_client: &OVHClient, method: &Method, uri: &str, body: & .header("X-Ovh-Timestamp", SystemTime::now().duration_since(UNIX_EPOCH) .unwrap().as_secs() as i64 + ovh_client.delta) .header("X-Ovh-Consumer", ovh_client.consumer_key.clone()) - .header("X-Ovh-Signature", get_signature(ovh_client, method, - format!("{}{}", - ovh_client.endpoint, uri).as_str(), body)) + .header("X-Ovh-Signature", + get_signature(ovh_client, method, + format!("{}{}", ovh_client.endpoint, uri) + .as_str(), body)) .header("Content-Type", "application/json;charset=utf-8") .body(Body::from(body.to_string())) .expect(uri); } -pub async fn get_delta(ovh_client: &OVHClient, client: &Client>) -> i64 { +pub async fn get_delta(ovh_client: &OVHClient) -> i64 { let req = build_request(&ovh_client, &Method::GET, "/auth/time", ""); - let res = client.request(req).await.unwrap(); + let res = ovh_client.client.request(req).await.unwrap(); let s = res.status(); assert!(s.is_success()); let b = body_to_str(res.into_body()).await.parse::().unwrap(); @@ -60,9 +63,12 @@ pub async fn get_delta(ovh_client: &OVHClient, client: &Client String { - let certificate = std::fs::read_to_string(path).expect("Something went wrong reading certificate file."); - let public_key = X509::from_pem(certificate.as_bytes()).unwrap().public_key().unwrap(); - let hashed = hash(MessageDigest::sha256(), &public_key.public_key_to_der().unwrap()).unwrap(); + let certificate = std::fs::read_to_string(path) + .expect("Something went wrong reading certificate file."); + let public_key = X509::from_pem(certificate.as_bytes()).unwrap().public_key() + .unwrap(); + let hashed = hash(MessageDigest::sha256(), &public_key.public_key_to_der().unwrap()) + .unwrap(); let mut res = String::with_capacity(64); for byte in &*hashed { write!(&mut res, "{:02x}", byte).unwrap(); @@ -77,22 +83,23 @@ pub fn get_tlsa_subdomain(subdomain: &str, port: u32, protocol: &str) -> String subdomain); } -pub async fn compute_certificate(ovh_client: &OVHClient, - client: &Client>, - path: &Path, base_cert_dir: &str, issuer_hash: &str, list: &List) { - let (subdomain, zone, domain) = get_subdomain_zone_domain_from_pem(path, base_cert_dir, list); - let (subdomain, zone, domain) = (subdomain.as_str(), zone.as_str(), domain.as_str()); - let records = get_all_records_from_zone(&ovh_client, &client, zone, subdomain).await; +pub async fn compute_certificate(ovh_client: &OVHClient, path: &Path, + base_cert_dir: &str, issuer_hash: &str, list: &List) { + let (subdomain, zone, domain) = + get_subdomain_zone_domain_from_pem(path, base_cert_dir, list); + let (subdomain, zone, domain) = + (subdomain.as_str(), zone.as_str(), domain.as_str()); + let records = get_all_records_from_zone(&ovh_client, zone, subdomain).await; let hash = get_hash_from_cert(path.to_str().unwrap()); if records.iter().any(|r| r.field_type == "A" || r.field_type == "AAAA" || r.field_type == "CNAME") { println!("\tDomain '{}' is associated with website. Updating.", domain); - update_tlsa_for_subdomain(&ovh_client, &client, zone, subdomain, + update_tlsa_for_subdomain(&ovh_client, zone, subdomain, hash.as_str(), issuer_hash, 443, "tcp").await; } if records.iter().any(|r| r.field_type == "MX" && r.sub_domain == subdomain) { println!("\tDomain '{}' is associated with mail. Updating.", domain); - update_tlsa_for_subdomain(&ovh_client, &client, zone, subdomain, + update_tlsa_for_subdomain(&ovh_client, zone, subdomain, hash.as_str(), issuer_hash, 25, "tcp").await; } }