| # Copyright (C) 2022 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import unittest |
| |
| from perfetto.trace_uri_resolver.util import parse_trace_uri |
| from perfetto.trace_uri_resolver.util import to_list |
| from perfetto.trace_uri_resolver.util import _cs_list |
| from perfetto.trace_uri_resolver.util import and_list |
| from perfetto.trace_uri_resolver.util import or_list |
| from perfetto.trace_uri_resolver.resolver import _args_dict_from_uri |
| from perfetto.trace_uri_resolver.resolver import Constraint |
| from perfetto.trace_uri_resolver.resolver import ConstraintClass |
| from perfetto.trace_uri_resolver.resolver import TraceUriResolver |
| from perfetto.trace_uri_resolver.registry import ResolverRegistry |
| |
| |
| class SimpleResolver(TraceUriResolver): |
| PREFIX = 'simple' |
| |
| def __init__(self, foo=None, bar=None): |
| self.foo = foo |
| self.bar = bar |
| |
| def foo_gen(self): |
| yield self.foo.encode() if self.foo else b'' |
| |
| def bar_gen(self): |
| yield self.bar.encode() if self.bar else b'' |
| |
| def resolve(self): |
| return [ |
| TraceUriResolver.Result(self.foo_gen()), |
| TraceUriResolver.Result( |
| self.bar_gen(), metadata={ |
| 'foo': self.foo, |
| 'bar': self.bar |
| }) |
| ] |
| |
| |
| class RecursiveResolver(SimpleResolver): |
| PREFIX = 'recursive' |
| |
| def __init__(self, foo=None, bar=None): |
| super().__init__(foo=foo, bar=bar) |
| |
| def resolve(self): |
| return [ |
| TraceUriResolver.Result(self.foo_gen()), |
| TraceUriResolver.Result( |
| self.bar_gen(), metadata={ |
| 'foo': 'foo', |
| 'bar': 'bar' |
| }), |
| TraceUriResolver.Result(f'simple:foo={self.foo};bar={self.bar}'), |
| TraceUriResolver.Result(SimpleResolver(foo=self.foo, bar=self.bar)), |
| ] |
| |
| |
| class TestResolver(unittest.TestCase): |
| |
| def test_simple_resolve(self): |
| registry = ResolverRegistry([SimpleResolver]) |
| |
| res = registry.resolve('simple:foo=x;bar=y') |
| self.assertEqual(len(res), 2) |
| |
| (foo_res, bar_res) = res |
| self._check_resolver_result(foo_res, bar_res) |
| |
| (foo_res, bar_res) = registry.resolve(['simple:foo=x;bar=y']) |
| self._check_resolver_result(foo_res, bar_res) |
| |
| resolver = SimpleResolver(foo='x', bar='y') |
| |
| (foo_res, bar_res) = registry.resolve(resolver) |
| self._check_resolver_result(foo_res, bar_res) |
| |
| (foo_res, bar_res) = registry.resolve([resolver]) |
| self._check_resolver_result(foo_res, bar_res) |
| |
| (foo_a, bar_b, foo_x, |
| bar_y) = registry.resolve(['simple:foo=a;bar=b', resolver]) |
| self._check_resolver_result(foo_a, bar_b, foo='a', bar='b') |
| self._check_resolver_result(foo_x, bar_y) |
| |
| def test_simple_resolve_missing_arg(self): |
| registry = ResolverRegistry([SimpleResolver]) |
| |
| (foo_res, bar_res) = registry.resolve('simple:foo=x') |
| self._check_resolver_result(foo_res, bar_res, bar=None) |
| |
| (foo_res, bar_res) = registry.resolve('simple:bar=y') |
| self._check_resolver_result(foo_res, bar_res, foo=None) |
| |
| (foo_res, bar_res) = registry.resolve('simple:') |
| self._check_resolver_result(foo_res, bar_res, foo=None, bar=None) |
| |
| def test_recursive_resolve(self): |
| registry = ResolverRegistry([SimpleResolver]) |
| registry.register(RecursiveResolver) |
| |
| res = registry.resolve('recursive:foo=x;bar=y') |
| self.assertEqual(len(res), 6) |
| |
| (non_rec_foo, non_rec_bar, rec_foo_str, rec_bar_str, rec_foo_obj, |
| rec_bar_obj) = res |
| |
| self._check_resolver_result( |
| non_rec_foo, non_rec_bar, foo_metadata='foo', bar_metadata='bar') |
| self._check_resolver_result(rec_foo_str, rec_bar_str) |
| self._check_resolver_result(rec_foo_obj, rec_bar_obj) |
| |
| def test_parse_trace_uri(self): |
| self.assertEqual(parse_trace_uri('/foo/bar'), (None, '/foo/bar')) |
| self.assertEqual(parse_trace_uri('foo/bar'), (None, 'foo/bar')) |
| self.assertEqual(parse_trace_uri('/foo/b:ar'), (None, '/foo/b:ar')) |
| self.assertEqual(parse_trace_uri('./foo/b:ar'), (None, './foo/b:ar')) |
| self.assertEqual(parse_trace_uri('foo/b:ar'), ('foo/b', 'ar')) |
| |
| def test_to_list(self): |
| self.assertEqual(to_list(None), None) |
| self.assertEqual(to_list(1), [1]) |
| self.assertEqual(to_list('1'), ['1']) |
| self.assertEqual(to_list([]), []) |
| self.assertEqual(to_list([1]), [1]) |
| |
| def test_cs_list(self): |
| fn = 'col = {}'.format |
| sep = ' || ' |
| self.assertEqual(_cs_list(None, fn, 'FALSE', sep), 'TRUE') |
| self.assertEqual(_cs_list(None, fn, 'TRUE', sep), 'TRUE') |
| self.assertEqual(_cs_list([], fn, 'FALSE', sep), 'FALSE') |
| self.assertEqual(_cs_list([], fn, 'TRUE', sep), 'TRUE') |
| self.assertEqual(_cs_list([1], fn, 'FALSE', sep), '(col = 1)') |
| self.assertEqual(_cs_list([1, 2], fn, 'FALSE', sep), '(col = 1 || col = 2)') |
| |
| def test_and_list(self): |
| fn = 'col != {}'.format |
| self.assertEqual(and_list([1, 2], fn, 'FALSE'), '(col != 1 AND col != 2)') |
| |
| def test_or_list(self): |
| fn = 'col = {}'.format |
| self.assertEqual(or_list([1, 2], fn, 'FALSE'), '(col = 1 OR col = 2)') |
| |
| def test_args_dict_from_uri(self): |
| self.assertEqual(_args_dict_from_uri('foo:', {}), {}) |
| self.assertEqual(_args_dict_from_uri('foo:bar=baz', {}), { |
| 'bar': 'baz', |
| }) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key=v1,v2', {}), {'key': ['v1', 'v2']}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:bar=baz;key=v1,v2', {}), { |
| 'bar': 'baz', |
| 'key': ['v1', 'v2'] |
| }) |
| with self.assertRaises(ValueError): |
| _args_dict_from_uri('foo:=v1', {}) |
| with self.assertRaises(ValueError): |
| _args_dict_from_uri('foo:key', {}) |
| with self.assertRaises(ValueError): |
| _args_dict_from_uri('foo:key<', {}) |
| with self.assertRaises(ValueError): |
| _args_dict_from_uri('foo:key<v1', {}) |
| with self.assertRaises(ValueError): |
| _args_dict_from_uri('foo:key<v1', {'key': str}) |
| |
| type_hints = {'key': Constraint[str]} |
| self.assertEqual( |
| _args_dict_from_uri('foo:key=v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.EQ)}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key!=v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.NE)}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key<=v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.LE)}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key>=v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.GE)}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key>v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.GT)}) |
| self.assertEqual( |
| _args_dict_from_uri('foo:key<v1', type_hints), |
| {'key': ConstraintClass('v1', ConstraintClass.Op.LT)}) |
| |
| def _check_resolver_result(self, |
| foo_res, |
| bar_res, |
| foo='x', |
| bar='y', |
| foo_metadata=None, |
| bar_metadata=None): |
| self.assertEqual( |
| tuple(foo_res.generator), (foo.encode() if foo else ''.encode(),)) |
| self.assertEqual( |
| tuple(bar_res.generator), (bar.encode() if bar else ''.encode(),)) |
| self.assertEqual( |
| bar_res.metadata, { |
| 'foo': foo_metadata if foo_metadata else foo, |
| 'bar': bar_metadata if bar_metadata else bar |
| }) |