Skip to content

Commit

Permalink
py3: Make the bytes/text distinction
Browse files Browse the repository at this point in the history
- DNs, attribute names, URLs are text (encoded to UTF-8 on the wire)
- Attribute values are always bytes

A "bytes_mode" switch controls behavior under Python 2.
  • Loading branch information
pyldap contributors authored and Petr Viktorin committed Nov 24, 2017
1 parent 750fe8c commit fa35757
Show file tree
Hide file tree
Showing 23 changed files with 636 additions and 172 deletions.
3 changes: 3 additions & 0 deletions Lib/ldap/dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
See https://www.python-ldap.org/ for details.
"""

import sys
from ldap.pkginfo import __version__

import _ldap
Expand Down Expand Up @@ -46,6 +47,8 @@ def str2dn(dn,flags=0):
"""
if not dn:
return []
if sys.version_info[0] < 3 and isinstance(dn, unicode):
dn = dn.encode('utf-8')
return ldap.functions._ldap_function_call(None,_ldap.str2dn,dn,flags)


Expand Down
12 changes: 8 additions & 4 deletions Lib/ldap/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _ldap_function_call(lock,func,*args,**kwargs):
return result


def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None, bytes_mode=None):
"""
Return LDAPObject instance by opening LDAP connection to
LDAP host specified by LDAP URL
Expand All @@ -76,11 +76,13 @@ def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
trace_file
File object where to write the trace output to.
Default is to use stdout.
bytes_mode
Whether to enable "bytes_mode" for backwards compatibility under Py2.
"""
return LDAPObject(uri,trace_level,trace_file,trace_stack_limit)
return LDAPObject(uri,trace_level,trace_file,trace_stack_limit,bytes_mode)


def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None,bytes_mode=None):
"""
Return LDAPObject instance by opening LDAP connection to
specified LDAP host
Expand All @@ -95,10 +97,12 @@ def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=Non
trace_file
File object where to write the trace output to.
Default is to use stdout.
bytes_mode
Whether to enable "bytes_mode" for backwards compatibility under Py2.
"""
import warnings
warnings.warn('ldap.open() is deprecated! Use ldap.initialize() instead.', DeprecationWarning,2)
return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit)
return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit,bytes_mode)

init = open

Expand Down
214 changes: 210 additions & 4 deletions Lib/ldap/ldapobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
See https://www.python-ldap.org/ for details.
"""

from __future__ import unicode_literals

from os import strerror

from ldap.pkginfo import __version__, __author__, __license__
Expand All @@ -20,6 +22,7 @@
import traceback

import sys,time,pprint,_ldap,ldap,ldap.sasl,ldap.functions
import warnings

from ldap.schema import SCHEMA_ATTRS
from ldap.controls import LDAPControl,DecodeControlTuples,RequestControlTuples
Expand All @@ -28,6 +31,11 @@

from ldap import LDAPError

PY2 = bool(sys.version_info[0] <= 2)
if PY2:
text_type = unicode
else:
text_type = str

class NO_UNIQUE_ENTRY(ldap.NO_SUCH_OBJECT):
"""
Expand Down Expand Up @@ -55,7 +63,7 @@ class SimpleLDAPObject:

def __init__(
self,uri,
trace_level=0,trace_file=None,trace_stack_limit=5
trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None
):
self._trace_level = trace_level
self._trace_file = trace_file or sys.stdout
Expand All @@ -66,6 +74,186 @@ def __init__(
self.timeout = -1
self.protocol_version = ldap.VERSION3

# Bytes mode
# ----------

# By default, raise a TypeError when receiving invalid args
self.bytes_mode_hardfail = True
if bytes_mode is None and PY2:
warnings.warn(
"Under Python 2, python-ldap uses bytes by default. "
"This will be removed in Python 3 (no bytes for DN/RDN/field names). "
"Please call initialize(..., bytes_mode=False) explicitly.",
BytesWarning,
stacklevel=2,
)
bytes_mode = True
# Disable hard failure when running in backwards compatibility mode.
self.bytes_mode_hardfail = False
elif bytes_mode and not PY2:
raise ValueError("bytes_mode is *not* supported under Python 3.")
# On by default on Py2, off on Py3.
self.bytes_mode = bytes_mode

def _bytesify_input(self, value):
"""Adapt a value following bytes_mode in Python 2.
In Python 3, returns the original value unmodified.
With bytes_mode ON, takes bytes or None and returns bytes or None.
With bytes_mode OFF, takes unicode or None and returns bytes or None.
This function should be applied on all text inputs (distinguished names
and attribute names in modlists) to convert them to the bytes expected
by the C bindings.
"""
if not PY2:
return value

if value is None:
return value
elif self.bytes_mode:
if isinstance(value, bytes):
return value
else:
if self.bytes_mode_hardfail:
raise TypeError("All provided fields *must* be bytes when bytes mode is on; got %r" % (value,))
else:
warnings.warn(
"Received non-bytes value %r with default (disabled) bytes mode; please choose an explicit "
"option for bytes_mode on your LDAP connection" % (value,),
BytesWarning,
stacklevel=6,
)
return value.encode('utf-8')
else:
if not isinstance(value, text_type):
raise TypeError("All provided fields *must* be text when bytes mode is off; got %r" % (value,))
assert not isinstance(value, bytes)
return value.encode('utf-8')

def _bytesify_inputs(self, *values):
"""Adapt values following bytes_mode.
Applies _bytesify_input on each arg.
Usage:
>>> a, b, c = self._bytesify_inputs(a, b, c)
"""
if not PY2:
return values
return (
self._bytesify_input(value)
for value in values
)

def _bytesify_modlist(self, modlist, with_opcode):
"""Adapt a modlist according to bytes_mode.
A modlist is a tuple of (op, attr, value), where:
- With bytes_mode ON, attr is checked to be bytes
- With bytes_mode OFF, attr is converted from unicode to bytes
- value is *always* bytes
"""
if not PY2:
return modlist
if with_opcode:
return tuple(
(op, self._bytesify_input(attr), val)
for op, attr, val in modlist
)
else:
return tuple(
(self._bytesify_input(attr), val)
for attr, val in modlist
)

def _unbytesify_text_value(self, value):
"""Adapt a 'known text, UTF-8 encoded' returned value following bytes_mode.
With bytes_mode ON, takes bytes or None and returns bytes or None.
With bytes_mode OFF, takes bytes or None and returns unicode or None.
This function should only be applied on field *values*; distinguished names
or field *names* are already natively handled in result4.
"""
if value is None:
return value

# Preserve logic of assertions only under Python 2
if PY2:
assert isinstance(value, bytes), "Expected bytes value, got text instead (%r)" % (value,)

if self.bytes_mode:
return value
else:
return value.decode('utf-8')

def _maybe_rebytesify_text(self, value):
"""Re-encodes text to bytes if needed by bytes_mode.
Takes unicode (and checks for it), and returns:
- bytes under bytes_mode
- unicode otherwise.
"""
if not PY2:
return value

if value is None:
return value

assert isinstance(value, text_type), "Should return text, got bytes instead (%r)" % (value,)
if not self.bytes_mode:
return value
else:
return value.encode('utf-8')

def _bytesify_result_value(self, result_value):
"""Applies bytes_mode to a result value.
Such a value can either be:
- a dict mapping an attribute name to its list of values
(where attribute names are unicode and values bytes)
- a list of referals (which are unicode)
"""
if not PY2:
return result_value
if hasattr(result_value, 'items'):
# It's a attribute_name: [values] dict
return dict(
(self._maybe_rebytesify_text(key), value)
for (key, value) in result_value.items()
)
elif isinstance(result_value, bytes):
return result_value
else:
# It's a list of referals
# Example value:
# [u'ldap://DomainDnsZones.xxxx.root.local/DC=DomainDnsZones,DC=xxxx,DC=root,DC=local']
return [self._maybe_rebytesify_text(referal) for referal in result_value]

def _bytesify_results(self, results, with_ctrls=False):
"""Converts a "results" object according to bytes_mode.
Takes:
- a list of (dn, {field: [values]}) if with_ctrls is False
- a list of (dn, {field: [values]}, ctrls) if with_ctrls is True
And, if bytes_mode is on, converts dn and fields to bytes.
"""
if not PY2:
return results
if with_ctrls:
return [
(self._maybe_rebytesify_text(dn), self._bytesify_result_value(fields), ctrls)
for (dn, fields, ctrls) in results
]
else:
return [
(self._maybe_rebytesify_text(dn), self._bytesify_result_value(fields))
for (dn, fields) in results
]

def _ldap_lock(self,desc=''):
if ldap.LIBLDAP_R:
return ldap.LDAPLock(desc='%s within %s' %(desc,repr(self)))
Expand Down Expand Up @@ -185,6 +373,8 @@ def add_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
The parameter modlist is similar to the one passed to modify(),
except that no operation integer need be included in the tuples.
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=False)
return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand All @@ -209,6 +399,7 @@ def simple_bind(self,who='',cred='',serverctrls=None,clientctrls=None):
"""
simple_bind([who='' [,cred='']]) -> int
"""
who, cred = self._bytesify_inputs(who, cred)
return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def simple_bind_s(self,who='',cred='',serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -285,6 +476,7 @@ def compare_ext(self,dn,attr,value,serverctrls=None,clientctrls=None):
A design bug in the library prevents value from containing
nul characters.
"""
dn, attr = self._bytesify_inputs(dn, attr)
return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -315,6 +507,7 @@ def delete_ext(self,dn,serverctrls=None,clientctrls=None):
form returns the message id of the initiated request, and the
result can be obtained from a subsequent call to result().
"""
dn = self._bytesify_input(dn)
return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def delete_ext_s(self,dn,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -363,6 +556,8 @@ def modify_ext(self,dn,modlist,serverctrls=None,clientctrls=None):
"""
modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int
"""
dn = self._bytesify_input(dn)
modlist = self._bytesify_modlist(modlist, with_opcode=True)
return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -416,6 +611,7 @@ def modrdn_s(self,dn,newrdn,delold=1):
return self.rename_s(dn,newrdn,None,delold)

def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
user, oldpw, newpw = self._bytesify_inputs(user, oldpw, newpw)
return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def passwd_s(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
Expand All @@ -437,6 +633,7 @@ def rename(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls
This actually corresponds to the rename* routines in the
LDAP-EXT C API library.
"""
dn, newrdn, newsuperior = self._bytesify_inputs(dn, newrdn, newsuperior)
return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))

def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None):
Expand Down Expand Up @@ -525,6 +722,8 @@ def result4(self,msgid=ldap.RES_ANY,all=1,timeout=None,add_ctrls=0,add_intermedi
if add_ctrls:
resp_data = [ (t,r,DecodeControlTuples(c,resp_ctrl_classes)) for t,r,c in resp_data ]
decoded_resp_ctrls = DecodeControlTuples(resp_ctrls,resp_ctrl_classes)
if resp_data is not None:
resp_data = self._bytesify_results(resp_data, with_ctrls=add_ctrls)
return resp_type, resp_data, resp_msgid, decoded_resp_ctrls, resp_name, resp_value

def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0):
Expand Down Expand Up @@ -572,6 +771,9 @@ def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrson
The amount of search results retrieved can be limited with the
sizelimit parameter if non-zero.
"""
base, filterstr = self._bytesify_inputs(base, filterstr)
if attrlist is not None:
attrlist = tuple(self._bytesify_inputs(*attrlist))
return self._ldap_call(
self._l.search_ext,
base,scope,filterstr,
Expand Down Expand Up @@ -665,6 +867,8 @@ def search_subschemasubentry_s(self,dn=''):
None as result indicates that the DN of the sub schema sub entry could
not be determined.
Returns: None or text/bytes depending on bytes_mode.
"""
try:
r = self.search_s(
Expand All @@ -686,7 +890,9 @@ def search_subschemasubentry_s(self,dn=''):
# If dn was already root DSE we can return here
return None
else:
return search_subschemasubentry_dn
# With legacy bytes mode, return bytes; otherwise, since this is a DN,
# RFCs impose that the field value *can* be decoded to UTF-8.
return self._unbytesify_text_value(search_subschemasubentry_dn)
except IndexError:
return None

Expand Down Expand Up @@ -788,7 +994,7 @@ class ReconnectLDAPObject(SimpleLDAPObject):

def __init__(
self,uri,
trace_level=0,trace_file=None,trace_stack_limit=5,
trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None,
retry_max=1,retry_delay=60.0
):
"""
Expand All @@ -803,7 +1009,7 @@ def __init__(
self._uri = uri
self._options = []
self._last_bind = None
SimpleLDAPObject.__init__(self,uri,trace_level,trace_file,trace_stack_limit)
SimpleLDAPObject.__init__(self,uri,trace_level,trace_file,trace_stack_limit,bytes_mode)
self._reconnect_lock = ldap.LDAPLock(desc='reconnect lock within %s' % (repr(self)))
self._retry_max = retry_max
self._retry_delay = retry_delay
Expand Down
Loading

0 comments on commit fa35757

Please sign in to comment.