Change yaml.load* methods to accept a Loader* instance instead of just Loader* classes All existing PyYAML usage should not be affected.
diff --git a/lib/yaml/__init__.py b/lib/yaml/__init__.py index 8c71105..8fbf8ef 100644 --- a/lib/yaml/__init__.py +++ b/lib/yaml/__init__.py
@@ -30,7 +30,12 @@ """ Scan a YAML stream and produce scanning tokens. """ - loader = Loader(stream) + + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + try: while loader.check_token(): yield loader.get_token() @@ -41,7 +46,12 @@ """ Parse a YAML stream and produce parsing events. """ - loader = Loader(stream) + + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + try: while loader.check_event(): yield loader.get_event() @@ -53,7 +63,12 @@ Parse the first YAML document in a stream and produce the corresponding representation tree. """ - loader = Loader(stream) + + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + try: return loader.get_single_node() finally: @@ -64,7 +79,15 @@ Parse all YAML documents in a stream and produce corresponding representation trees. """ - loader = Loader(stream) + + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + + if hasattr(loader, '_load_init'): + loader._load_init(stream) + try: while loader.check_node(): yield loader.get_node() @@ -76,18 +99,27 @@ Parse the first YAML document in a stream and produce the corresponding Python object. """ - loader = Loader(stream) - try: - return loader.get_single_data() - finally: - loader.dispose() + + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + + return loader.load(stream) def load_all(stream, Loader): """ Parse all YAML documents in a stream and produce corresponding Python objects. """ - loader = Loader(stream) + if hasattr(Loader, '_yaml_instance'): + loader = Loader + else: + loader = Loader(stream) + + if hasattr(loader, '_load_init'): + loader._load_init(stream) + try: while loader.check_data(): yield loader.get_data()
diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py index 619acd3..71f39a7 100644 --- a/lib/yaml/constructor.py +++ b/lib/yaml/constructor.py
@@ -22,6 +22,12 @@ yaml_multi_constructors = {} def __init__(self): + self._yaml_constructors = {} + self.add_constructor = self._add_constructor + self._yaml_multi_constructors = {} + self.add_multi_constructor = self._add_multi_constructor + self._add_constructors() + self.constructed_objects = {} self.recursive_objects = {} self.state_generators = [] @@ -76,9 +82,24 @@ self.recursive_objects[node] = None constructor = None tag_suffix = None - if node.tag in self.yaml_constructors: - constructor = self.yaml_constructors[node.tag] - else: + + if (hasattr(self, '_yaml_constructors') and + node.tag in self._yaml_constructors + ): + constructor = self._yaml_constructors[node.tag] + + if constructor is None: + if node.tag in self.yaml_constructors: + constructor = self.yaml_constructors[node.tag] + + if constructor is None and hasattr(self, '_yaml_constructors'): + for tag_prefix in self._yaml_multi_constructors: + if tag_prefix is not None and node.tag.startswith(tag_prefix): + tag_suffix = node.tag[len(tag_prefix):] + constructor = self._yaml_multi_constructors[tag_prefix] + break + + if constructor is None: for tag_prefix in self.yaml_multi_constructors: if tag_prefix is not None and node.tag.startswith(tag_prefix): tag_suffix = node.tag[len(tag_prefix):] @@ -157,19 +178,57 @@ return pairs @classmethod - def add_constructor(cls, tag, constructor): + def add_constructors(cls): if not 'yaml_constructors' in cls.__dict__: cls.yaml_constructors = cls.yaml_constructors.copy() + + for tag in cls.constructors: + if tag is not None and tag.endswith(':'): + cls.add_multi_constructor(tag, getattr(cls, cls.constructors[tag])) + else: + cls.add_constructor(tag, getattr(cls, cls.constructors[tag])) + + @classmethod + def add_constructor(cls, tag, constructor): cls.yaml_constructors[tag] = constructor @classmethod def add_multi_constructor(cls, tag_prefix, multi_constructor): - if not 'yaml_multi_constructors' in cls.__dict__: - cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() cls.yaml_multi_constructors[tag_prefix] = multi_constructor + def _add_constructors(self): + cls = type(self) + if hasattr(cls, 'constructors'): + for tag in cls.constructors: + if tag is not None and tag.endswith(':'): + self._add_multi_constructor(tag, getattr(cls, cls.constructors[tag])) + else: + self._add_constructor(tag, getattr(cls, cls.constructors[tag])) + + def _add_constructor(self, tag, constructor): + self._yaml_constructors[tag] = constructor + + def _add_multi_constructor(self, tag_prefix, multi_constructor): + self._yaml_multi_constructors[tag_prefix] = multi_constructor + class SafeConstructor(BaseConstructor): + constructors = { + 'tag:yaml.org,2002:null': 'construct_yaml_null', + 'tag:yaml.org,2002:bool': 'construct_yaml_bool', + 'tag:yaml.org,2002:int': 'construct_yaml_int', + 'tag:yaml.org,2002:float': 'construct_yaml_float', + 'tag:yaml.org,2002:binary': 'construct_yaml_binary', + 'tag:yaml.org,2002:timestamp': 'construct_yaml_timestamp', + 'tag:yaml.org,2002:omap': 'construct_yaml_omap', + 'tag:yaml.org,2002:pairs': 'construct_yaml_pairs', + 'tag:yaml.org,2002:set': 'construct_yaml_set', + 'tag:yaml.org,2002:str': 'construct_yaml_str', + 'tag:yaml.org,2002:seq': 'construct_yaml_seq', + 'tag:yaml.org,2002:map': 'construct_yaml_map', + None: 'construct_undefined', + } + def construct_scalar(self, node): if isinstance(node, MappingNode): for key_node, value_node in node.value: @@ -428,61 +487,29 @@ "could not determine a constructor for the tag %r" % node.tag, node.start_mark) -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:null', - SafeConstructor.construct_yaml_null) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:bool', - SafeConstructor.construct_yaml_bool) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:int', - SafeConstructor.construct_yaml_int) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:float', - SafeConstructor.construct_yaml_float) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:binary', - SafeConstructor.construct_yaml_binary) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:timestamp', - SafeConstructor.construct_yaml_timestamp) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:omap', - SafeConstructor.construct_yaml_omap) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:pairs', - SafeConstructor.construct_yaml_pairs) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:set', - SafeConstructor.construct_yaml_set) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:str', - SafeConstructor.construct_yaml_str) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:seq', - SafeConstructor.construct_yaml_seq) - -SafeConstructor.add_constructor( - 'tag:yaml.org,2002:map', - SafeConstructor.construct_yaml_map) - -SafeConstructor.add_constructor(None, - SafeConstructor.construct_undefined) +SafeConstructor.add_constructors() class FullConstructor(SafeConstructor): # 'extend' is blacklisted because it is used by # construct_python_object_apply to add `listitems` to a newly generate # python instance + + constructors = { + 'tag:yaml.org,2002:python/none': 'construct_yaml_null', + 'tag:yaml.org,2002:python/bool': 'construct_yaml_bool', + 'tag:yaml.org,2002:python/str': 'construct_python_str', + 'tag:yaml.org,2002:python/unicode': 'construct_python_unicode', + 'tag:yaml.org,2002:python/bytes': 'construct_python_bytes', + 'tag:yaml.org,2002:python/int': 'construct_yaml_int', + 'tag:yaml.org,2002:python/long': 'construct_python_long', + 'tag:yaml.org,2002:python/float': 'construct_yaml_float', + 'tag:yaml.org,2002:python/complex': 'construct_python_complex', + 'tag:yaml.org,2002:python/list': 'construct_yaml_seq', + 'tag:yaml.org,2002:python/tuple': 'construct_python_tuple', + 'tag:yaml.org,2002:python/dict': 'construct_yaml_map', + 'tag:yaml.org,2002:python/name:': 'construct_python_name', + } + def get_state_keys_blacklist(self): return ['^extend$', '^__.*__$'] @@ -658,60 +685,17 @@ def construct_python_object_new(self, suffix, node): return self.construct_python_object_apply(suffix, node, newobj=True) -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/none', - FullConstructor.construct_yaml_null) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/bool', - FullConstructor.construct_yaml_bool) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/str', - FullConstructor.construct_python_str) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/unicode', - FullConstructor.construct_python_unicode) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/bytes', - FullConstructor.construct_python_bytes) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/int', - FullConstructor.construct_yaml_int) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/long', - FullConstructor.construct_python_long) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/float', - FullConstructor.construct_yaml_float) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/complex', - FullConstructor.construct_python_complex) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/list', - FullConstructor.construct_yaml_seq) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/tuple', - FullConstructor.construct_python_tuple) - -FullConstructor.add_constructor( - 'tag:yaml.org,2002:python/dict', - FullConstructor.construct_yaml_map) - -FullConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/name:', - FullConstructor.construct_python_name) +FullConstructor.add_constructors() class UnsafeConstructor(FullConstructor): + constructors = { + 'tag:yaml.org,2002:python/module:': 'construct_python_module', + 'tag:yaml.org,2002:python/object:': 'construct_python_object', + 'tag:yaml.org,2002:python/object/new:': 'construct_python_object_new', + 'tag:yaml.org,2002:python/object/apply:': 'construct_python_object_apply', + } + def find_python_module(self, name, mark): return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True) @@ -726,21 +710,7 @@ return super(UnsafeConstructor, self).set_python_instance_state( instance, state, unsafe=True) -UnsafeConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/module:', - UnsafeConstructor.construct_python_module) - -UnsafeConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object:', - UnsafeConstructor.construct_python_object) - -UnsafeConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object/new:', - UnsafeConstructor.construct_python_object_new) - -UnsafeConstructor.add_multi_constructor( - 'tag:yaml.org,2002:python/object/apply:', - UnsafeConstructor.construct_python_object_apply) +UnsafeConstructor.add_constructors() # Constructor is same as UnsafeConstructor. Need to leave this in place in case # people have extended it directly.
diff --git a/lib/yaml/loader.py b/lib/yaml/loader.py index e90c112..40ae352 100644 --- a/lib/yaml/loader.py +++ b/lib/yaml/loader.py
@@ -10,41 +10,52 @@ class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver): - def __init__(self, stream): + def _base_init(self, stream): + self._yaml_instance = True + Reader.__init__(self, stream) Scanner.__init__(self) Parser.__init__(self) Composer.__init__(self) + + def __init__(self, stream=None): + self._base_init(stream) BaseConstructor.__init__(self) BaseResolver.__init__(self) -class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver): + def load(self, stream): + self._load_init(stream) - def __init__(self, stream): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) + try: + return self.get_single_data() + finally: + self.dispose() + self.stream = None + + def _load_init(self, stream): + if hasattr(self, 'stream') and self.stream is None: + if stream is None: + raise TypeError("load() requires stream=...") + self._base_init(stream) + +class FullLoader(BaseLoader, Reader, Scanner, Parser, Composer, FullConstructor, Resolver): + + def __init__(self, stream=None): + self._base_init(stream) FullConstructor.__init__(self) Resolver.__init__(self) -class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): +class SafeLoader(BaseLoader, Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): - def __init__(self, stream): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) + def __init__(self, stream=None): + self._base_init(stream) SafeConstructor.__init__(self) Resolver.__init__(self) -class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver): +class Loader(BaseLoader, Reader, Scanner, Parser, Composer, Constructor, Resolver): - def __init__(self, stream): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) + def __init__(self, stream=None): + self._base_init(stream) Constructor.__init__(self) Resolver.__init__(self) @@ -52,12 +63,9 @@ # untrusted input). Use of either Loader or UnsafeLoader should be rare, since # FullLoad should be able to load almost all YAML safely. Loader is left intact # to ensure backwards compatibility. -class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver): +class UnsafeLoader(BaseLoader, Reader, Scanner, Parser, Composer, Constructor, Resolver): - def __init__(self, stream): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) + def __init__(self, stream=None): + self._base_init(stream) Constructor.__init__(self) Resolver.__init__(self)
diff --git a/lib/yaml/reader.py b/lib/yaml/reader.py index 774b021..32ba51f 100644 --- a/lib/yaml/reader.py +++ b/lib/yaml/reader.py
@@ -56,7 +56,7 @@ # Yeah, it's ugly and slow. - def __init__(self, stream): + def __init__(self, stream=None): self.name = None self.stream = None self.stream_pointer = 0 @@ -69,7 +69,9 @@ self.index = 0 self.line = 0 self.column = 0 - if isinstance(stream, str): + if stream is None: + pass + elif isinstance(stream, str): self.name = "<unicode string>" self.check_printable(stream) self.buffer = stream+'\0'
diff --git a/tests/lib/canonical.py b/tests/lib/canonical.py index a8b4e3a..7d3ae12 100644 --- a/tests/lib/canonical.py +++ b/tests/lib/canonical.py
@@ -354,7 +354,7 @@ yaml.canonical_load = canonical_load -def canonical_load_all(stream): +def canonical_load_all(stream=None): return yaml.load_all(stream, Loader=CanonicalLoader) yaml.canonical_load_all = canonical_load_all