Allow custom prefixes for PerfettoSQL modules.

Allow functions, tables, views and macros in chrome/util to start with
cr for brevity.

R=lalitm@google.com
CC=rasikan@google.com,mayzner@google.com

Change-Id: I5d8f715977f3594b56c4e911a3224c90382e298e
diff --git a/python/generators/sql_processing/docs_parse.py b/python/generators/sql_processing/docs_parse.py
index 981316d..564679a 100644
--- a/python/generators/sql_processing/docs_parse.py
+++ b/python/generators/sql_processing/docs_parse.py
@@ -20,7 +20,7 @@
 from typing import Any, Dict, List, Optional, Set, Tuple, NamedTuple
 
 from python.generators.sql_processing.docs_extractor import DocsExtractor
-from python.generators.sql_processing.utils import ANY_PATTERN, ARG_DEFINITION_PATTERN, ObjKind
+from python.generators.sql_processing.utils import ALLOWED_PREFIXES, ANY_PATTERN, ARG_DEFINITION_PATTERN, ObjKind
 from python.generators.sql_processing.utils import ARG_ANNOTATION_PATTERN
 from python.generators.sql_processing.utils import NAME_AND_TYPE_PATTERN
 from python.generators.sql_processing.utils import FUNCTION_RETURN_PATTERN
@@ -49,6 +49,28 @@
   description: str
 
 
+# Returns: error message if the name is not correct, None otherwise.
+def get_module_prefix_error(name: str, path: str, module: str) -> Optional[str]:
+  prefix = name.lower().split('_')[0]
+  if module == "common" or module == "prelude":
+    if prefix == module:
+      return (f'Names of tables/views/functions in the "{module}" module '
+              f'should not start with {module}')
+    return None
+  if prefix == module:
+    # Module prefix is always allowed.
+    return None
+  allowed_prefixes = [module]
+  for (path_prefix, allowed_name_prefix) in ALLOWED_PREFIXES.items():
+    if path.startswith(path_prefix):
+      if prefix == allowed_name_prefix:
+        return None
+      allowed_prefixes.append(allowed_name_prefix)
+  return (
+      f'Names of tables/views/functions at path "{path}" should be prefixed '
+      f'with one of following names: {", ".join(allowed_prefixes)}')
+
+
 class AbstractDocParser(ABC):
 
   @dataclass
@@ -64,19 +86,10 @@
   def _parse_name(self, upper: bool = False):
     assert self.name
     assert isinstance(self.name, str)
-    module_pattern = f"^{self.module}_.*"
-    if upper:
-      module_pattern = module_pattern.upper()
-    starts_with_module_name = re.match(module_pattern, self.name, re.IGNORECASE)
-    if self.module == "common" or self.module == "prelude":
-      if starts_with_module_name:
-        self._error(
-            'Names of tables/views/functions in the "{self.module}" module '
-            f'should not start with {module_pattern}')
-      return self.name
-    if not starts_with_module_name:
-      self._error('Names of tables/views/functions should be prefixed with the '
-                  f'module name (i.e. should start with {module_pattern})')
+    module_prefix_error = get_module_prefix_error(self.name, self.path,
+                                                  self.module)
+    if module_prefix_error is not None:
+      self._error(module_prefix_error)
     return self.name.strip()
 
   def _parse_desc_not_empty(self, desc: str):
diff --git a/python/generators/sql_processing/utils.py b/python/generators/sql_processing/utils.py
index 0a05a71..c7133b3 100644
--- a/python/generators/sql_processing/utils.py
+++ b/python/generators/sql_processing/utils.py
@@ -105,6 +105,9 @@
     ObjKind.table_function: CREATE_TABLE_FUNCTION_PATTERN,
 }
 
+ALLOWED_PREFIXES = {
+    'chrome/util': 'cr',
+}
 
 # Given a regex pattern and a string to match against, returns all the
 # matching positions. Specifically, it returns a dictionary from the line
diff --git a/python/test/stdlib_unittest.py b/python/test/stdlib_unittest.py
index 3fb6713..f7963fb 100644
--- a/python/test/stdlib_unittest.py
+++ b/python/test/stdlib_unittest.py
@@ -112,6 +112,79 @@
     # Expecting an error: function prefix (bar) not matching module name (foo).
     self.assertEqual(len(res.errors), 1)
 
+  # Checks that custom prefixes (cr for chrome/util) are allowed.
+  def test_custom_module_prefix(self):
+    res = parse_file(
+        'chrome/util/test.sql', f'''
+-- Comment
+CREATE PERFETTO TABLE cr_table(
+    -- Column.
+    x INT
+) AS
+SELECT 1;
+    '''.strip())
+    self.assertListEqual(res.errors, [])
+
+    fn = res.table_views[0]
+    self.assertEqual(fn.name, 'cr_table')
+    self.assertEqual(fn.desc, 'Comment')
+    self.assertEqual(fn.cols, {
+        'x': Arg('INT', 'Column.'),
+    })
+
+  # Checks that when custom prefixes (cr for chrome/util) are present,
+  # the full module name (chrome) is still accepted.
+  def test_custom_module_prefix_full_module_name(self):
+    res = parse_file(
+        'chrome/util/test.sql', f'''
+-- Comment
+CREATE PERFETTO TABLE chrome_table(
+    -- Column.
+    x INT
+) AS
+SELECT 1;
+    '''.strip())
+    self.assertListEqual(res.errors, [])
+
+    fn = res.table_views[0]
+    self.assertEqual(fn.name, 'chrome_table')
+    self.assertEqual(fn.desc, 'Comment')
+    self.assertEqual(fn.cols, {
+        'x': Arg('INT', 'Column.'),
+    })
+
+  # Checks that when custom prefixes (cr for chrome/util) are present,
+  # the incorrect prefixes (foo) are not accepted.
+  def test_custom_module_prefix_incorrect(self):
+    res = parse_file(
+        'chrome/util/test.sql', f'''
+-- Comment
+CREATE PERFETTO TABLE foo_table(
+    -- Column.
+    x INT
+) AS
+SELECT 1;
+    '''.strip())
+    # Expecting an error: table prefix (foo) is not allowed for a given path
+    # (allowed: chrome, cr).
+    self.assertEqual(len(res.errors), 1)
+
+  # Checks that when custom prefixes (cr for chrome/util) are present,
+  # they do not apply outside of the path scope.
+  def test_custom_module_prefix_does_not_apply_outside(self):
+    res = parse_file(
+        'foo/bar.sql', f'''
+-- Comment
+CREATE PERFETTO TABLE cr_table(
+    -- Column.
+    x INT
+) AS
+SELECT 1;
+    '''.strip())
+    # Expecting an error: table prefix (foo) is not allowed for a given path
+    # (allowed: foo).
+    self.assertEqual(len(res.errors), 1)
+
   def test_common_does_not_include_module_name(self):
     res = parse_file(
         'common/bar.sql', f'''