diff --git a/src/dump_git.rs b/src/dump_git.rs index fcc2759..13513c2 100644 --- a/src/dump_git.rs +++ b/src/dump_git.rs @@ -1,11 +1,7 @@ -use std::{ - collections::HashSet, - path::{Path, PathBuf}, - time::Duration, -}; +use std::{collections::HashSet, path::Path, sync::Arc, time::Duration}; use anyhow::{bail, Context, Result}; -use hyper::{Client, StatusCode}; +use hyper::{Client, Method, Request, StatusCode}; use hyper_tls::HttpsConnector; use regex::Regex; use tokio::{ @@ -13,7 +9,10 @@ use tokio::{ time::sleep, }; -use crate::git_parsing::{parse_hash, parse_head, parse_log, parse_object, GitObject}; +use crate::{ + git_parsing::{parse_hash, parse_head, parse_log, parse_object, GitObject}, + Args, +}; lazy_static::lazy_static! { static ref REGEX_OBJECT_PATH: Regex = Regex::new(r"[\da-f]{2}/[\da-f]{38}").unwrap(); @@ -42,7 +41,9 @@ struct DownloadedFile { pub tx: UnboundedSender, } -pub async fn download_all(base_url: String, base_path: PathBuf, max_task_count: u16) { +pub async fn download_all(args: Arc) { + let base_url = &args.url; + let base_path = &args.path; let mut cache = HashSet::::new(); // TODO: try out unbounded channel too @@ -75,8 +76,9 @@ pub async fn download_all(base_url: String, base_path: PathBuf, max_task_count: let url = format!("{}{}", &base_url, &message.path); let base_path = base_path.clone(); + let cloned_args = args.clone(); let handle = tokio::spawn(async move { - let file_bytes = match download(&url).await { + let file_bytes = match download(&url, cloned_args).await { Ok(content) => content, Err(e) => { println!("Error while downloading file {url}: {}", e); @@ -99,7 +101,7 @@ pub async fn download_all(base_url: String, base_path: PathBuf, max_task_count: threads.push(handle); - while threads.len() >= (max_task_count as usize) { + while threads.len() >= (args.tasks as usize) { // sleep sleep(Duration::from_millis(10)).await; @@ -109,9 +111,25 @@ pub async fn download_all(base_url: String, base_path: PathBuf, max_task_count: } } -async fn download(url: &str) -> Result> { +async fn download(url: &str, args: Arc) -> Result> { let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new()); - let resp = client.get(url.parse().unwrap()).await; + let req = Request::builder() + .method(Method::GET) + .uri(url) + .header( + "User-Agent", + args.user_agent + .clone() + .unwrap_or( + "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)" + .into(), + ) + .clone(), + ) + .body(hyper::Body::empty()) + .expect("Failed to build the request"); + + let resp = client.request(req).await; match resp { Ok(resp) => match resp.status() { StatusCode::OK => { diff --git a/src/main.rs b/src/main.rs index 5f996b7..650a744 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use clap::Parser; @@ -7,11 +7,13 @@ mod git_parsing; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] -struct Cli { +pub struct Args { /// The url of the exposed .git directory #[arg()] url: String, + #[arg(short, long)] + user_agent: Option, /// The directory to download to #[arg(default_value = "git-dumped")] path: PathBuf, @@ -23,13 +25,13 @@ struct Cli { #[tokio::main] async fn main() -> Result<(), Box> { - let args = Cli::parse(); + let args = Args::parse(); // println!("URL: {url}"); // println!("PATH: {path}"); std::fs::create_dir_all(args.path.join(".git"))?; - dump_git::download_all(args.url.clone(), args.path, args.tasks).await; + dump_git::download_all(Arc::new(args)).await; Ok(()) }