diff --git a/ql/Cargo.lock b/ql/Cargo.lock index f1428bd8bea..8d355f264b1 100644 --- a/ql/Cargo.lock +++ b/ql/Cargo.lock @@ -96,6 +96,7 @@ dependencies = [ "js-sys", "num-integer", "num-traits", + "serde", "time", "wasm-bindgen", "winapi", @@ -116,6 +117,23 @@ dependencies = [ "vec_map", ] +[[package]] +name = "codeql-extractor" +version = "0.1.0" +dependencies = [ + "chrono", + "encoding", + "flate2", + "lazy_static", + "num_cpus", + "rayon", + "regex", + "serde", + "serde_json", + "tracing", + "tree-sitter", +] + [[package]] name = "codespan-reporting" version = "0.11.1" @@ -234,6 +252,70 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "encoding" +version = "0.2.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" +dependencies = [ + "encoding-index-japanese", + "encoding-index-korean", + "encoding-index-simpchinese", + "encoding-index-singlebyte", + "encoding-index-tradchinese", +] + +[[package]] +name = "encoding-index-japanese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-korean" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-simpchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-singlebyte" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-tradchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding_index_tests" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" + [[package]] name = "flate2" version = "1.0.25" @@ -364,14 +446,6 @@ dependencies = [ "adler", ] -[[package]] -name = "node-types" -version = "0.1.0" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -447,9 +521,7 @@ name = "ql-extractor" version = "0.1.0" dependencies = [ "clap", - "flate2", - "node-types", - "num_cpus", + "codeql-extractor", "rayon", "regex", "tracing", @@ -467,7 +539,7 @@ name = "ql-generator" version = "0.1.0" dependencies = [ "clap", - "node-types", + "codeql-extractor", "tracing", "tracing-subscriber", "tree-sitter-blame", diff --git a/ql/Cargo.toml b/ql/Cargo.toml index 4bc60c3333d..bb052f81d32 100644 --- a/ql/Cargo.toml +++ b/ql/Cargo.toml @@ -3,6 +3,5 @@ members = [ "autobuilder", "extractor", "generator", - "node-types", "buramu", ] diff --git a/ql/extractor/Cargo.toml b/ql/extractor/Cargo.toml index e7dc17f2feb..4f9b4aebef0 100644 --- a/ql/extractor/Cargo.toml +++ b/ql/extractor/Cargo.toml @@ -7,8 +7,6 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -flate2 = "1.0" -node-types = { path = "../node-types" } tree-sitter = ">= 0.20, < 0.21" tree-sitter-ql = { git = "https://github.com/tree-sitter/tree-sitter-ql.git", rev = "d08db734f8dc52f6bc04db53a966603122bc6985"} tree-sitter-ql-dbscheme = { git = "https://github.com/erik-krogh/tree-sitter-ql-dbscheme.git", rev = "63e1344353f63931e88bfbc2faa2e78e1421b213"} @@ -19,5 +17,5 @@ clap = "2.33" tracing = "0.1" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } rayon = "1.7.0" -num_cpus = "1.14.0" regex = "1.7.2" +codeql-extractor = { path = "../../shared/extractor" } diff --git a/ql/extractor/src/extractor.rs b/ql/extractor/src/extractor.rs deleted file mode 100644 index f5557e5a188..00000000000 --- a/ql/extractor/src/extractor.rs +++ /dev/null @@ -1,650 +0,0 @@ -use crate::trap; -use node_types::{EntryKind, Field, NodeTypeMap, Storage, TypeName}; -use std::collections::BTreeMap as Map; -use std::collections::BTreeSet as Set; -use std::fmt; -use std::path::Path; - -use tracing::{error, info, span, Level}; -use tree_sitter::{Language, Node, Parser, Range, Tree}; - -pub fn populate_file(writer: &mut trap::Writer, absolute_path: &Path) -> trap::Label { - let (file_label, fresh) = - writer.global_id(&trap::full_id_for_file(&normalize_path(absolute_path))); - if fresh { - writer.add_tuple( - "files", - vec![ - trap::Arg::Label(file_label), - trap::Arg::String(normalize_path(absolute_path)), - ], - ); - populate_parent_folders(writer, file_label, absolute_path.parent()); - } - file_label -} - -fn populate_empty_file(writer: &mut trap::Writer) -> trap::Label { - let (file_label, fresh) = writer.global_id("empty;sourcefile"); - if fresh { - writer.add_tuple( - "files", - vec![ - trap::Arg::Label(file_label), - trap::Arg::String("".to_string()), - ], - ); - } - file_label -} - -pub fn populate_empty_location(writer: &mut trap::Writer) { - let file_label = populate_empty_file(writer); - location(writer, file_label, 0, 0, 0, 0); -} - -pub fn populate_parent_folders( - writer: &mut trap::Writer, - child_label: trap::Label, - path: Option<&Path>, -) { - let mut path = path; - let mut child_label = child_label; - loop { - match path { - None => break, - Some(folder) => { - let (folder_label, fresh) = - writer.global_id(&trap::full_id_for_folder(&normalize_path(folder))); - writer.add_tuple( - "containerparent", - vec![ - trap::Arg::Label(folder_label), - trap::Arg::Label(child_label), - ], - ); - if fresh { - writer.add_tuple( - "folders", - vec![ - trap::Arg::Label(folder_label), - trap::Arg::String(normalize_path(folder)), - ], - ); - path = folder.parent(); - child_label = folder_label; - } else { - break; - } - } - } - } -} - -fn location( - writer: &mut trap::Writer, - file_label: trap::Label, - start_line: usize, - start_column: usize, - end_line: usize, - end_column: usize, -) -> trap::Label { - let (loc_label, fresh) = writer.global_id(&format!( - "loc,{{{}}},{},{},{},{}", - file_label, start_line, start_column, end_line, end_column - )); - if fresh { - writer.add_tuple( - "locations_default", - vec![ - trap::Arg::Label(loc_label), - trap::Arg::Label(file_label), - trap::Arg::Int(start_line), - trap::Arg::Int(start_column), - trap::Arg::Int(end_line), - trap::Arg::Int(end_column), - ], - ); - } - loc_label -} - -/// Extracts the source file at `path`, which is assumed to be canonicalized. -pub fn extract( - language: Language, - language_prefix: &str, - schema: &NodeTypeMap, - trap_writer: &mut trap::Writer, - path: &Path, - source: &[u8], - ranges: &[Range], -) -> std::io::Result<()> { - let path_str = format!("{}", path.display()); - let span = span!( - Level::TRACE, - "extract", - file = %path_str - ); - - let _enter = span.enter(); - - info!("extracting: {}", path_str); - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - parser.set_included_ranges(ranges).unwrap(); - let tree = parser.parse(&source, None).expect("Failed to parse file"); - trap_writer.comment(format!("Auto-generated TRAP file for {}", path_str)); - let file_label = populate_file(trap_writer, path); - let mut visitor = Visitor::new( - source, - trap_writer, - // TODO: should we handle path strings that are not valid UTF8 better? - &path_str, - file_label, - language_prefix, - schema, - ); - traverse(&tree, &mut visitor); - - parser.reset(); - Ok(()) -} - -/// Normalizes the path according the common CodeQL specification. Assumes that -/// `path` has already been canonicalized using `std::fs::canonicalize`. -fn normalize_path(path: &Path) -> String { - if cfg!(windows) { - // The way Rust canonicalizes paths doesn't match the CodeQL spec, so we - // have to do a bit of work removing certain prefixes and replacing - // backslashes. - let mut components: Vec = Vec::new(); - for component in path.components() { - match component { - std::path::Component::Prefix(prefix) => match prefix.kind() { - std::path::Prefix::Disk(letter) | std::path::Prefix::VerbatimDisk(letter) => { - components.push(format!("{}:", letter as char)); - } - std::path::Prefix::Verbatim(x) | std::path::Prefix::DeviceNS(x) => { - components.push(x.to_string_lossy().to_string()); - } - std::path::Prefix::UNC(server, share) - | std::path::Prefix::VerbatimUNC(server, share) => { - components.push(server.to_string_lossy().to_string()); - components.push(share.to_string_lossy().to_string()); - } - }, - std::path::Component::Normal(n) => { - components.push(n.to_string_lossy().to_string()); - } - std::path::Component::RootDir => {} - std::path::Component::CurDir => {} - std::path::Component::ParentDir => {} - } - } - components.join("/") - } else { - // For other operating systems, we can use the canonicalized path - // without modifications. - format!("{}", path.display()) - } -} - -struct ChildNode { - field_name: Option<&'static str>, - label: trap::Label, - type_name: TypeName, -} - -struct Visitor<'a> { - /// The file path of the source code (as string) - path: &'a str, - /// The label to use whenever we need to refer to the `@file` entity of this - /// source file. - file_label: trap::Label, - /// The source code as a UTF-8 byte array - source: &'a [u8], - /// A trap::Writer to accumulate trap entries - trap_writer: &'a mut trap::Writer, - /// A counter for top-level child nodes - toplevel_child_counter: usize, - /// Language-specific name of the AST info table - ast_node_info_table_name: String, - /// Language-specific name of the tokeninfo table - tokeninfo_table_name: String, - /// A lookup table from type name to node types - schema: &'a NodeTypeMap, - /// A stack for gathering information from child nodes. Whenever a node is - /// entered the parent's [Label], child counter, and an empty list is pushed. - /// All children append their data to the list. When the visitor leaves a - /// node the list containing the child data is popped from the stack and - /// matched against the dbscheme for the node. If the expectations are met - /// the corresponding row definitions are added to the trap_output. - stack: Vec<(trap::Label, usize, Vec)>, -} - -impl<'a> Visitor<'a> { - fn new( - source: &'a [u8], - trap_writer: &'a mut trap::Writer, - path: &'a str, - file_label: trap::Label, - language_prefix: &str, - schema: &'a NodeTypeMap, - ) -> Visitor<'a> { - Visitor { - path, - file_label, - source, - trap_writer, - toplevel_child_counter: 0, - ast_node_info_table_name: format!("{}_ast_node_info", language_prefix), - tokeninfo_table_name: format!("{}_tokeninfo", language_prefix), - schema, - stack: Vec::new(), - } - } - - fn record_parse_error( - &mut self, - error_message: String, - full_error_message: String, - loc: trap::Label, - ) { - error!("{}", full_error_message); - let id = self.trap_writer.fresh_id(); - self.trap_writer.add_tuple( - "diagnostics", - vec![ - trap::Arg::Label(id), - trap::Arg::Int(40), // severity 40 = error - trap::Arg::String("parse_error".to_string()), - trap::Arg::String(error_message), - trap::Arg::String(full_error_message), - trap::Arg::Label(loc), - ], - ); - } - - fn record_parse_error_for_node( - &mut self, - error_message: String, - full_error_message: String, - node: Node, - ) { - let (start_line, start_column, end_line, end_column) = location_for(self.source, node); - let loc = location( - self.trap_writer, - self.file_label, - start_line, - start_column, - end_line, - end_column, - ); - self.record_parse_error(error_message, full_error_message, loc); - } - - fn enter_node(&mut self, node: Node) -> bool { - if node.is_error() || node.is_missing() { - let error_message = if node.is_missing() { - format!("parse error: expecting '{}'", node.kind()) - } else { - "parse error".to_string() - }; - let full_error_message = format!( - "{}:{}: {}", - &self.path, - node.start_position().row + 1, - error_message - ); - self.record_parse_error_for_node(error_message, full_error_message, node); - return false; - } - - let id = self.trap_writer.fresh_id(); - - self.stack.push((id, 0, Vec::new())); - true - } - - fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) { - if node.is_error() || node.is_missing() { - return; - } - let (id, _, child_nodes) = self.stack.pop().expect("Vistor: empty stack"); - let (start_line, start_column, end_line, end_column) = location_for(self.source, node); - let loc = location( - self.trap_writer, - self.file_label, - start_line, - start_column, - end_line, - end_column, - ); - let table = self - .schema - .get(&TypeName { - kind: node.kind().to_owned(), - named: node.is_named(), - }) - .unwrap(); - let mut valid = true; - let (parent_id, parent_index) = match self.stack.last_mut() { - Some(p) if !node.is_extra() => { - p.1 += 1; - (p.0, p.1 - 1) - } - _ => { - self.toplevel_child_counter += 1; - (self.file_label, self.toplevel_child_counter - 1) - } - }; - match &table.kind { - EntryKind::Token { kind_id, .. } => { - self.trap_writer.add_tuple( - &self.ast_node_info_table_name, - vec![ - trap::Arg::Label(id), - trap::Arg::Label(parent_id), - trap::Arg::Int(parent_index), - trap::Arg::Label(loc), - ], - ); - self.trap_writer.add_tuple( - &self.tokeninfo_table_name, - vec![ - trap::Arg::Label(id), - trap::Arg::Int(*kind_id), - sliced_source_arg(self.source, node), - ], - ); - } - EntryKind::Table { - fields, - name: table_name, - } => { - if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) { - self.trap_writer.add_tuple( - &self.ast_node_info_table_name, - vec![ - trap::Arg::Label(id), - trap::Arg::Label(parent_id), - trap::Arg::Int(parent_index), - trap::Arg::Label(loc), - ], - ); - let mut all_args = vec![trap::Arg::Label(id)]; - all_args.extend(args); - self.trap_writer.add_tuple(table_name, all_args); - } - } - _ => { - let error_message = format!("unknown table type: '{}'", node.kind()); - let full_error_message = format!( - "{}:{}: {}", - &self.path, - node.start_position().row + 1, - error_message - ); - self.record_parse_error(error_message, full_error_message, loc); - - valid = false; - } - } - if valid && !node.is_extra() { - // Extra nodes are independent root nodes and do not belong to the parent node - // Therefore we should not register them in the parent vector - if let Some(parent) = self.stack.last_mut() { - parent.2.push(ChildNode { - field_name, - label: id, - type_name: TypeName { - kind: node.kind().to_owned(), - named: node.is_named(), - }, - }); - }; - } - } - - fn complex_node( - &mut self, - node: &Node, - fields: &[Field], - child_nodes: &[ChildNode], - parent_id: trap::Label, - ) -> Option> { - let mut map: Map<&Option, (&Field, Vec)> = Map::new(); - for field in fields { - map.insert(&field.name, (field, Vec::new())); - } - for child_node in child_nodes { - if let Some((field, values)) = map.get_mut(&child_node.field_name.map(|x| x.to_owned())) - { - //TODO: handle error and missing nodes - if self.type_matches(&child_node.type_name, &field.type_info) { - if let node_types::FieldTypeInfo::ReservedWordInt(int_mapping) = - &field.type_info - { - // We can safely unwrap because type_matches checks the key is in the map. - let (int_value, _) = int_mapping.get(&child_node.type_name.kind).unwrap(); - values.push(trap::Arg::Int(*int_value)); - } else { - values.push(trap::Arg::Label(child_node.label)); - } - } else if field.name.is_some() { - let error_message = format!( - "type mismatch for field {}::{} with type {:?} != {:?}", - node.kind(), - child_node.field_name.unwrap_or("child"), - child_node.type_name, - field.type_info - ); - let full_error_message = format!( - "{}:{}: {}", - &self.path, - node.start_position().row + 1, - error_message - ); - self.record_parse_error_for_node(error_message, full_error_message, *node); - } - } else if child_node.field_name.is_some() || child_node.type_name.named { - let error_message = format!( - "value for unknown field: {}::{} and type {:?}", - node.kind(), - &child_node.field_name.unwrap_or("child"), - &child_node.type_name - ); - let full_error_message = format!( - "{}:{}: {}", - &self.path, - node.start_position().row + 1, - error_message - ); - self.record_parse_error_for_node(error_message, full_error_message, *node); - } - } - let mut args = Vec::new(); - let mut is_valid = true; - for field in fields { - let child_values = &map.get(&field.name).unwrap().1; - match &field.storage { - Storage::Column { name: column_name } => { - if child_values.len() == 1 { - args.push(child_values.first().unwrap().clone()); - } else { - is_valid = false; - let error_message = format!( - "{} for field: {}::{}", - if child_values.is_empty() { - "missing value" - } else { - "too many values" - }, - node.kind(), - column_name - ); - let full_error_message = format!( - "{}:{}: {}", - &self.path, - node.start_position().row + 1, - error_message - ); - self.record_parse_error_for_node(error_message, full_error_message, *node); - } - } - Storage::Table { - name: table_name, - has_index, - column_name: _, - } => { - for (index, child_value) in child_values.iter().enumerate() { - if !*has_index && index > 0 { - error!( - "{}:{}: too many values for field: {}::{}", - &self.path, - node.start_position().row + 1, - node.kind(), - table_name, - ); - break; - } - let mut args = vec![trap::Arg::Label(parent_id)]; - if *has_index { - args.push(trap::Arg::Int(index)) - } - args.push(child_value.clone()); - self.trap_writer.add_tuple(table_name, args); - } - } - } - } - if is_valid { - Some(args) - } else { - None - } - } - - fn type_matches(&self, tp: &TypeName, type_info: &node_types::FieldTypeInfo) -> bool { - match type_info { - node_types::FieldTypeInfo::Single(single_type) => { - if tp == single_type { - return true; - } - if let EntryKind::Union { members } = &self.schema.get(single_type).unwrap().kind { - if self.type_matches_set(tp, members) { - return true; - } - } - } - node_types::FieldTypeInfo::Multiple { types, .. } => { - return self.type_matches_set(tp, types); - } - - node_types::FieldTypeInfo::ReservedWordInt(int_mapping) => { - return !tp.named && int_mapping.contains_key(&tp.kind) - } - } - false - } - - fn type_matches_set(&self, tp: &TypeName, types: &Set) -> bool { - if types.contains(tp) { - return true; - } - for other in types.iter() { - if let EntryKind::Union { members } = &self.schema.get(other).unwrap().kind { - if self.type_matches_set(tp, members) { - return true; - } - } - } - false - } -} - -// Emit a slice of a source file as an Arg. -fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg { - let range = n.byte_range(); - trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned()) -} - -// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. -// The first is the location and label definition, and the second is the -// 'Located' entry. -fn location_for(source: &[u8], n: Node) -> (usize, usize, usize, usize) { - // Tree-sitter row, column values are 0-based while CodeQL starts - // counting at 1. In addition Tree-sitter's row and column for the - // end position are exclusive while CodeQL's end positions are inclusive. - // This means that all values should be incremented by 1 and in addition the - // end position needs to be shift 1 to the left. In most cases this means - // simply incrementing all values except the end column except in cases where - // the end column is 0 (start of a line). In such cases the end position must be - // set to the end of the previous line. - let start_line = n.start_position().row + 1; - let start_col = n.start_position().column + 1; - let mut end_line = n.end_position().row + 1; - let mut end_col = n.end_position().column; - if start_line > end_line || start_line == end_line && start_col > end_col { - // the range is empty, clip it to sensible values - end_line = start_line; - end_col = start_col - 1; - } else if end_col == 0 { - // end_col = 0 means that we are at the start of a line - // unfortunately 0 is invalid as column number, therefore - // we should update the end location to be the end of the - // previous line - let mut index = n.end_byte(); - if index > 0 && index <= source.len() { - index -= 1; - if source[index] != b'\n' { - error!("expecting a line break symbol, but none found while correcting end column value"); - } - end_line -= 1; - end_col = 1; - while index > 0 && source[index - 1] != b'\n' { - index -= 1; - end_col += 1; - } - } else { - error!( - "cannot correct end column value: end_byte index {} is not in range [1,{}]", - index, - source.len() - ); - } - } - (start_line, start_col, end_line, end_col) -} - -fn traverse(tree: &Tree, visitor: &mut Visitor) { - let cursor = &mut tree.walk(); - visitor.enter_node(cursor.node()); - let mut recurse = true; - loop { - if recurse && cursor.goto_first_child() { - recurse = visitor.enter_node(cursor.node()); - } else { - visitor.leave_node(cursor.field_name(), cursor.node()); - - if cursor.goto_next_sibling() { - recurse = visitor.enter_node(cursor.node()); - } else if cursor.goto_parent() { - recurse = false; - } else { - break; - } - } - } -} - -// Numeric indices. -#[derive(Debug, Copy, Clone)] -struct Index(usize); - -impl fmt::Display for Index { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) - } -} diff --git a/ql/extractor/src/main.rs b/ql/extractor/src/main.rs index da7f263e7ef..210abaf49d4 100644 --- a/ql/extractor/src/main.rs +++ b/ql/extractor/src/main.rs @@ -1,41 +1,9 @@ -mod extractor; -mod trap; - -extern crate num_cpus; - use rayon::prelude::*; use std::fs; use std::io::BufRead; use std::path::{Path, PathBuf}; -/** - * Gets the number of threads the extractor should use, by reading the - * CODEQL_THREADS environment variable and using it as described in the - * extractor spec: - * - * "If the number is positive, it indicates the number of threads that should - * be used. If the number is negative or zero, it should be added to the number - * of cores available on the machine to determine how many threads to use - * (minimum of 1). If unspecified, should be considered as set to -1." - */ -fn num_codeql_threads() -> usize { - let threads_str = std::env::var("CODEQL_THREADS").unwrap_or_else(|_| "-1".to_owned()); - match threads_str.parse::() { - Ok(num) if num <= 0 => { - let reduction = -num as usize; - std::cmp::max(1, num_cpus::get() - reduction) - } - Ok(num) => num as usize, - - Err(_) => { - tracing::error!( - "Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.", - &threads_str - ); - 1 - } - } -} +use codeql_extractor::{diagnostics, extractor, node_types, trap}; fn main() -> std::io::Result<()> { tracing_subscriber::fmt() @@ -45,7 +13,23 @@ fn main() -> std::io::Result<()> { .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .init(); - let num_threads = num_codeql_threads(); + let diagnostics = diagnostics::DiagnosticLoggers::new("ql"); + let mut main_thread_logger = diagnostics.logger(); + let num_threads = match codeql_extractor::options::num_threads() { + Ok(num) => num, + Err(e) => { + main_thread_logger.write( + main_thread_logger + .new_entry("configuration-error", "Configuration error") + .message( + "{}; defaulting to 1 thread.", + &[diagnostics::MessageArg::Code(&e)], + ) + .severity(diagnostics::Severity::Warning), + ); + 1 + } + }; tracing::info!( "Using {} {}", num_threads, @@ -55,6 +39,20 @@ fn main() -> std::io::Result<()> { "threads" } ); + let trap_compression = match trap::Compression::from_env("CODEQL_QL_TRAP_COMPRESSION") { + Ok(x) => x, + Err(e) => { + main_thread_logger.write( + main_thread_logger + .new_entry("configuration-error", "Configuration error") + .message("{}; using gzip.", &[diagnostics::MessageArg::Code(&e)]) + .severity(diagnostics::Severity::Warning), + ); + trap::Compression::Gzip + } + }; + drop(main_thread_logger); + rayon::ThreadPoolBuilder::new() .num_threads(num_threads) .build_global() @@ -79,7 +77,6 @@ fn main() -> std::io::Result<()> { .value_of("output-dir") .expect("missing --output-dir"); let trap_dir = PathBuf::from(trap_dir); - let trap_compression = trap::Compression::from_env("CODEQL_QL_TRAP_COMPRESSION"); let file_list = matches.value_of("file-list").expect("missing --file-list"); let file_list = fs::File::open(file_list)?; @@ -119,26 +116,29 @@ fn main() -> std::io::Result<()> { let source = std::fs::read(&path)?; let code_ranges = vec![]; let mut trap_writer = trap::Writer::new(); + let mut diagnostics_writer = diagnostics.logger(); if line.ends_with(".dbscheme") { extractor::extract( dbscheme, "dbscheme", &dbscheme_schema, + &mut diagnostics_writer, &mut trap_writer, &path, &source, &code_ranges, - )? + ) } else if line.ends_with("qlpack.yml") { extractor::extract( yaml, "yaml", &yaml_schema, + &mut diagnostics_writer, &mut trap_writer, &path, &source, &code_ranges, - )? + ) } else if line.ends_with(".json") || line.ends_with(".jsonl") || line.ends_with(".jsonc") @@ -147,31 +147,34 @@ fn main() -> std::io::Result<()> { json, "json", &json_schema, + &mut diagnostics_writer, &mut trap_writer, &path, &source, &code_ranges, - )? + ) } else if line.ends_with(".blame") { extractor::extract( blame, "blame", &blame_schema, + &mut diagnostics_writer, &mut trap_writer, &path, &source, &code_ranges, - )? + ) } else { extractor::extract( language, "ql", &schema, + &mut diagnostics_writer, &mut trap_writer, &path, &source, &code_ranges, - )? + ) } std::fs::create_dir_all(&src_archive_file.parent().unwrap())?; std::fs::copy(&path, &src_archive_file)?; diff --git a/ql/extractor/src/trap.rs b/ql/extractor/src/trap.rs deleted file mode 100644 index 35a9b69f255..00000000000 --- a/ql/extractor/src/trap.rs +++ /dev/null @@ -1,275 +0,0 @@ -use std::borrow::Cow; -use std::fmt; -use std::io::{BufWriter, Write}; -use std::path::Path; - -use flate2::write::GzEncoder; - -pub struct Writer { - /// The accumulated trap entries - trap_output: Vec, - /// A counter for generating fresh labels - counter: u32, - /// cache of global keys - global_keys: std::collections::HashMap, -} - -impl Writer { - pub fn new() -> Writer { - Writer { - counter: 0, - trap_output: Vec::new(), - global_keys: std::collections::HashMap::new(), - } - } - - pub fn fresh_id(&mut self) -> Label { - let label = Label(self.counter); - self.counter += 1; - self.trap_output.push(Entry::FreshId(label)); - label - } - - /// Gets a label that will hold the unique ID of the passed string at import time. - /// This can be used for incrementally importable TRAP files -- use globally unique - /// strings to compute a unique ID for table tuples. - /// - /// Note: You probably want to make sure that the key strings that you use are disjoint - /// for disjoint column types; the standard way of doing this is to prefix (or append) - /// the column type name to the ID. Thus, you might identify methods in Java by the - /// full ID "methods_com.method.package.DeclaringClass.method(argumentList)". - pub fn global_id(&mut self, key: &str) -> (Label, bool) { - if let Some(label) = self.global_keys.get(key) { - return (*label, false); - } - let label = Label(self.counter); - self.counter += 1; - self.global_keys.insert(key.to_owned(), label); - self.trap_output - .push(Entry::MapLabelToKey(label, key.to_owned())); - (label, true) - } - - pub fn add_tuple(&mut self, table_name: &str, args: Vec) { - self.trap_output - .push(Entry::GenericTuple(table_name.to_owned(), args)) - } - - pub fn comment(&mut self, text: String) { - self.trap_output.push(Entry::Comment(text)); - } - - pub fn write_to_file(&self, path: &Path, compression: Compression) -> std::io::Result<()> { - let trap_file = std::fs::File::create(path)?; - match compression { - Compression::None => { - let mut trap_file = BufWriter::new(trap_file); - self.write_trap_entries(&mut trap_file) - } - Compression::Gzip => { - let trap_file = GzEncoder::new(trap_file, flate2::Compression::fast()); - let mut trap_file = BufWriter::new(trap_file); - self.write_trap_entries(&mut trap_file) - } - } - } - - fn write_trap_entries(&self, file: &mut W) -> std::io::Result<()> { - for trap_entry in &self.trap_output { - writeln!(file, "{}", trap_entry)?; - } - std::io::Result::Ok(()) - } -} - -pub enum Entry { - /// Maps the label to a fresh id, e.g. `#123=*`. - FreshId(Label), - /// Maps the label to a key, e.g. `#7=@"foo"`. - MapLabelToKey(Label, String), - /// foo_bar(arg*) - GenericTuple(String, Vec), - Comment(String), -} - -impl fmt::Display for Entry { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Entry::FreshId(label) => write!(f, "{}=*", label), - Entry::MapLabelToKey(label, key) => { - write!(f, "{}=@\"{}\"", label, key.replace("\"", "\"\"")) - } - Entry::GenericTuple(name, args) => { - write!(f, "{}(", name)?; - for (index, arg) in args.iter().enumerate() { - if index > 0 { - write!(f, ",")?; - } - write!(f, "{}", arg)?; - } - write!(f, ")") - } - Entry::Comment(line) => write!(f, "// {}", line), - } - } -} - -#[derive(Debug, Copy, Clone)] -// Identifiers of the form #0, #1... -pub struct Label(u32); - -impl fmt::Display for Label { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "#{:x}", self.0) - } -} - -// Some untyped argument to a TrapEntry. -#[derive(Debug, Clone)] -pub enum Arg { - Label(Label), - Int(usize), - String(String), -} - -const MAX_STRLEN: usize = 1048576; - -impl fmt::Display for Arg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Arg::Label(x) => write!(f, "{}", x), - Arg::Int(x) => write!(f, "{}", x), - Arg::String(x) => write!( - f, - "\"{}\"", - limit_string(x, MAX_STRLEN).replace("\"", "\"\"") - ), - } - } -} - -pub struct Program(Vec); - -impl fmt::Display for Program { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut text = String::new(); - for trap_entry in &self.0 { - text.push_str(&format!("{}\n", trap_entry)); - } - write!(f, "{}", text) - } -} - -pub fn full_id_for_file(normalized_path: &str) -> String { - format!("{};sourcefile", escape_key(normalized_path)) -} - -pub fn full_id_for_folder(normalized_path: &str) -> String { - format!("{};folder", escape_key(normalized_path)) -} - -/// Escapes a string for use in a TRAP key, by replacing special characters with -/// HTML entities. -fn escape_key<'a, S: Into>>(key: S) -> Cow<'a, str> { - fn needs_escaping(c: char) -> bool { - matches!(c, '&' | '{' | '}' | '"' | '@' | '#') - } - - let key = key.into(); - if key.contains(needs_escaping) { - let mut escaped = String::with_capacity(2 * key.len()); - for c in key.chars() { - match c { - '&' => escaped.push_str("&"), - '{' => escaped.push_str("{"), - '}' => escaped.push_str("}"), - '"' => escaped.push_str("""), - '@' => escaped.push_str("@"), - '#' => escaped.push_str("#"), - _ => escaped.push(c), - } - } - Cow::Owned(escaped) - } else { - key - } -} - -/// Limit the length (in bytes) of a string. If the string's length in bytes is -/// less than or equal to the limit then the entire string is returned. Otherwise -/// the string is sliced at the provided limit. If there is a multi-byte character -/// at the limit then the returned slice will be slightly shorter than the limit to -/// avoid splitting that multi-byte character. -fn limit_string(string: &str, max_size: usize) -> &str { - if string.len() <= max_size { - return string; - } - let p = string.as_bytes(); - let mut index = max_size; - // We want to clip the string at [max_size]; however, the character at that position - // may span several bytes. We need to find the first byte of the character. In UTF-8 - // encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte. - // Therefore we decrement the index as long as there are bytes matching this pattern. - // This ensures we cut the string at the border between one character and another. - while index > 0 && (p[index] & 0b11000000) == 0b10000000 { - index -= 1; - } - &string[0..index] -} - -#[derive(Clone, Copy)] -pub enum Compression { - None, - Gzip, -} - -impl Compression { - pub fn from_env(var_name: &str) -> Compression { - match std::env::var(var_name) { - Ok(method) => match Compression::from_string(&method) { - Some(c) => c, - None => { - tracing::error!("Unknown compression method '{}'; using gzip.", &method); - Compression::Gzip - } - }, - // Default compression method if the env var isn't set: - Err(_) => Compression::Gzip, - } - } - - pub fn from_string(s: &str) -> Option { - match s.to_lowercase().as_ref() { - "none" => Some(Compression::None), - "gzip" => Some(Compression::Gzip), - _ => None, - } - } - - pub fn extension(&self) -> &str { - match self { - Compression::None => "trap", - Compression::Gzip => "trap.gz", - } - } -} - -#[test] -fn limit_string_test() { - assert_eq!("hello", limit_string(&"hello world".to_owned(), 5)); - assert_eq!("hi ☹", limit_string(&"hi ☹☹".to_owned(), 6)); - assert_eq!("hi ", limit_string(&"hi ☹☹".to_owned(), 5)); -} - -#[test] -fn escape_key_test() { - assert_eq!("foo!", escape_key("foo!")); - assert_eq!("foo{}", escape_key("foo{}")); - assert_eq!("{}", escape_key("{}")); - assert_eq!("", escape_key("")); - assert_eq!("/path/to/foo.rb", escape_key("/path/to/foo.rb")); - assert_eq!( - "/path/to/foo&{}"@#.rb", - escape_key("/path/to/foo&{}\"@#.rb") - ); -} diff --git a/ql/generator/Cargo.toml b/ql/generator/Cargo.toml index 4fcc98be310..3a5665f33ae 100644 --- a/ql/generator/Cargo.toml +++ b/ql/generator/Cargo.toml @@ -8,7 +8,6 @@ edition = "2018" [dependencies] clap = "2.33" -node-types = { path = "../node-types" } tracing = "0.1" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } tree-sitter-ql = { git = "https://github.com/tree-sitter/tree-sitter-ql.git", rev = "d08db734f8dc52f6bc04db53a966603122bc6985"} @@ -16,3 +15,4 @@ tree-sitter-ql-dbscheme = { git = "https://github.com/erik-krogh/tree-sitter-ql- tree-sitter-ql-yaml = {git = "https://github.com/erik-krogh/tree-sitter-ql.git", rev = "cf704bf3671e1ae148e173464fb65a4d2bbf5f99"} tree-sitter-blame = {path = "../buramu/tree-sitter-blame"} tree-sitter-json = { git = "https://github.com/tausbn/tree-sitter-json.git", rev = "745663ee997f1576fe1e7187e6347e0db36ec7a9"} +codeql-extractor = { path = "../../shared/extractor" } diff --git a/ql/generator/src/dbscheme.rs b/ql/generator/src/dbscheme.rs deleted file mode 100644 index 335eee1950c..00000000000 --- a/ql/generator/src/dbscheme.rs +++ /dev/null @@ -1,130 +0,0 @@ -use crate::ql; -use std::collections::BTreeSet as Set; -use std::fmt; -/// Represents a distinct entry in the database schema. -pub enum Entry<'a> { - /// An entry defining a database table. - Table(Table<'a>), - /// An entry defining a database table. - Case(Case<'a>), - /// An entry defining type that is a union of other types. - Union(Union<'a>), -} - -/// A table in the database schema. -pub struct Table<'a> { - pub name: &'a str, - pub columns: Vec>, - pub keysets: Option>, -} - -/// A union in the database schema. -pub struct Union<'a> { - pub name: &'a str, - pub members: Set<&'a str>, -} - -/// A table in the database schema. -pub struct Case<'a> { - pub name: &'a str, - pub column: &'a str, - pub branches: Vec<(usize, &'a str)>, -} - -/// A column in a table. -pub struct Column<'a> { - pub db_type: DbColumnType, - pub name: &'a str, - pub unique: bool, - pub ql_type: ql::Type<'a>, - pub ql_type_is_ref: bool, -} - -/// The database column type. -pub enum DbColumnType { - Int, - String, -} - -impl<'a> fmt::Display for Case<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "case @{}.{} of", &self.name, &self.column)?; - let mut sep = " "; - for (c, tp) in &self.branches { - writeln!(f, "{} {} = @{}", sep, c, tp)?; - sep = "|"; - } - writeln!(f, ";") - } -} - -impl<'a> fmt::Display for Table<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(keyset) = &self.keysets { - write!(f, "#keyset[")?; - for (key_index, key) in keyset.iter().enumerate() { - if key_index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", key)?; - } - writeln!(f, "]")?; - } - - writeln!(f, "{}(", self.name)?; - for (column_index, column) in self.columns.iter().enumerate() { - write!(f, " ")?; - if column.unique { - write!(f, "unique ")?; - } - write!( - f, - "{} ", - match column.db_type { - DbColumnType::Int => "int", - DbColumnType::String => "string", - } - )?; - write!(f, "{}: {}", column.name, column.ql_type)?; - if column.ql_type_is_ref { - write!(f, " ref")?; - } - if column_index + 1 != self.columns.len() { - write!(f, ",")?; - } - writeln!(f)?; - } - write!(f, ");")?; - - Ok(()) - } -} - -impl<'a> fmt::Display for Union<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "@{} = ", self.name)?; - let mut first = true; - for member in &self.members { - if first { - first = false; - } else { - write!(f, " | ")?; - } - write!(f, "@{}", member)?; - } - Ok(()) - } -} - -/// Generates the dbscheme by writing the given dbscheme `entries` to the `file`. -pub fn write<'a>(file: &mut dyn std::io::Write, entries: &'a [Entry]) -> std::io::Result<()> { - for entry in entries { - match entry { - Entry::Case(case) => write!(file, "{}\n\n", case)?, - Entry::Table(table) => write!(file, "{}\n\n", table)?, - Entry::Union(union) => write!(file, "{}\n\n", union)?, - } - } - - Ok(()) -} diff --git a/ql/generator/src/language.rs b/ql/generator/src/language.rs deleted file mode 100644 index f0b0ed1790f..00000000000 --- a/ql/generator/src/language.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub struct Language { - pub name: String, - pub node_types: &'static str, -} diff --git a/ql/generator/src/main.rs b/ql/generator/src/main.rs index 8c6bf63d859..8baed6d21a5 100644 --- a/ql/generator/src/main.rs +++ b/ql/generator/src/main.rs @@ -1,9 +1,3 @@ -mod dbscheme; -mod language; -mod ql; -mod ql_gen; - -use language::Language; use std::collections::BTreeMap as Map; use std::collections::BTreeSet as Set; use std::fs::File; @@ -11,6 +5,9 @@ use std::io::LineWriter; use std::io::Write; use std::path::PathBuf; +use codeql_extractor::generator::{dbscheme, language::Language, ql, ql_gen}; +use codeql_extractor::node_types; + /// Given the name of the parent node, and its field information, returns a pair, /// the first of which is the field's type. The second is an optional dbscheme /// entry that should be added. diff --git a/ql/generator/src/ql.rs b/ql/generator/src/ql.rs deleted file mode 100644 index 7dd94f24bea..00000000000 --- a/ql/generator/src/ql.rs +++ /dev/null @@ -1,295 +0,0 @@ -use std::collections::BTreeSet; -use std::fmt; - -#[derive(Clone, Eq, PartialEq, Hash)] -pub enum TopLevel<'a> { - Class(Class<'a>), - Import(Import<'a>), - Module(Module<'a>), -} - -impl<'a> fmt::Display for TopLevel<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - TopLevel::Import(imp) => write!(f, "{}", imp), - TopLevel::Class(cls) => write!(f, "{}", cls), - TopLevel::Module(m) => write!(f, "{}", m), - } - } -} - -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Import<'a> { - pub module: &'a str, - pub alias: Option<&'a str>, -} - -impl<'a> fmt::Display for Import<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "import {}", &self.module)?; - if let Some(name) = &self.alias { - write!(f, " as {}", name)?; - } - Ok(()) - } -} -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Class<'a> { - pub qldoc: Option, - pub name: &'a str, - pub is_abstract: bool, - pub supertypes: BTreeSet>, - pub characteristic_predicate: Option>, - pub predicates: Vec>, -} - -impl<'a> fmt::Display for Class<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(qldoc) = &self.qldoc { - write!(f, "/** {} */", qldoc)?; - } - if self.is_abstract { - write!(f, "abstract ")?; - } - write!(f, "class {} extends ", &self.name)?; - for (index, supertype) in self.supertypes.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", supertype)?; - } - writeln!(f, " {{ ")?; - - if let Some(charpred) = &self.characteristic_predicate { - writeln!( - f, - " {}", - Predicate { - qldoc: None, - name: self.name, - overridden: false, - is_final: false, - return_type: None, - formal_parameters: vec![], - body: charpred.clone(), - } - )?; - } - - for predicate in &self.predicates { - writeln!(f, " {}", predicate)?; - } - - write!(f, "}}")?; - - Ok(()) - } -} - -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Module<'a> { - pub qldoc: Option, - pub name: &'a str, - pub body: Vec>, -} - -impl<'a> fmt::Display for Module<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(qldoc) = &self.qldoc { - write!(f, "/** {} */", qldoc)?; - } - writeln!(f, "module {} {{ ", self.name)?; - for decl in &self.body { - writeln!(f, " {}", decl)?; - } - write!(f, "}}")?; - Ok(()) - } -} -// The QL type of a column. -#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub enum Type<'a> { - /// Primitive `int` type. - Int, - - /// Primitive `string` type. - String, - - /// A database type that will need to be referred to with an `@` prefix. - At(&'a str), - - /// A user-defined type. - Normal(&'a str), -} - -impl<'a> fmt::Display for Type<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Type::Int => write!(f, "int"), - Type::String => write!(f, "string"), - Type::Normal(name) => write!(f, "{}", name), - Type::At(name) => write!(f, "@{}", name), - } - } -} - -#[derive(Clone, Eq, PartialEq, Hash)] -pub enum Expression<'a> { - Var(&'a str), - String(&'a str), - Integer(usize), - Pred(&'a str, Vec>), - And(Vec>), - Or(Vec>), - Equals(Box>, Box>), - Dot(Box>, &'a str, Vec>), - Aggregate { - name: &'a str, - vars: Vec>, - range: Option>>, - expr: Box>, - second_expr: Option>>, - }, -} - -impl<'a> fmt::Display for Expression<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expression::Var(x) => write!(f, "{}", x), - Expression::String(s) => write!(f, "\"{}\"", s), - Expression::Integer(n) => write!(f, "{}", n), - Expression::Pred(n, args) => { - write!(f, "{}(", n)?; - for (index, arg) in args.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", arg)?; - } - write!(f, ")") - } - Expression::And(conjuncts) => { - if conjuncts.is_empty() { - write!(f, "any()") - } else { - for (index, conjunct) in conjuncts.iter().enumerate() { - if index > 0 { - write!(f, " and ")?; - } - write!(f, "({})", conjunct)?; - } - Ok(()) - } - } - Expression::Or(disjuncts) => { - if disjuncts.is_empty() { - write!(f, "none()") - } else { - for (index, disjunct) in disjuncts.iter().enumerate() { - if index > 0 { - write!(f, " or ")?; - } - write!(f, "({})", disjunct)?; - } - Ok(()) - } - } - Expression::Equals(a, b) => write!(f, "{} = {}", a, b), - Expression::Dot(x, member_pred, args) => { - write!(f, "{}.{}(", x, member_pred)?; - for (index, arg) in args.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", arg)?; - } - write!(f, ")") - } - Expression::Aggregate { - name, - vars, - range, - expr, - second_expr, - } => { - write!(f, "{}(", name)?; - if !vars.is_empty() { - for (index, var) in vars.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", var)?; - } - write!(f, " | ")?; - } - if let Some(range) = range { - write!(f, "{} | ", range)?; - } - write!(f, "{}", expr)?; - if let Some(second_expr) = second_expr { - write!(f, ", {}", second_expr)?; - } - write!(f, ")") - } - } - } -} - -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Predicate<'a> { - pub qldoc: Option, - pub name: &'a str, - pub overridden: bool, - pub is_final: bool, - pub return_type: Option>, - pub formal_parameters: Vec>, - pub body: Expression<'a>, -} - -impl<'a> fmt::Display for Predicate<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(qldoc) = &self.qldoc { - write!(f, "/** {} */", qldoc)?; - } - if self.is_final { - write!(f, "final ")?; - } - if self.overridden { - write!(f, "override ")?; - } - match &self.return_type { - None => write!(f, "predicate ")?, - Some(return_type) => write!(f, "{} ", return_type)?, - } - write!(f, "{}(", self.name)?; - for (index, param) in self.formal_parameters.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", param)?; - } - write!(f, ") {{ {} }}", self.body)?; - - Ok(()) - } -} - -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct FormalParameter<'a> { - pub name: &'a str, - pub param_type: Type<'a>, -} - -impl<'a> fmt::Display for FormalParameter<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self.param_type, self.name) - } -} - -/// Generates a QL library by writing the given `elements` to the `file`. -pub fn write<'a>(file: &mut dyn std::io::Write, elements: &'a [TopLevel]) -> std::io::Result<()> { - for element in elements { - write!(file, "{}\n\n", &element)?; - } - Ok(()) -} diff --git a/ql/generator/src/ql_gen.rs b/ql/generator/src/ql_gen.rs deleted file mode 100644 index 007509e0074..00000000000 --- a/ql/generator/src/ql_gen.rs +++ /dev/null @@ -1,565 +0,0 @@ -use crate::ql; -use std::collections::BTreeSet; - -/// Creates the hard-coded `AstNode` class that acts as a supertype of all -/// classes we generate. -pub fn create_ast_node_class<'a>(ast_node: &'a str, node_info_table: &'a str) -> ql::Class<'a> { - // Default implementation of `toString` calls `this.getAPrimaryQlClass()` - let to_string = ql::Predicate { - qldoc: Some(String::from( - "Gets a string representation of this element.", - )), - name: "toString", - overridden: false, - is_final: false, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::Dot( - Box::new(ql::Expression::Var("this")), - "getAPrimaryQlClass", - vec![], - )), - ), - }; - let get_location = ql::Predicate { - name: "getLocation", - qldoc: Some(String::from("Gets the location of this element.")), - overridden: false, - is_final: true, - return_type: Some(ql::Type::Normal("L::Location")), - formal_parameters: vec![], - body: ql::Expression::Pred( - node_info_table, - vec![ - ql::Expression::Var("this"), - ql::Expression::Var("_"), // parent - ql::Expression::Var("_"), // parent index - ql::Expression::Var("result"), // location - ], - ), - }; - let get_a_field_or_child = create_none_predicate( - Some(String::from("Gets a field or child node of this node.")), - "getAFieldOrChild", - false, - Some(ql::Type::Normal("AstNode")), - ); - let get_parent = ql::Predicate { - qldoc: Some(String::from("Gets the parent of this element.")), - name: "getParent", - overridden: false, - is_final: true, - return_type: Some(ql::Type::Normal("AstNode")), - formal_parameters: vec![], - body: ql::Expression::Pred( - node_info_table, - vec![ - ql::Expression::Var("this"), - ql::Expression::Var("result"), - ql::Expression::Var("_"), // parent index - ql::Expression::Var("_"), // location - ], - ), - }; - let get_parent_index = ql::Predicate { - qldoc: Some(String::from( - "Gets the index of this node among the children of its parent.", - )), - name: "getParentIndex", - overridden: false, - is_final: true, - return_type: Some(ql::Type::Int), - formal_parameters: vec![], - body: ql::Expression::Pred( - node_info_table, - vec![ - ql::Expression::Var("this"), - ql::Expression::Var("_"), // parent - ql::Expression::Var("result"), // parent index - ql::Expression::Var("_"), // location - ], - ), - }; - let get_a_primary_ql_class = ql::Predicate { - qldoc: Some(String::from( - "Gets the name of the primary QL class for this element.", - )), - name: "getAPrimaryQlClass", - overridden: false, - is_final: false, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::String("???")), - ), - }; - let get_primary_ql_classes = ql::Predicate { - qldoc: Some( - "Gets a comma-separated list of the names of the primary CodeQL \ - classes to which this element belongs." - .to_owned(), - ), - name: "getPrimaryQlClasses", - overridden: false, - is_final: false, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::Aggregate { - name: "concat", - vars: vec![], - range: None, - expr: Box::new(ql::Expression::Dot( - Box::new(ql::Expression::Var("this")), - "getAPrimaryQlClass", - vec![], - )), - second_expr: Some(Box::new(ql::Expression::String(","))), - }), - ), - }; - ql::Class { - qldoc: Some(String::from("The base class for all AST nodes")), - name: "AstNode", - is_abstract: false, - supertypes: vec![ql::Type::At(ast_node)].into_iter().collect(), - characteristic_predicate: None, - predicates: vec![ - to_string, - get_location, - get_parent, - get_parent_index, - get_a_field_or_child, - get_a_primary_ql_class, - get_primary_ql_classes, - ], - } -} - -pub fn create_token_class<'a>(token_type: &'a str, tokeninfo: &'a str) -> ql::Class<'a> { - let tokeninfo_arity = 3; // id, kind, value - let get_value = ql::Predicate { - qldoc: Some(String::from("Gets the value of this token.")), - name: "getValue", - overridden: false, - is_final: true, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: create_get_field_expr_for_column_storage("result", tokeninfo, 1, tokeninfo_arity), - }; - let to_string = ql::Predicate { - qldoc: Some(String::from( - "Gets a string representation of this element.", - )), - name: "toString", - overridden: true, - is_final: true, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::Dot( - Box::new(ql::Expression::Var("this")), - "getValue", - vec![], - )), - ), - }; - ql::Class { - qldoc: Some(String::from("A token.")), - name: "Token", - is_abstract: false, - supertypes: vec![ql::Type::At(token_type), ql::Type::Normal("AstNode")] - .into_iter() - .collect(), - characteristic_predicate: None, - predicates: vec![ - get_value, - to_string, - create_get_a_primary_ql_class("Token", false), - ], - } -} - -// Creates the `ReservedWord` class. -pub fn create_reserved_word_class(db_name: &str) -> ql::Class { - let class_name = "ReservedWord"; - let get_a_primary_ql_class = create_get_a_primary_ql_class(class_name, true); - ql::Class { - qldoc: Some(String::from("A reserved word.")), - name: class_name, - is_abstract: false, - supertypes: vec![ql::Type::At(db_name), ql::Type::Normal("Token")] - .into_iter() - .collect(), - characteristic_predicate: None, - predicates: vec![get_a_primary_ql_class], - } -} - -/// Creates a predicate whose body is `none()`. -fn create_none_predicate<'a>( - qldoc: Option, - name: &'a str, - overridden: bool, - return_type: Option>, -) -> ql::Predicate<'a> { - ql::Predicate { - qldoc, - name, - overridden, - is_final: false, - return_type, - formal_parameters: Vec::new(), - body: ql::Expression::Pred("none", vec![]), - } -} - -/// Creates an overridden `getAPrimaryQlClass` predicate that returns the given -/// name. -fn create_get_a_primary_ql_class(class_name: &str, is_final: bool) -> ql::Predicate { - ql::Predicate { - qldoc: Some(String::from( - "Gets the name of the primary QL class for this element.", - )), - name: "getAPrimaryQlClass", - overridden: true, - is_final, - return_type: Some(ql::Type::String), - formal_parameters: vec![], - body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::String(class_name)), - ), - } -} - -/// Returns an expression to get a field that's defined as a column in the parent's table. -/// -/// # Arguments -/// -/// * `result_var_name` - the name of the variable to which the resulting value should be bound -/// * `table_name` - the name of parent's defining table -/// * `column_index` - the index in that table that defines the field -/// * `arity` - the total number of columns in the table -fn create_get_field_expr_for_column_storage<'a>( - result_var_name: &'a str, - table_name: &'a str, - column_index: usize, - arity: usize, -) -> ql::Expression<'a> { - let num_underscores_before = column_index; - let num_underscores_after = arity - 2 - num_underscores_before; - ql::Expression::Pred( - table_name, - [ - vec![ql::Expression::Var("this")], - vec![ql::Expression::Var("_"); num_underscores_before], - vec![ql::Expression::Var(result_var_name)], - vec![ql::Expression::Var("_"); num_underscores_after], - ] - .concat(), - ) -} - -/// Returns an expression to get the field with the given index from its -/// auxiliary table. The index name can be "_" so the expression will hold for -/// all indices. -fn create_get_field_expr_for_table_storage<'a>( - result_var_name: &'a str, - table_name: &'a str, - index_var_name: Option<&'a str>, -) -> ql::Expression<'a> { - ql::Expression::Pred( - table_name, - match index_var_name { - Some(index_var_name) => vec![ - ql::Expression::Var("this"), - ql::Expression::Var(index_var_name), - ql::Expression::Var(result_var_name), - ], - None => vec![ql::Expression::Var("this"), ql::Expression::Var("result")], - }, - ) -} - -/// Creates a pair consisting of a predicate to get the given field, and an -/// optional expression that will get the same field. When the field can occur -/// multiple times, the predicate will take an index argument, while the -/// expression will use the "don't care" expression to hold for all occurrences. -/// -/// # Arguments -/// -/// `main_table_name` - the name of the defining table for the parent node -/// `main_table_arity` - the number of columns in the main table -/// `main_table_column_index` - a mutable reference to a column index indicating -/// where the field is in the main table. If this is used (i.e. the field has -/// column storage), then the index is incremented. -/// `parent_name` - the name of the parent node -/// `field` - the field whose getters we are creating -/// `field_type` - the db name of the field's type (possibly being a union we created) -fn create_field_getters<'a>( - main_table_name: &'a str, - main_table_arity: usize, - main_table_column_index: &mut usize, - field: &'a node_types::Field, - nodes: &'a node_types::NodeTypeMap, -) -> (ql::Predicate<'a>, Option>) { - let return_type = match &field.type_info { - node_types::FieldTypeInfo::Single(t) => { - Some(ql::Type::Normal(&nodes.get(t).unwrap().ql_class_name)) - } - node_types::FieldTypeInfo::Multiple { - types: _, - dbscheme_union: _, - ql_class, - } => Some(ql::Type::Normal(ql_class)), - node_types::FieldTypeInfo::ReservedWordInt(_) => Some(ql::Type::String), - }; - let formal_parameters = match &field.storage { - node_types::Storage::Column { .. } => vec![], - node_types::Storage::Table { has_index, .. } => { - if *has_index { - vec![ql::FormalParameter { - name: "i", - param_type: ql::Type::Int, - }] - } else { - vec![] - } - } - }; - - // For the expression to get a value, what variable name should the result - // be bound to? - let get_value_result_var_name = match &field.type_info { - node_types::FieldTypeInfo::ReservedWordInt(_) => "value", - node_types::FieldTypeInfo::Single(_) => "result", - node_types::FieldTypeInfo::Multiple { .. } => "result", - }; - - // Two expressions for getting the value. One that's suitable use in the - // getter predicate (where there may be a specific index), and another for - // use in `getAFieldOrChild` (where we use a "don't care" expression to - // match any index). - let (get_value, get_value_any_index) = match &field.storage { - node_types::Storage::Column { name: _ } => { - let column_index = *main_table_column_index; - *main_table_column_index += 1; - ( - create_get_field_expr_for_column_storage( - get_value_result_var_name, - main_table_name, - column_index, - main_table_arity, - ), - create_get_field_expr_for_column_storage( - get_value_result_var_name, - main_table_name, - column_index, - main_table_arity, - ), - ) - } - node_types::Storage::Table { - name: field_table_name, - has_index, - column_name: _, - } => ( - create_get_field_expr_for_table_storage( - get_value_result_var_name, - field_table_name, - if *has_index { Some("i") } else { None }, - ), - create_get_field_expr_for_table_storage( - get_value_result_var_name, - field_table_name, - if *has_index { Some("_") } else { None }, - ), - ), - }; - let (body, optional_expr) = match &field.type_info { - node_types::FieldTypeInfo::ReservedWordInt(int_mapping) => { - // Create an expression that binds the corresponding string to `result` for each `value`, e.g.: - // result = "foo" and value = 0 or - // result = "bar" and value = 1 or - // result = "baz" and value = 2 - let disjuncts = int_mapping - .iter() - .map(|(token_str, (value, _))| { - ql::Expression::And(vec![ - ql::Expression::Equals( - Box::new(ql::Expression::Var("result")), - Box::new(ql::Expression::String(token_str)), - ), - ql::Expression::Equals( - Box::new(ql::Expression::Var("value")), - Box::new(ql::Expression::Integer(*value)), - ), - ]) - }) - .collect(); - ( - ql::Expression::Aggregate { - name: "exists", - vars: vec![ql::FormalParameter { - name: "value", - param_type: ql::Type::Int, - }], - range: Some(Box::new(get_value)), - expr: Box::new(ql::Expression::Or(disjuncts)), - second_expr: None, - }, - // Since the getter returns a string and not an AstNode, it won't be part of getAFieldOrChild: - None, - ) - } - node_types::FieldTypeInfo::Single(_) | node_types::FieldTypeInfo::Multiple { .. } => { - (get_value, Some(get_value_any_index)) - } - }; - let qldoc = match &field.name { - Some(name) => format!("Gets the node corresponding to the field `{}`.", name), - None => { - if formal_parameters.is_empty() { - "Gets the child of this node.".to_owned() - } else { - "Gets the `i`th child of this node.".to_owned() - } - } - }; - ( - ql::Predicate { - qldoc: Some(qldoc), - name: &field.getter_name, - overridden: false, - is_final: true, - return_type, - formal_parameters, - body, - }, - optional_expr, - ) -} - -/// Converts the given node types into CodeQL classes wrapping the dbscheme. -pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { - let mut classes: Vec = Vec::new(); - let mut token_kinds = BTreeSet::new(); - for (type_name, node) in nodes { - if let node_types::EntryKind::Token { .. } = &node.kind { - if type_name.named { - token_kinds.insert(&type_name.kind); - } - } - } - - for (type_name, node) in nodes { - match &node.kind { - node_types::EntryKind::Token { kind_id: _ } => { - if type_name.named { - let get_a_primary_ql_class = - create_get_a_primary_ql_class(&node.ql_class_name, true); - let mut supertypes: BTreeSet = BTreeSet::new(); - supertypes.insert(ql::Type::At(&node.dbscheme_name)); - supertypes.insert(ql::Type::Normal("Token")); - classes.push(ql::TopLevel::Class(ql::Class { - qldoc: Some(format!("A class representing `{}` tokens.", type_name.kind)), - name: &node.ql_class_name, - is_abstract: false, - supertypes, - characteristic_predicate: None, - predicates: vec![get_a_primary_ql_class], - })); - } - } - node_types::EntryKind::Union { members: _ } => { - // It's a tree-sitter supertype node, so we're wrapping a dbscheme - // union type. - classes.push(ql::TopLevel::Class(ql::Class { - qldoc: None, - name: &node.ql_class_name, - is_abstract: false, - supertypes: vec![ - ql::Type::At(&node.dbscheme_name), - ql::Type::Normal("AstNode"), - ] - .into_iter() - .collect(), - characteristic_predicate: None, - predicates: vec![], - })); - } - node_types::EntryKind::Table { - name: main_table_name, - fields, - } => { - if fields.is_empty() { - panic!("Encountered node '{}' with no fields", type_name.kind); - } - - // Count how many columns there will be in the main table. There - // will be one for the id, plus one for each field that's stored - // as a column. - let main_table_arity = 1 + fields - .iter() - .filter(|&f| matches!(f.storage, node_types::Storage::Column { .. })) - .count(); - - let main_class_name = &node.ql_class_name; - let mut main_class = ql::Class { - qldoc: Some(format!("A class representing `{}` nodes.", type_name.kind)), - name: main_class_name, - is_abstract: false, - supertypes: vec![ - ql::Type::At(&node.dbscheme_name), - ql::Type::Normal("AstNode"), - ] - .into_iter() - .collect(), - characteristic_predicate: None, - predicates: vec![create_get_a_primary_ql_class(main_class_name, true)], - }; - - let mut main_table_column_index: usize = 0; - let mut get_child_exprs: Vec = Vec::new(); - - // Iterate through the fields, creating: - // - classes to wrap union types if fields need them, - // - predicates to access the fields, - // - the QL expressions to access the fields that will be part of getAFieldOrChild. - for field in fields { - let (get_pred, get_child_expr) = create_field_getters( - main_table_name, - main_table_arity, - &mut main_table_column_index, - field, - nodes, - ); - main_class.predicates.push(get_pred); - if let Some(get_child_expr) = get_child_expr { - get_child_exprs.push(get_child_expr) - } - } - - main_class.predicates.push(ql::Predicate { - qldoc: Some(String::from("Gets a field or child node of this node.")), - name: "getAFieldOrChild", - overridden: true, - is_final: true, - return_type: Some(ql::Type::Normal("AstNode")), - formal_parameters: vec![], - body: ql::Expression::Or(get_child_exprs), - }); - - classes.push(ql::TopLevel::Class(main_class)); - } - } - } - - classes -} diff --git a/ql/node-types/Cargo.toml b/ql/node-types/Cargo.toml deleted file mode 100644 index 181bd6481e9..00000000000 --- a/ql/node-types/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "node-types" -version = "0.1.0" -authors = ["GitHub"] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" diff --git a/ql/node-types/src/lib.rs b/ql/node-types/src/lib.rs deleted file mode 100644 index 9467e23fd62..00000000000 --- a/ql/node-types/src/lib.rs +++ /dev/null @@ -1,449 +0,0 @@ -use serde::Deserialize; -use std::collections::BTreeMap; -use std::path::Path; - -use std::collections::BTreeSet as Set; -use std::fs; - -/// A lookup table from TypeName to Entry. -pub type NodeTypeMap = BTreeMap; - -#[derive(Debug)] -pub struct Entry { - pub dbscheme_name: String, - pub ql_class_name: String, - pub kind: EntryKind, -} - -#[derive(Debug)] -pub enum EntryKind { - Union { members: Set }, - Table { name: String, fields: Vec }, - Token { kind_id: usize }, -} - -#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)] -pub struct TypeName { - pub kind: String, - pub named: bool, -} - -#[derive(Debug)] -pub enum FieldTypeInfo { - /// The field has a single type. - Single(TypeName), - - /// The field can take one of several types, so we also provide the name of - /// the database union type that wraps them, and the corresponding QL class - /// name. - Multiple { - types: Set, - dbscheme_union: String, - ql_class: String, - }, - - /// The field can be one of several tokens, so the db type will be an `int` - /// with a `case @foo.kind` for each possibility. - ReservedWordInt(BTreeMap), -} - -#[derive(Debug)] -pub struct Field { - pub parent: TypeName, - pub type_info: FieldTypeInfo, - /// The name of the field or None for the anonymous 'children' - /// entry from node_types.json - pub name: Option, - /// The name of the predicate to get this field. - pub getter_name: String, - pub storage: Storage, -} - -fn name_for_field_or_child(name: &Option) -> String { - match name { - Some(name) => name.clone(), - None => "child".to_owned(), - } -} - -#[derive(Debug)] -pub enum Storage { - /// the field is stored as a column in the parent table - Column { name: String }, - /// the field is stored in a link table - Table { - /// the name of the table - name: String, - /// the name of the column for the field in the dbscheme - column_name: String, - /// does it have an associated index column? - has_index: bool, - }, -} - -impl Storage { - pub fn is_column(&self) -> bool { - match self { - Storage::Column { .. } => true, - _ => false, - } - } -} -pub fn read_node_types(prefix: &str, node_types_path: &Path) -> std::io::Result { - let file = fs::File::open(node_types_path)?; - let node_types: Vec = serde_json::from_reader(file)?; - Ok(convert_nodes(prefix, &node_types)) -} - -pub fn read_node_types_str(prefix: &str, node_types_json: &str) -> std::io::Result { - let node_types: Vec = serde_json::from_str(node_types_json)?; - Ok(convert_nodes(prefix, &node_types)) -} - -fn convert_type(node_type: &NodeType) -> TypeName { - TypeName { - kind: node_type.kind.to_string(), - named: node_type.named, - } -} - -fn convert_types(node_types: &[NodeType]) -> Set { - node_types.iter().map(convert_type).collect() -} - -pub fn convert_nodes(prefix: &str, nodes: &[NodeInfo]) -> NodeTypeMap { - let mut entries = NodeTypeMap::new(); - let mut token_kinds = Set::new(); - - // First, find all the token kinds - for node in nodes { - if node.subtypes.is_none() - && node.fields.as_ref().map_or(0, |x| x.len()) == 0 - && node.children.is_none() - { - let type_name = TypeName { - kind: node.kind.clone(), - named: node.named, - }; - token_kinds.insert(type_name); - } - } - - for node in nodes { - let flattened_name = &node_type_name(&node.kind, node.named); - let dbscheme_name = escape_name(flattened_name); - let ql_class_name = dbscheme_name_to_class_name(&dbscheme_name); - let dbscheme_name = format!("{}_{}", prefix, &dbscheme_name); - if let Some(subtypes) = &node.subtypes { - // It's a tree-sitter supertype node, for which we create a union - // type. - entries.insert( - TypeName { - kind: node.kind.clone(), - named: node.named, - }, - Entry { - dbscheme_name, - ql_class_name, - kind: EntryKind::Union { - members: convert_types(subtypes), - }, - }, - ); - } else if node.fields.as_ref().map_or(0, |x| x.len()) == 0 && node.children.is_none() { - // Token kind, handled above. - } else { - // It's a product type, defined by a table. - let type_name = TypeName { - kind: node.kind.clone(), - named: node.named, - }; - let table_name = escape_name(&(format!("{}_def", &flattened_name))); - let table_name = format!("{}_{}", prefix, &table_name); - - let mut fields = Vec::new(); - - // If the type also has fields or children, then we create either - // auxiliary tables or columns in the defining table for them. - if let Some(node_fields) = &node.fields { - for (field_name, field_info) in node_fields { - add_field( - prefix, - &type_name, - Some(field_name.to_string()), - field_info, - &mut fields, - &token_kinds, - ); - } - } - if let Some(children) = &node.children { - // Treat children as if they were a field called 'child'. - add_field( - prefix, - &type_name, - None, - children, - &mut fields, - &token_kinds, - ); - } - entries.insert( - type_name, - Entry { - dbscheme_name, - ql_class_name, - kind: EntryKind::Table { - name: table_name, - fields, - }, - }, - ); - } - } - let mut counter = 0; - for type_name in token_kinds { - let entry = if type_name.named { - counter += 1; - let unprefixed_name = node_type_name(&type_name.kind, true); - Entry { - dbscheme_name: escape_name(&format!("{}_token_{}", &prefix, &unprefixed_name)), - ql_class_name: dbscheme_name_to_class_name(&escape_name(&unprefixed_name)), - kind: EntryKind::Token { kind_id: counter }, - } - } else { - Entry { - dbscheme_name: format!("{}_reserved_word", &prefix), - ql_class_name: "ReservedWord".to_owned(), - kind: EntryKind::Token { kind_id: 0 }, - } - }; - entries.insert(type_name, entry); - } - entries -} - -fn add_field( - prefix: &str, - parent_type_name: &TypeName, - field_name: Option, - field_info: &FieldInfo, - fields: &mut Vec, - token_kinds: &Set, -) { - let parent_flattened_name = node_type_name(&parent_type_name.kind, parent_type_name.named); - let column_name = escape_name(&name_for_field_or_child(&field_name)); - let storage = if !field_info.multiple && field_info.required { - // This field must appear exactly once, so we add it as - // a column to the main table for the node type. - Storage::Column { name: column_name } - } else { - // Put the field in an auxiliary table. - let has_index = field_info.multiple; - let field_table_name = escape_name(&format!( - "{}_{}_{}", - &prefix, - parent_flattened_name, - &name_for_field_or_child(&field_name) - )); - Storage::Table { - has_index, - name: field_table_name, - column_name, - } - }; - let converted_types = convert_types(&field_info.types); - let type_info = if storage.is_column() - && field_info - .types - .iter() - .all(|t| !t.named && token_kinds.contains(&convert_type(t))) - { - // All possible types for this field are reserved words. The db - // representation will be an `int` with a `case @foo.field = ...` to - // enumerate the possible values. - let mut field_token_ints: BTreeMap = BTreeMap::new(); - for (counter, t) in converted_types.into_iter().enumerate() { - let dbscheme_variant_name = - escape_name(&format!("{}_{}_{}", &prefix, parent_flattened_name, t.kind)); - field_token_ints.insert(t.kind.to_owned(), (counter, dbscheme_variant_name)); - } - FieldTypeInfo::ReservedWordInt(field_token_ints) - } else if field_info.types.len() == 1 { - FieldTypeInfo::Single(converted_types.into_iter().next().unwrap()) - } else { - // The dbscheme type for this field will be a union. In QL, it'll just be AstNode. - FieldTypeInfo::Multiple { - types: converted_types, - dbscheme_union: format!( - "{}_{}_{}_type", - &prefix, - &parent_flattened_name, - &name_for_field_or_child(&field_name) - ), - ql_class: "AstNode".to_owned(), - } - }; - let getter_name = format!( - "get{}", - dbscheme_name_to_class_name(&escape_name(&name_for_field_or_child(&field_name))) - ); - fields.push(Field { - parent: TypeName { - kind: parent_type_name.kind.to_string(), - named: parent_type_name.named, - }, - type_info, - name: field_name, - getter_name, - storage, - }); -} -#[derive(Deserialize)] -pub struct NodeInfo { - #[serde(rename = "type")] - pub kind: String, - pub named: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub fields: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub children: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub subtypes: Option>, -} - -#[derive(Deserialize)] -pub struct NodeType { - #[serde(rename = "type")] - pub kind: String, - pub named: bool, -} - -#[derive(Deserialize)] -pub struct FieldInfo { - pub multiple: bool, - pub required: bool, - pub types: Vec, -} - -/// Given a tree-sitter node type's (kind, named) pair, returns a single string -/// representing the (unescaped) name we'll use to refer to corresponding QL -/// type. -fn node_type_name(kind: &str, named: bool) -> String { - if named { - kind.to_string() - } else { - format!("{}_unnamed", kind) - } -} - -const RESERVED_KEYWORDS: [&str; 14] = [ - "boolean", "case", "date", "float", "int", "key", "of", "order", "ref", "string", "subtype", - "type", "unique", "varchar", -]; - -/// Returns a string that's a copy of `name` but suitably escaped to be a valid -/// QL identifier. -fn escape_name(name: &str) -> String { - let mut result = String::new(); - - // If there's a leading underscore, replace it with 'underscore_'. - if let Some(c) = name.chars().next() { - if c == '_' { - result.push_str("underscore"); - } - } - for c in name.chars() { - match c { - '{' => result.push_str("lbrace"), - '}' => result.push_str("rbrace"), - '<' => result.push_str("langle"), - '>' => result.push_str("rangle"), - '[' => result.push_str("lbracket"), - ']' => result.push_str("rbracket"), - '(' => result.push_str("lparen"), - ')' => result.push_str("rparen"), - '|' => result.push_str("pipe"), - '=' => result.push_str("equal"), - '~' => result.push_str("tilde"), - '?' => result.push_str("question"), - '`' => result.push_str("backtick"), - '^' => result.push_str("caret"), - '!' => result.push_str("bang"), - '#' => result.push_str("hash"), - '%' => result.push_str("percent"), - '&' => result.push_str("ampersand"), - '.' => result.push_str("dot"), - ',' => result.push_str("comma"), - '/' => result.push_str("slash"), - ':' => result.push_str("colon"), - ';' => result.push_str("semicolon"), - '"' => result.push_str("dquote"), - '*' => result.push_str("star"), - '+' => result.push_str("plus"), - '-' => result.push_str("minus"), - '@' => result.push_str("at"), - _ if c.is_uppercase() => { - result.push('_'); - result.push_str(&c.to_lowercase().to_string()) - } - _ => result.push(c), - } - } - - for &keyword in &RESERVED_KEYWORDS { - if result == keyword { - result.push_str("__"); - break; - } - } - - result -} - -pub fn to_snake_case(word: &str) -> String { - let mut prev_upper = true; - let mut result = String::new(); - for c in word.chars() { - if c.is_uppercase() { - if !prev_upper { - result.push('_') - } - prev_upper = true; - result.push(c.to_ascii_lowercase()); - } else { - prev_upper = false; - result.push(c); - } - } - result -} -/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL -/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz". -fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String { - fn to_title_case(word: &str) -> String { - let mut first = true; - let mut result = String::new(); - for c in word.chars() { - if first { - first = false; - result.push(c.to_ascii_uppercase()); - } else { - result.push(c); - } - } - result - } - dbscheme_name - .split('_') - .map(to_title_case) - .collect::>() - .join("") -} - -#[test] -fn to_snake_case_test() { - assert_eq!("python", to_snake_case("Python")); - assert_eq!("yaml", to_snake_case("YAML")); - assert_eq!("set_literal", to_snake_case("SetLiteral")); -}