tp: Refactor SQL module presubmit check.

Bug:255535171
Change-Id: I5da63402788143e8596a7730438afb4385bcde69
diff --git a/tools/check_sql_modules.py b/tools/check_sql_modules.py
index 33b1ed2..7537977 100755
--- a/tools/check_sql_modules.py
+++ b/tools/check_sql_modules.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-# This tool checks that every create (table|view) without prefix
+# This tool checks that every SQL object created without prefix
 # 'internal_' is documented with proper schema.
 
 from __future__ import absolute_import
@@ -24,264 +24,480 @@
 import re
 import sys
 from sql_modules_utils import *
+from typing import Union, List
 
 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
 
-# Check that CREATE VIEW/TABLE has a matching schema before it.
-def check_create_table_view(path, module, sql):
-  errors = 0
-  obj_name, schema_cols, schema_desc = None, False, False
-  lines = sql.split('\n')
-  for i, line in enumerate(lines):
-    create_line = re.match(create_table_view_pattern(), line)
+# Stores documentation for CREATE {TABLE|VIEW} with comment split into
+# segments.
+class TableViewDocs:
 
-    # Ignore all lines that don't create an object
-    if create_line is None:
-      continue
+  def __init__(self, name: str, desc: List[str], columns: List[str], path: str):
+    self.name = name
+    self.desc = desc
+    self.columns = columns
+    self.path = path
 
-    obj_name = create_line.group(2)
+  # Contructs new TableViewDocs from whole comment, by splitting it on typed
+  # lines. Returns None for improperly structured schemas.
+  @staticmethod
+  def create_from_comment(path: str, comment: List[str], module: str,
+                          matches: tuple) -> tuple["TableViewDocs", List[str]]:
+    obj_type, name = matches[:2]
 
-    # Ignore 'internal_' tables|views
-    if re.match(r'^internal_.*', obj_name):
-      continue
+    # Ignore internal tables and views.
+    if re.match(r"^internal_.*", name):
+      return None, []
 
-    # Check whether the name starts with module_name
-    starts_with_module_name = re.match(f'^{module}_.*', obj_name)
-    if module == 'common':
-      if starts_with_module_name:
-        sys.stderr.write(
-            f"Invalid name in module {obj_name}. "
-            f"In module 'common' the name shouldn't start with 'common_'.\n")
-        errors += 1
-    else:
-      if not starts_with_module_name:
-        sys.stderr.write(f"Invalid name in module {obj_name}. "
-                         f"View/table name has to begin with {module}_.\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
+    errors = validate_name(name, module)
+    col_start = None
 
-    # Validate the schema before the create line.
-    for comment_line in fetch_comment(lines[i - 1::-1]):
+    # Splits code into segments by finding beginning of column segment.
+    for i, line in enumerate(comment):
       # Ignore only '--' line.
-      if comment_line == '--':
+      if line == "--":
         continue
 
-      # Break on SQL lines (lines with words without '--' at the beginning)
-      # and empty lines.
-      if not line or not comment_line.startswith('--'):
-        break
+      m = re.match(typed_comment_pattern(), line)
 
-      # Look for '-- @column' line as a column description
-      column_line = re.match(column_pattern(), comment_line)
-      if column_line is not None:
-        if not schema_desc:
-          sys.stderr.write(f"Columns needs to be defined after description.\n")
-          sys.stderr.write(f'{path}:\n"{comment_line}"\n')
-          errors += 1
-          continue
-
-        schema_cols = True
+      # Ignore untyped lines
+      if not m:
         continue
 
-      # The only  option left is a description, but it has to be after
-      # schema columns.
-      schema_desc = True
+      line_type = m.group(1)
+      if line_type == "column" and not col_start:
+        col_start = i
+        continue
 
-    if not schema_cols:
-      sys.stderr.write((f"Missing documentation schema for {obj_name}\n"))
-      sys.stderr.write(f'{path}:\n"{line}"\n')
-      errors += 1
-    obj_name, schema_cols, schema_desc = None, False, False
+    if not col_start:
+      errors.append(f"No columns for {obj_type}.\n"
+                    f"'{name}' in {path}:\n'{line}'\n")
+      return None, errors
 
-  return errors
+    return (
+        TableViewDocs(name, comment[:col_start], comment[col_start:], path),
+        errors,
+    )
+
+  def check_comment(self) -> List[str]:
+    errors = validate_desc(self)
+    errors += validate_columns(self)
+    return errors
 
 
-def parse_args(args_str):
-  errors = 0
-  args = {}
-  for arg_str in args_str.split(","):
-    m = re.match(arg_pattern(), arg_str)
-    if m is None:
-      sys.stderr.write(f"Wrong arguments formatting for '{arg_str}'\n")
-      errors += 1
-      continue
-    args[m.group(1)] = m.group(2)
-  return errors, args
+# Stores documentation for CREATE_FUNCTION with comment split into segments.
+class FunctionDocs:
+
+  def __init__(
+      self,
+      path: str,
+      data_from_sql: dict,
+      module: str,
+      name: str,
+      desc: str,
+      args: List[str],
+      ret: List[str],
+  ):
+    self.path = path
+    self.data_from_sql = data_from_sql
+    self.module = module
+    self.name = name
+    self.desc = desc
+    self.args = args
+    self.ret = ret
+
+  # Contructs new FunctionDocs from whole comment, by splitting it on typed
+  # lines. Returns None for improperly structured schemas.
+  @staticmethod
+  def create_from_comment(path: str, comment: List[str], module: str,
+                          matches: tuple) -> tuple["FunctionDocs", List[str]]:
+    name, args, ret, sql = matches
+
+    # Ignore internal functions.
+    if re.match(r"^INTERNAL_.*", name):
+      return None, []
+
+    errors = validate_name(name, module, upper=True)
+    start_args, start_ret = None, None
+
+    # Splits code into segments by finding beginning of args and ret segments.
+    for i, line in enumerate(comment):
+      # Ignore only '--' line.
+      if line == "--":
+        continue
+
+      m = re.match(typed_comment_pattern(), line)
+
+      # Ignore untyped lines
+      if not m:
+        continue
+
+      line_type = m.group(1)
+      if line_type == "arg" and not start_args:
+        start_args = i
+        continue
+
+      if line_type == "ret" and not start_ret:
+        start_ret = i
+        continue
+
+    if not start_ret or not start_args:
+      errors.append(f"Function requires 'arg' and 'ret' comments.\n"
+                    f"'{name}' in {path}:\n'{line}'\n")
+      return None, errors
+
+    args_dict, parse_errors = parse_args(args)
+    data_from_sql = {'name': name, 'args': args_dict, 'ret': ret, 'sql': sql}
+    return (
+        FunctionDocs(
+            path,
+            data_from_sql,
+            module,
+            name,
+            comment[:start_args],
+            comment[start_args:start_ret],
+            comment[start_ret:],
+        ),
+        errors + parse_errors,
+    )
+
+  def check_comment(self) -> List[str]:
+    errors = validate_desc(self)
+    errors += validate_args(self)
+    errors += validate_ret(self)
+    return errors
 
 
-# Check that CREATE_FUNCTION has a matching schema before it.
-def match_create_functions(sql):
-  errors = 0
+# Stores documentation for CREATE_VIEW_FUNCTION with comment split into
+# segments.
+class ViewFunctionDocs:
 
-  line_to_match_dict = match_pattern(create_function_pattern(), sql)
-  if line_to_match_dict:
-    return []
+  def __init__(
+      self,
+      path: str,
+      data_from_sql: str,
+      module: str,
+      name: str,
+      desc: List[str],
+      args: List[str],
+      columns: List[str],
+  ):
+    self.path = path
+    self.data_from_sql = data_from_sql
+    self.module = module
+    self.name = name
+    self.desc = desc
+    self.args = args
+    self.columns = columns
 
-  functions = {}
-  for line_id, match_groups in line_to_match_dict.items():
-    name = match_groups[0]
-    if re.match(r'^INTERNAL_.*', name):
-      continue
+  # Contructs new ViewFunctionDocs from whole comment, by splitting it on typed
+  # lines. Returns None for improperly structured schemas.
+  @staticmethod
+  def create_from_comment(path: str, comment: List[str], module: str,
+                          matches: tuple[str]
+                         ) -> tuple["ViewFunctionDocs", List[str]]:
+    name, args, columns, sql = matches
 
-    parse_errors, args = parse_args(match_groups[1])
+    # Ignore internal functions.
+    if re.match(r"^INTERNAL_.*", name):
+      return None, []
+
+    errors = validate_name(name, module, upper=True)
+    start_args, start_cols = None, None
+
+    # Splits code into segments by finding beginning of args and cols segments.
+    for i, line in enumerate(comment):
+      # Ignore only '--' line.
+      if line == "--":
+        continue
+
+      m = re.match(typed_comment_pattern(), line)
+
+      # Ignore untyped lines
+      if not m:
+        continue
+
+      line_type = m.group(1)
+      if line_type == "arg" and not start_args:
+        start_args = i
+        continue
+
+      if line_type == "column" and not start_cols:
+        start_cols = i
+        continue
+
+    if not start_cols or not start_args:
+      errors.append(f"Function requires 'arg' and 'column' comments.\n"
+                    f"'{name}' in {path}:\n'{line}'\n")
+      return None, errors
+
+    args_dict, parse_errors = parse_args(args)
     errors += parse_errors
-    functions[line_id] = dict(
-        name=name, args=args, ret_type=match_groups[2], sql=match_groups[3])
 
-  return dict(sorted(functions.items()))
+    cols_dict, parse_errors = parse_args(columns)
+    errors += parse_errors
+
+    data_from_sql = dict(name=name, args=args_dict, columns=cols_dict, sql=sql)
+    return (
+        ViewFunctionDocs(
+            path,
+            data_from_sql,
+            module,
+            name,
+            comment[:start_args],
+            comment[start_args:start_cols],
+            comment[start_cols:],
+        ),
+        errors,
+    )
+
+  def check_comment(self) -> List[str]:
+    errors = validate_desc(self)
+    errors += validate_args(self)
+    errors += validate_columns(self, use_data_from_sql=True)
+    return errors
 
 
-def check_function_docs(path, rev_comment, fun_data):
-  errors = 0
-  has_ret, has_args, has_desc = False, False, False
+# Whether the name starts with module_name.
+def validate_name(name: str, module: str, upper: bool = False) -> List[str]:
+  module_pattern = f"^{module}_.*"
+  if upper:
+    module_pattern = module_pattern.upper()
+  starts_with_module_name = re.match(module_pattern, name)
+  if module == "common":
+    if starts_with_module_name:
+      return [(f"Invalid name in module {name}. "
+               f"In module 'common' the name shouldn't "
+               f"start with '{module_pattern}'.\n")]
+  else:
+    if not starts_with_module_name:
+      return [(f"Invalid name in module {name}. "
+               f"Name has to begin with '{module_pattern}'.\n")]
+  return []
 
-  for line in rev_comment:
-    # Break if the comment is finished
-    if not line or not line.startswith('--'):
-      break
 
-    # Ignore empty lines
+# Whether the only typed comment in provided comment segment is of type
+# `comment_type`.
+def validate_typed_comment(
+    comment_segment: str,
+    comment_type: str,
+    docs: Union["TableViewDocs", "FunctionDocs", "ViewFunctionDocs"],
+) -> List[str]:
+  for line in comment_segment:
+    # Ignore only '--' line.
     if line == "--":
       continue
 
-    if line.startswith('-- @ret'):
-      if has_ret:
-        sys.stderr.write(f"Function can only return one element: '{line}'\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
+    m = re.match(typed_comment_pattern(), line)
 
-      m = re.match(function_return_pattern(), line)
-      if m is None:
-        sys.stderr.write("The return docs formatting is wrong. It should be:\n"
-                         "-- @ret [A-Z]* {desc}\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
-
-      if fun_data['ret_type'] != m.group(1):
-        sys.stderr.write(
-            f"The code specifies {fun_data['ret_type']} as return type, "
-            f"but its {m.group(1)} in docs.\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
-
-      has_ret = True
+    # Ignore untyped lines
+    if not m:
       continue
 
-    if line.startswith('-- @arg'):
-      if not has_ret:
-        sys.stderr.write(
-            f"Arguments should be specified before return: '{line}'\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
+    line_type = m.group(1)
 
-      m = re.match(args_pattern(), line)
-      if m is None:
-        sys.stderr.write("The arg docs formatting is wrong. It should be:\n"
-                         "-- @arg [a-z_]* [A-Z]* {desc}\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
+    if line_type != comment_type:
+      return [(
+          f"Wrong comment type. Expected '{comment_type}', got '{line_type}'.\n"
+          f"'{docs.name}' in {docs.path}:\n'{line}'\n")]
+  return []
 
-      arg_name, arg_type = m.group(1), m.group(2)
-      if arg_name not in fun_data['args']:
-        sys.stderr.write(
-            f"There is not argument '{arg_name} specified in code.\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
 
-      if arg_type != fun_data['args'][arg_name]:
-        sys.stderr.write(
-            f"In the code, the type of '{arg_name} is "
-            f"{fun_data['args'][arg_name]}, but according to the docs "
-            f"it is '{arg_type}.\n")
-        sys.stderr.write(f'{path}:\n"{line}"\n')
-        errors += 1
-        continue
+# Whether comment segment with description of the object contains content.
+def validate_desc(
+    docs: Union["TableViewDocs", "FunctionDocs", "ViewFunctionDocs"]
+) -> List[str]:
+  for line in docs.desc:
+    if line == "--":
+      continue
+    return []
+  return [(f"Missing documentation for {docs.name}\n"
+           f"'{docs.name}' in {docs.path}:\n'{line}'\n")]
 
-      has_args = True
+
+# Whether comment segment about columns contain proper schema. Can be matched
+# against parsed SQL data by setting `use_data_from_sql`.
+def validate_columns(docs: Union["TableViewDocs", "ViewFunctionDocs"],
+                     use_data_from_sql=False) -> List[str]:
+  errors = validate_typed_comment(docs.columns, "column", docs)
+
+  if errors:
+    return errors
+
+  if use_data_from_sql:
+    cols_from_sql = docs.data_from_sql["columns"]
+
+  for line in docs.columns:
+    # Ignore only '--' line.
+    if line == "--" or not line.startswith("-- @column"):
       continue
 
-    if has_args:
-      has_desc = True
+    # Look for '-- @column' line as a column description
+    m = re.match(column_pattern(), line)
+    if not m:
+      errors.append(f"Wrong column description.\n"
+                    f"'{docs.name}' in {docs.path}:\n'{line}'\n")
+      continue
+
+    if not use_data_from_sql:
       return errors
 
-  if not has_ret:
-    sys.stderr.write(f"Return value was not specified in the documentation "
-                     f"of function '{fun_data['name']}'.\n")
-    sys.stderr.write(f'{path}')
-    errors += 1
+    col_name = m.group(1)
+    if col_name not in cols_from_sql:
+      errors.append(f"There is no argument '{col_name}' specified in code.\n"
+                    f"'{docs.name}' in {docs.path}:\n'{line}'\n")
+      continue
+
+    cols_from_sql.pop(col_name)
+
+  if not use_data_from_sql:
+    errors.append(f"Missing columns for {docs.name}\n{docs.path}\n")
     return errors
 
-  if not has_args:
-    sys.stderr.write(f"Arguments were not specified in the documentation "
-                     f"of function '{fun_data['name']}'.\n")
-    sys.stderr.write(f'{path}')
-    errors += 1
+  if not cols_from_sql:
     return errors
 
-  if not has_desc:
-    sys.stderr.write(f"Missing description of function '{fun_data['name']}'.\n")
-    sys.stderr.write(f'{path}')
-    errors += 1
+  errors.append(
+      f"Missing documentation of columns: {list(cols_from_sql.keys())}.\n"
+      f"'{docs.name}' in {docs.path}:\n")
+  return errors
+
+
+# Whether comment segment about columns contain proper schema. Matches against
+# parsed SQL data.
+def validate_args(docs: "FunctionDocs") -> List[str]:
+  errors = validate_typed_comment(docs.args, "arg", docs)
+
+  if errors:
     return errors
 
+  args_from_sql = docs.data_from_sql["args"]
+  for line in docs.args:
+    # Ignore only '--' line.
+    if line == "--" or not line.startswith("-- @"):
+      continue
 
-def check_create_functions(path, module, sql):
-  errors = 0
-  matched_create_functions = match_create_functions(sql)
+    m = re.match(args_pattern(), line)
+    if m is None:
+      errors.append("The arg docs formatting is wrong. It should be:\n"
+                    "-- @arg [a-z_]* [A-Z]* {desc}\n"
+                    f"'{docs.name}' in {docs.path}:\n'{line}'\n")
+      return errors
 
-  if not bool(matched_create_functions):
+    arg_name, arg_type = m.group(1), m.group(2)
+    if arg_name not in args_from_sql:
+      errors.append(f"There is not argument '{arg_name}' specified in code.\n"
+                    f"'{docs.name}' in {docs.path}:\n'{line}'\n")
+      continue
+
+    arg_type_from_sql = args_from_sql.pop(arg_name)
+    if arg_type != arg_type_from_sql:
+      errors.append(f"In the code, the type of '{arg_name}' is "
+                    f"'{arg_type_from_sql}', but according to the docs "
+                    f"it is '{arg_type}'.\n"
+                    f"'{docs.name}' in {docs.path}:\n'{line}'\n")
+
+  if not args_from_sql:
     return errors
 
-  lines = sql.split('\n')
-  for line_id, fun_data in matched_create_functions.items():
-    starts_with_module = fun_data['name'].startswith('{module}_'.upper())
-    if module == 'common' and starts_with_module:
-      sys.stderr.write(
-          f"For module 'common', function name shouldn't start with "
-          f"'COMMON_', as in {fun_data['name']}'.\n")
-      sys.stderr.write(f'{path}')
-      errors += 1
-    if module != 'common' and not starts_with_module:
-      sys.stderr.write(f"Function name ({fun_data['name']}) "
-                       f"should start with '{module.upper()}_'\n")
-      sys.stderr.write(f'{path}')
-      errors += 1
-    errors += check_function_docs(path, lines[line_id - 1::-1], fun_data)
+  errors.append(
+      f"Missing documentation of args: {list(args_from_sql.keys())}.\n"
+      f"{docs.path}\n")
+  return errors
+
+
+# Whether comment segment about return contain proper schema. Matches against
+# parsed SQL data.
+def validate_ret(docs: "FunctionDocs") -> List[str]:
+  errors = validate_typed_comment(docs.ret, "ret", docs)
+  if errors:
+    return errors
+
+  ret_type_from_sql = docs.data_from_sql["ret"]
+
+  for line in docs.ret:
+    # Ignore only '--' line.
+    if line == "--" or not line.startswith("-- @ret"):
+      continue
+
+    m = re.match(function_return_pattern(), line)
+    if m is None:
+      return [("The return docs formatting is wrong. It should be:\n"
+               "-- @ret [A-Z]* {desc}\n"
+               f"'{docs.name}' in {docs.path}:\n'{line}'\n")]
+    docs_ret_type = m.group(1)
+    if ret_type_from_sql != docs_ret_type:
+      return [(f"The return type in docs is '{docs_ret_type}', "
+               f"but it is {ret_type_from_sql} in code.\n"
+               f"'{docs.name}' in {docs.path}:\n'{line}'\n")]
+    return []
+
+
+# Parses string with multiple arguments with type separated by comma into dict.
+def parse_args(args_str: str) -> tuple[dict, List[str]]:
+  errors = []
+  args = {}
+  for arg_str in args_str.split(","):
+    m = re.match(arg_str_pattern(), arg_str)
+    if m is None:
+      errors.append(f"Wrong arguments formatting for '{arg_str}'\n")
+      continue
+    args[m.group(1)] = m.group(2)
+  return args, errors
+
+
+# After matching file to pattern, fetches and validates related documentation.
+def validate_docs_for_sql_object_type(path: str, module: str, sql: str,
+                                      pattern: str, docs_object: type):
+  errors = []
+  line_id_to_match = match_pattern(pattern, sql)
+  lines = sql.split("\n")
+  for line_id, matches in line_id_to_match.items():
+    # Fetch comment by looking at lines over beginning of match in reverse
+    # order.
+    comment = fetch_comment(lines[line_id - 1::-1])
+    docs, obj_errors = docs_object.create_from_comment(path, comment, module,
+                                                       matches)
+    errors += obj_errors
+    if docs:
+      errors += docs.check_comment()
 
   return errors
 
 
-def check(path):
-  errors = 0
+def check(path: str):
+  errors = []
 
   # Get module name
-  module_name = path.split('/stdlib/')[-1].split('/')[0]
+  module_name = path.split("/stdlib/")[-1].split("/")[0]
 
   with open(path) as f:
     sql = f.read()
 
-  errors += check_create_table_view(path, module_name, sql)
-  errors += check_create_functions(path, module_name, sql)
+  errors += validate_docs_for_sql_object_type(path, module_name, sql,
+                                              create_table_view_pattern(),
+                                              TableViewDocs)
+  errors += validate_docs_for_sql_object_type(path, module_name, sql,
+                                              create_function_pattern(),
+                                              FunctionDocs)
+  errors += validate_docs_for_sql_object_type(path, module_name, sql,
+                                              create_view_function_pattern(),
+                                              ViewFunctionDocs)
   return errors
 
 
 def main():
-  errors = 0
-  metrics_sources = os.path.join(ROOT_DIR, 'src', 'trace_processor', 'stdlib')
+  errors = []
+  metrics_sources = os.path.join(ROOT_DIR, "src", "trace_processor", "stdlib")
   for root, _, files in os.walk(metrics_sources, topdown=True):
     for f in files:
       path = os.path.join(root, f)
-      if path.endswith('.sql'):
+      if path.endswith(".sql"):
         errors += check(path)
-  return 0 if errors == 0 else 1
+  sys.stderr.write("\n\n".join(errors))
+  return 0 if not errors else 1
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
   sys.exit(main())