diff --git a/Lib/ldap/cidict.py b/Lib/ldap/cidict.py index 875d5f5..48aeacb 100644 --- a/Lib/ldap/cidict.py +++ b/Lib/ldap/cidict.py @@ -7,55 +7,68 @@ """ import warnings +from ldap.compat import MutableMapping from ldap import __version__ -from ldap.compat import IterableUserDict +class cidict(MutableMapping): + """ + Case-insensitive but case-respecting dictionary. + """ + __slots__ = ('_keys', '_data') -class cidict(IterableUserDict): - """ - Case-insensitive but case-respecting dictionary. - """ + def __init__(self, default=None): + self._keys = {} + self._data = {} + if default: + self.update(default) + + # MutableMapping abstract methods + + def __getitem__(self, key): + return self._data[key.lower()] - def __init__(self,default=None): - self._keys = {} - IterableUserDict.__init__(self,{}) - self.update(default or {}) + def __setitem__(self, key, value): + lower_key = key.lower() + self._keys[lower_key] = key + self._data[lower_key] = value - def __getitem__(self,key): - return self.data[key.lower()] + def __delitem__(self, key): + lower_key = key.lower() + del self._keys[lower_key] + del self._data[lower_key] - def __setitem__(self,key,value): - lower_key = key.lower() - self._keys[lower_key] = key - self.data[lower_key] = value + def __iter__(self): + return iter(self._keys.values()) - def __delitem__(self,key): - lower_key = key.lower() - del self._keys[lower_key] - del self.data[lower_key] + def __len__(self): + return len(self._keys) - def update(self,dict): - for key, value in dict.items(): - self[key] = value + # Specializations for performance - def has_key(self,key): - return key in self + def __contains__(self, key): + return key.lower() in self._keys - def __contains__(self,key): - return IterableUserDict.__contains__(self, key.lower()) + def clear(self): + self._keys.clear() + self._data.clear() - def __iter__(self): - return iter(self.keys()) + # Backwards compatibility - def keys(self): - return self._keys.values() + def has_key(self, key): + """Compatibility with python-ldap 2.x""" + return key in self - def items(self): - result = [] - for k in self._keys.values(): - result.append((k,self[k])) - return result + @property + def data(self): + """Compatibility with older IterableUserDict-based implementation""" + warnings.warn( + 'ldap.cidict.cidict.data is an internal attribute; it may be ' + + 'removed at any time', + category=DeprecationWarning, + stacklevel=2, + ) + return self._data def strlist_minus(a,b): diff --git a/Lib/ldap/compat.py b/Lib/ldap/compat.py index cbfeef5..901457b 100644 --- a/Lib/ldap/compat.py +++ b/Lib/ldap/compat.py @@ -10,6 +10,7 @@ from urllib import unquote as urllib_unquote from urllib import urlopen from urlparse import urlparse + from collections import MutableMapping def unquote(uri): """Specialized unquote that uses UTF-8 for parsing.""" @@ -33,6 +34,7 @@ def unquote(uri): IterableUserDict = UserDict from urllib.parse import quote, quote_plus, unquote, urlparse from urllib.request import urlopen + from collections.abc import MutableMapping def reraise(exc_type, exc_value, exc_traceback): """Re-raise an exception given information from sys.exc_info() diff --git a/Lib/ldapurl.py b/Lib/ldapurl.py index 6de0645..a3dd7ff 100644 --- a/Lib/ldapurl.py +++ b/Lib/ldapurl.py @@ -16,7 +16,7 @@ 'LDAPUrlExtension','LDAPUrlExtensions','LDAPUrl' ] -from ldap.compat import UserDict, quote, unquote +from ldap.compat import quote, unquote, MutableMapping LDAP_SCOPE_BASE = 0 LDAP_SCOPE_ONELEVEL = 1 @@ -130,58 +130,71 @@ def __ne__(self,other): return not self.__eq__(other) -class LDAPUrlExtensions(UserDict): - """ - Models a collection of LDAP URL extensions as - dictionary type - """ - - def __init__(self,default=None): - UserDict.__init__(self) - for k,v in (default or {}).items(): - self[k]=v - - def __setitem__(self,name,value): +class LDAPUrlExtensions(MutableMapping): """ - value - Either LDAPUrlExtension instance, (critical,exvalue) - or string'ed exvalue + Models a collection of LDAP URL extensions as + a mapping type """ - assert isinstance(value,LDAPUrlExtension) - assert name==value.extype - self.data[name] = value - - def values(self): - return [ - self[k] - for k in self.keys() - ] - - def __str__(self): - return ','.join(str(v) for v in self.values()) - - def __repr__(self): - return '<%s.%s instance at %s: %s>' % ( - self.__class__.__module__, - self.__class__.__name__, - hex(id(self)), - self.data - ) + __slots__ = ('_data', ) + + def __init__(self, default=None): + self._data = {} + if default is not None: + self.update(default) + + def __setitem__(self, name, value): + """Store an extension + + name + string + value + LDAPUrlExtension instance, whose extype nust match `name` + """ + if not isinstance(value, LDAPUrlExtension): + raise TypeError("value must be LDAPUrlExtension, not " + + type(value).__name__) + if name != value.extype: + raise ValueError( + "key {!r} does not match extension type {!r}".format( + name, value.extype)) + self._data[name] = value + + def __getitem__(self, name): + return self._data[name] + + def __delitem__(self, name): + del self._data[name] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __str__(self): + return ','.join(str(v) for v in self.values()) + + def __repr__(self): + return '<%s.%s instance at %s: %s>' % ( + self.__class__.__module__, + self.__class__.__name__, + hex(id(self)), + self._data + ) - def __eq__(self,other): - assert isinstance(other,self.__class__),TypeError( - "other has to be instance of %s" % (self.__class__) - ) - return self.data==other.data + def __eq__(self,other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._data == other._data - def parse(self,extListStr): - for extension_str in extListStr.strip().split(','): - if extension_str: - e = LDAPUrlExtension(extension_str) - self[e.extype] = e + def parse(self,extListStr): + for extension_str in extListStr.strip().split(','): + if extension_str: + e = LDAPUrlExtension(extension_str) + self[e.extype] = e - def unparse(self): - return ','.join([ v.unparse() for v in self.values() ]) + def unparse(self): + return ','.join(v.unparse() for v in self.values()) class LDAPUrl(object): @@ -366,17 +379,23 @@ def htmlHREF(self,urlPrefix='',hrefText=None,hrefTarget=None): hrefTarget string added as link target attribute """ - assert type(urlPrefix)==StringType, "urlPrefix must be StringType" + if not isinstance(urlPrefix, str): + raise TypeError("urlPrefix must be str, not " + + type(urlPrefix).__name__) if hrefText is None: - hrefText = self.unparse() - assert type(hrefText)==StringType, "hrefText must be StringType" + hrefText = self.unparse() + if not isinstance(hrefText, str): + raise TypeError("hrefText must be str, not " + + type(hrefText).__name__) if hrefTarget is None: - target = '' + target = '' else: - assert type(hrefTarget)==StringType, "hrefTarget must be StringType" - target = ' target="%s"' % hrefTarget + if not isinstance(hrefTarget, str): + raise TypeError("hrefTarget must be str, not " + + type(hrefTarget).__name__) + target = ' target="%s"' % hrefTarget return '%s' % ( - target,urlPrefix,self.unparse(),hrefText + target, urlPrefix, self.unparse(), hrefText ) def __str__(self): diff --git a/Tests/t_cidict.py b/Tests/t_cidict.py index b96a26e..6878617 100644 --- a/Tests/t_cidict.py +++ b/Tests/t_cidict.py @@ -62,6 +62,16 @@ def test_strlist_deprecated(self): strlist_func(["a"], ["b"]) self.assertEqual(len(w), 1) + def test_cidict_data(self): + """test the deprecated data atrtribute""" + d = ldap.cidict.cidict({'A': 1, 'B': 2}) + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter('always', DeprecationWarning) + data = d.data + assert data == {'a': 1, 'b': 2} + self.assertEqual(len(w), 1) + if __name__ == '__main__': unittest.main()