[stdlib]: Add an interval flatten macro

This can be used to flatten any rooted hierarchy of intervals to
a new set of intervals showing the leaves. E.g flatten a slice stack
to self time.

Update the slice flattening module to use this.

Test: tools/diff_test_trace_processor.py out/android/trace_processor_shell --name-filter '.*intervals_flatten.*'
Change-Id: If32cc214e219c4a4516552cc5618a928cd0e6be1
diff --git a/src/trace_processor/perfetto_sql/stdlib/intervals/overlap.sql b/src/trace_processor/perfetto_sql/stdlib/intervals/overlap.sql
index 3db21d4..b3f950a 100644
--- a/src/trace_processor/perfetto_sql/stdlib/intervals/overlap.sql
+++ b/src/trace_processor/perfetto_sql/stdlib/intervals/overlap.sql
@@ -90,4 +90,114 @@
 )
 SELECT count() AS has_overlaps
 FROM filtered
-);
\ No newline at end of file
+);
+
+-- Partition and flatten a hierarchy of intervals into non-overlapping intervals where
+-- each resulting interval is the leaf in the hierarchy at any given time. The result also
+-- denotes the 'self-time' of each interval.
+--
+-- Each interval is a (root_id, id, parent_id, ts, dur) and the overlap is also represented as a
+-- (root_id, id, parent_id, ts, dur).
+-- Note that, children intervals must not be longer than any ancestor interval.
+CREATE PERFETTO MACRO _intervals_flatten(
+  -- Table or subquery containing all the root intervals: (id, ts, dur). Note that parent_id
+  -- is not necessary in this table as it will be NULL anyways.
+  roots_table TableOrSubquery,
+  -- Table or subquery containing all the child intervals. (root_id, id, parent_id, ts, dur)
+  children_table TableOrSubquery)
+RETURNS TableOrSubquery
+  AS (
+    -- Algorithm: Sort all the start and end timestamps of the children within a root.
+    -- The interval duration between one timestamp and the next is one result.
+    -- If the timestamp is a start, the id is the id of the interval, if it's an end,
+    -- it's the parent_id.
+    -- Special case the edges of the roots and roots without children.
+  WITH
+    _roots AS (
+      SELECT * FROM ($roots_table) WHERE dur > 0
+    ),
+    _children AS (
+      SELECT * FROM ($children_table) WHERE dur > 0
+    ),
+    _roots_without_children AS (
+      SELECT id FROM _roots
+      EXCEPT
+      SELECT DISTINCT parent_id AS id FROM _children
+    ),
+    _children_with_root_ts_and_dur AS (
+      SELECT
+        _roots.id AS root_id,
+        _roots.ts AS root_ts,
+        _roots.dur AS root_dur,
+        _children.id,
+        _children.parent_id,
+        _children.ts,
+        _children.dur
+      FROM _children
+      JOIN _roots ON _roots.id = root_id
+    ),
+    _ends AS (
+      SELECT
+        child.root_id,
+        child.root_ts,
+        child.root_dur,
+        IFNULL(parent.id, child.root_id) AS id,
+        parent.parent_id,
+        child.ts + child.dur AS ts
+      FROM _children_with_root_ts_and_dur child
+      LEFT JOIN _children_with_root_ts_and_dur parent
+        ON child.parent_id = parent.id
+    ),
+    _events AS (
+      SELECT root_id, root_ts, root_dur, id, parent_id, ts FROM _children_with_root_ts_and_dur
+      UNION ALL
+      SELECT root_id, root_ts, root_dur, id, parent_id, ts FROM _ends
+    ),
+    _intervals AS (
+      SELECT
+        root_id,
+        root_ts,
+        root_dur,
+        id,
+        parent_id,
+        ts,
+        LEAD(ts)
+          OVER (PARTITION BY root_id ORDER BY ts) - ts AS dur
+      FROM _events
+    ),
+    _only_middle AS (
+      SELECT * FROM _intervals WHERE dur > 0
+    ),
+    _only_start AS (
+      SELECT
+        root_id,
+        parent_id AS id,
+        NULL AS parent_id,
+        root_ts AS ts,
+        MIN(ts) - root_ts AS dur
+      FROM _only_middle
+      GROUP BY root_id
+    ),
+    _only_end AS (
+      SELECT
+        root_id,
+        parent_id AS id,
+        NULL AS parent_id,
+        MAX(ts + dur) AS ts,
+        root_ts + root_dur - MAX(ts + dur) AS dur
+      FROM _only_middle
+      GROUP BY root_id
+    ),
+    _only_singleton AS (
+      SELECT id AS root_id, id, NULL AS parent_id, ts, dur
+      FROM _roots
+      JOIN _roots_without_children USING (id)
+    )
+  SELECT root_id, id, parent_id, ts, dur FROM _only_middle
+  UNION ALL
+  SELECT root_id, id, parent_id, ts, dur FROM _only_start
+  UNION ALL
+  SELECT root_id, id, parent_id, ts, dur FROM _only_end
+  UNION ALL
+  SELECT root_id, id, parent_id, ts, dur FROM _only_singleton
+);
diff --git a/src/trace_processor/perfetto_sql/stdlib/slices/flat_slices.sql b/src/trace_processor/perfetto_sql/stdlib/slices/flat_slices.sql
index 4353ed2..a02e014 100644
--- a/src/trace_processor/perfetto_sql/stdlib/slices/flat_slices.sql
+++ b/src/trace_processor/perfetto_sql/stdlib/slices/flat_slices.sql
@@ -12,6 +12,8 @@
 -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 -- See the License for the specific language governing permissions and
 -- limitations under the License.
+INCLUDE PERFETTO MODULE slices.with_context;
+INCLUDE PERFETTO MODULE intervals.overlap;
 
 -- The concept of a "flat slice" is to take the data in the slice table and
 -- remove all notion of nesting; we do this by projecting every slice in a stack to
@@ -51,55 +53,37 @@
 -- @column upid               Alias for `process.upid`.
 -- @column pid                Alias for `process.pid`.
 -- @column process_name       Alias for `process.name`.
-CREATE TABLE _slice_flattened AS
--- The algorithm proceeds as follows:
--- 1. Find the start and end timestamps of all slices.
--- 2. Iterate the generated timestamps within a stack in chronoligical order.
--- 3. Generate a slice for each timestamp pair (regardless of if it was a start or end)  .
--- 4. If the first timestamp in the pair was originally a start, the slice is the 'current' slice,
--- otherwise, the slice is the parent slice.
+CREATE TABLE _slice_flattened
+AS
 WITH
-  begins AS (
-    SELECT id AS slice_id, ts, name, track_id, depth
+  root_slices AS (
+    SELECT * FROM slice WHERE parent_id IS NULL
+  ),
+  child_slices AS (
+    SELECT anc.id AS root_id, slice.*
     FROM slice
-    WHERE dur > 0
+    JOIN ancestor_slice(slice.id) anc
+    WHERE slice.parent_id IS NOT NULL
   ),
-  ends AS (
-    SELECT
-      parent.id AS slice_id,
-      current.ts + current.dur AS ts,
-      parent.name as name,
-      current.track_id,
-      current.depth - 1 AS depth
-    FROM slice current
-    LEFT JOIN slice parent
-      ON current.parent_id = parent.id
-    WHERE current.dur > 0
-  ),
-  events AS (
-    SELECT * FROM begins
-    UNION ALL
-    SELECT * FROM ends
-  ),
-  data AS (
-    SELECT
-      events.slice_id,
-      events.ts,
-      LEAD(events.ts) OVER (
-         PARTITION BY events.track_id
-         ORDER BY events.ts) - events.ts AS dur,
-      events.depth,
-      events.name,
-      events.track_id
-    FROM events
+  flat_slices AS (
+    SELECT id, ts, dur FROM _intervals_flatten !(root_slices, child_slices)
   )
-SELECT data.slice_id, data.ts, data.dur, data.depth,
- data.name, data.track_id, thread.utid, thread.tid, thread.name as thread_name,
- process.upid, process.pid, process.name as process_name
- FROM data JOIN thread_track ON data.track_id = thread_track.id
-JOIN thread USING(utid)
-JOIN process USING(upid)
-WHERE depth != -1;
+SELECT
+  id AS slice_id,
+  flat_slices.ts,
+  flat_slices.dur,
+  depth,
+  name,
+  track_id,
+  utid,
+  tid,
+  thread_name,
+  upid,
+  pid,
+  process_name
+FROM flat_slices
+JOIN thread_slice
+  USING (id);
 
 CREATE
   INDEX _slice_flattened_id_idx
diff --git a/test/trace_processor/diff_tests/stdlib/intervals/tests.py b/test/trace_processor/diff_tests/stdlib/intervals/tests.py
index c00eca8..6b3fc5d 100644
--- a/test/trace_processor/diff_tests/stdlib/intervals/tests.py
+++ b/test/trace_processor/diff_tests/stdlib/intervals/tests.py
@@ -81,4 +81,35 @@
         "has_overlaps"
         0
         1
-        """))
\ No newline at end of file
+        """))
+
+  def test_intervals_flatten(self):
+    return DiffTestBlueprint(
+        trace=TextProto(""),
+        query="""
+        INCLUDE PERFETTO MODULE intervals.overlap;
+
+        WITH roots_data (id, ts, dur) AS (
+          VALUES
+            (0, 0, 7),
+            (1, 8, 1)
+        ), children_data (root_id, id, parent_id, ts, dur) AS (
+          VALUES
+            (0, 2, 0, 1, 3),
+            (0, 3, 0, 5, 1),
+            (0, 4, 2, 2, 1)
+        )
+        SELECT ts, dur, id, parent_id, root_id
+        FROM _intervals_flatten!(roots_data, children_data) ORDER BY ts
+        """,
+        out=Csv("""
+        "ts","dur","id","parent_id","root_id"
+        0,1,0,"[NULL]",0
+        1,1,2,0,0
+        2,1,4,2,0
+        3,1,2,0,0
+        4,1,0,"[NULL]",0
+        5,1,3,0,0
+        6,1,0,"[NULL]",0
+        8,1,1,"[NULL]",1
+        """))
diff --git a/test/trace_processor/diff_tests/stdlib/slices/tests.py b/test/trace_processor/diff_tests/stdlib/slices/tests.py
index 747ffb2..796fcbd 100644
--- a/test/trace_processor/diff_tests/stdlib/slices/tests.py
+++ b/test/trace_processor/diff_tests/stdlib/slices/tests.py
@@ -79,20 +79,21 @@
           JOIN thread_track ON e.track_id = thread_track.id
           JOIN thread USING(utid)
         WHERE thread.tid = 30196
+        ORDER BY ts
         LIMIT 10;
       """,
         out=Csv("""
         "name","ts","dur","depth"
         "EventForwarder::OnTouchEvent",1035865509936036,211000,0
-        "EventForwarder::OnTouchEvent",1035865510234036,48000,0
-        "EventForwarder::OnTouchEvent",1035865510673036,10000,0
         "GestureProvider::OnTouchEvent",1035865510147036,87000,1
+        "EventForwarder::OnTouchEvent",1035865510234036,48000,0
         "RenderWidgetHostImpl::ForwardTouchEvent",1035865510282036,41000,1
-        "RenderWidgetHostImpl::ForwardTouchEvent",1035865510331036,16000,1
-        "RenderWidgetHostImpl::ForwardTouchEvent",1035865510670036,3000,1
         "LatencyInfo.Flow",1035865510323036,8000,2
+        "RenderWidgetHostImpl::ForwardTouchEvent",1035865510331036,16000,1
         "PassthroughTouchEventQueue::QueueEvent",1035865510347036,30000,2
-        "PassthroughTouchEventQueue::QueueEvent",1035865510666036,4000,2
+        "InputRouterImpl::FilterAndSendWebInputEvent",1035865510377036,8000,3
+        "LatencyInfo.Flow",1035865510385036,126000,4
+        "RenderWidgetHostImpl::UserInputStarted",1035865510511036,7000,5
       """))
 
   def test_thread_slice_cpu_time(self):
@@ -117,4 +118,4 @@
         7,33333
         8,46926
         9,17865
-        """))
\ No newline at end of file
+        """))