exp: Import multiple WITH statements
diff --git a/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler.ts b/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler.ts index 2a44c93..db50270 100644 --- a/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler.ts +++ b/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler.ts
@@ -24,7 +24,6 @@ SerializedNode, deserializeState, } from './json_handler'; -import {SqlSourceSerializedState} from './query_builder/nodes/sources/sql_source'; import m from 'mithril'; export function showImportWithStatementModal( @@ -83,14 +82,21 @@ modules: string; } -function parseSqlWithModules(sql: string): ParsedSql { +function parseSqlWithModules( + sql: string, + existingNodeNames: Set<string> = new Set(), +): ParsedSql { const modules: string[] = []; const lines = sql.split('\n'); let lastModuleLine = -1; for (let i = 0; i < lines.length; i++) { const line = lines[i].trim(); - if (line.toUpperCase().startsWith('PERFETTO INCLUDE MODULE')) { + const upperLine = line.toUpperCase(); + if ( + upperLine.startsWith('PERFETTO INCLUDE MODULE') || + upperLine.startsWith('INCLUDE PERFETTO MODULE') + ) { modules.push(line); lastModuleLine = i; } else if (line !== '' && !line.startsWith('--')) { @@ -99,7 +105,7 @@ } const sqlWithoutModules = lines.slice(lastModuleLine + 1).join('\n'); - const nodes = parseSql(sqlWithoutModules); + const nodes = parseSql(sqlWithoutModules, existingNodeNames); return { nodes, @@ -107,13 +113,33 @@ }; } -function parseSql(sql: string): ParsedNode[] { +function parseSql( + sql: string, + existingNodeNames: Set<string> = new Set(), +): ParsedNode[] { // TODO(mayzner): This whole logic is very fragile and should be replaced // with Trace Processor SQL parser, when it becomes available. const nodes: ParsedNode[] = []; const sqlUpperCase = sql.toUpperCase(); const withIndex = sqlUpperCase.indexOf('WITH'); if (withIndex === -1) { + // This is not a WITH query, just a SELECT + const finalQueryWithVars = sql; + const finalDependencies: string[] = []; + + let finalNodeName = 'output'; + let counter = 1; + const currentNames = new Set(existingNodeNames); + while (currentNames.has(finalNodeName)) { + finalNodeName = `output_${counter}`; + counter++; + } + + nodes.push({ + name: finalNodeName, + query: finalQueryWithVars, + dependencies: finalDependencies, + }); return nodes; } @@ -198,8 +224,11 @@ let finalNodeName = 'output'; let counter = 1; - const existingNames = new Set(nodes.map((n) => n.name)); - while (existingNames.has(finalNodeName)) { + const currentNames = new Set([ + ...Array.from(existingNodeNames), + ...nodes.map((n) => n.name), + ]); + while (currentNames.has(finalNodeName)) { finalNodeName = `output_${counter}`; counter++; } @@ -214,14 +243,74 @@ } export function createGraphFromSql(sql: string): string { - const {nodes: parsedNodes, modules} = parseSqlWithModules(sql); + const statements = sql + .split(';') + .map((s) => s.trim()) + .filter((s) => s.length > 0); + const allParsedNodes: ParsedNode[] = []; + const existingNodeNames = new Set<string>(); + let pendingModules: string[] = []; + + for (const statement of statements) { + const upperStmt = statement.toUpperCase(); + if ( + upperStmt.startsWith('INCLUDE PERFETTO MODULE') || + upperStmt.startsWith('PERFETTO INCLUDE MODULE') + ) { + pendingModules.push(statement + ';'); + } else if (upperStmt.includes('WITH') || upperStmt.includes('SELECT')) { + const queryWithModules = [...pendingModules, statement].join('\n'); + pendingModules = []; + + const {nodes: parsedNodes, modules: modulesFromQuery} = + parseSqlWithModules(queryWithModules, existingNodeNames); + + if (modulesFromQuery && parsedNodes.length > 0) { + const firstRootNode = parsedNodes.find( + (p) => p.dependencies.length === 0, + ); + if (firstRootNode) { + firstRootNode.query = `${modulesFromQuery}\n${firstRootNode.query}`; + } + } + + // Deduplicate node names + const nameMapping = new Map<string, string>(); + for (const node of parsedNodes) { + let newName = node.name; + let counter = 1; + while (existingNodeNames.has(newName)) { + newName = `${node.name}_${counter}`; + counter++; + } + if (newName !== node.name) { + nameMapping.set(node.name, newName); + } + node.name = newName; + existingNodeNames.add(newName); + } + + // Update dependencies with new names + for (const node of parsedNodes) { + node.dependencies = node.dependencies.map( + (dep) => nameMapping.get(dep) || dep, + ); + for (const [oldName, newName] of nameMapping.entries()) { + const regex = new RegExp(`\\$${oldName}\\b`, 'g'); + node.query = node.query.replace(regex, `$${newName}`); + } + } + + allParsedNodes.push(...parsedNodes); + } + } + const serializedNodes: SerializedNode[] = []; const nodeLayouts: {[key: string]: NodeBoxLayout} = {}; const rootNodeIds: string[] = []; - const nodeMap = new Map<string, SerializedNode>(); - for (const parsedNode of parsedNodes) { + for (const parsedNode of allParsedNodes) { const nodeId = parsedNode.name; const node: SerializedNode = { nodeId, @@ -238,7 +327,7 @@ nodeMap.set(nodeId, node); } - for (const parsedNode of parsedNodes) { + for (const parsedNode of allParsedNodes) { const node = nodeMap.get(parsedNode.name)!; for (const dep of parsedNode.dependencies) { const depNode = nodeMap.get(dep)!; @@ -250,14 +339,6 @@ } } - if (modules && rootNodeIds.length > 0) { - const firstRootNode = nodeMap.get(rootNodeIds[0])!; - if (firstRootNode.type === NodeType.kSqlSource) { - const state = firstRootNode.state as SqlSourceSerializedState; - state.sql = `${modules}\n${state.sql}`; - } - } - const serializedGraph: SerializedGraph = { nodes: serializedNodes, rootNodeIds,
diff --git a/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler_unittest.ts b/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler_unittest.ts index c86ce29..7cc6229 100644 --- a/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler_unittest.ts +++ b/ui/src/plugins/dev.perfetto.ExplorePage/sql_json_handler_unittest.ts
@@ -494,4 +494,86 @@ expect(nodeOutput!.prevNodes).toEqual(['d']); }); + + it('should handle multiple WITH statements', () => { + const sql = ` + WITH a AS (SELECT 1) + SELECT * FROM a; + + WITH b AS (SELECT 2) + SELECT * FROM b; + `; + const graphJson = createGraphFromSql(sql); + const graph: SerializedGraph = JSON.parse(graphJson); + + expect(graph.nodes.length).toBe(4); + expect(graph.rootNodeIds).toEqual(['a', 'b']); + + const nodeA = graph.nodes.find((n) => n.nodeId === 'a'); + const nodeB = graph.nodes.find((n) => n.nodeId === 'b'); + const outputA = graph.nodes.find((n) => n.nodeId === 'output'); + const outputB = graph.nodes.find((n) => n.nodeId === 'output_1'); + + expect(nodeA).toBeDefined(); + expect(nodeB).toBeDefined(); + expect(outputA).toBeDefined(); + expect(outputB).toBeDefined(); + }); + + it('should handle includes not at the start of the query', () => { + const sql = ` + WITH a AS (SELECT 1) + SELECT * FROM a; + + INCLUDE PERFETTO MODULE android.slices; + + WITH b AS (SELECT 2) + SELECT * FROM b; + `; + const graphJson = createGraphFromSql(sql); + const graph: SerializedGraph = JSON.parse(graphJson); + + expect(graph.nodes.length).toBe(4); + expect(graph.rootNodeIds).toEqual(['a', 'b']); + + const nodeA = graph.nodes.find((n) => n.nodeId === 'a'); + const nodeB = graph.nodes.find((n) => n.nodeId === 'b'); + + expect(nodeA).toBeDefined(); + expect(nodeB).toBeDefined(); + + const stateB = nodeB!.state as SqlSourceSerializedState; + expect(stateB.sql).toContain('INCLUDE PERFETTO MODULE android.slices'); + }); + + it('should handle duplicated node names in multiple WITH statements', () => { + const sql = ` + WITH a AS (SELECT 1) + SELECT * FROM a; + + WITH a AS (SELECT 2) + SELECT * FROM a; + `; + const graphJson = createGraphFromSql(sql); + const graph: SerializedGraph = JSON.parse(graphJson); + + expect(graph.nodes.length).toBe(4); + + const nodeA1 = graph.nodes.find((n) => n.nodeId === 'a'); + const nodeA2 = graph.nodes.find((n) => n.nodeId === 'a_1'); + const output1 = graph.nodes.find((n) => n.nodeId === 'output'); + const output2 = graph.nodes.find((n) => n.nodeId === 'output_1'); + + expect(nodeA1).toBeDefined(); + expect(nodeA2).toBeDefined(); + expect(output1).toBeDefined(); + expect(output2).toBeDefined(); + + expect(nodeA1!.nextNodes).toEqual(['output']); + expect(nodeA2!.nextNodes).toEqual(['output_1']); + + expect(output2!.prevNodes).toEqual(['a_1']); + const state = output2!.state as SqlSourceSerializedState; + expect(state.sql).toBe('SELECT * FROM $a_1'); + }); });