diff --git a/Lib/ldap/ldapobject.py b/Lib/ldap/ldapobject.py index f7443fa..a92b088 100644 --- a/Lib/ldap/ldapobject.py +++ b/Lib/ldap/ldapobject.py @@ -1166,14 +1166,18 @@ def reconnect(self,uri,retry_max=1,retry_delay=60.0): counter_text,uri )) try: - # Do the connect - self._l = ldap.functions._ldap_function_call(ldap._ldap_module_lock,_ldap.initialize,uri) - self._restore_options() - # StartTLS extended operation in case this was called before - if self._start_tls: - SimpleLDAPObject.start_tls_s(self) - # Repeat last simple or SASL bind - self._apply_last_bind() + try: + # Do the connect + self._l = ldap.functions._ldap_function_call(ldap._ldap_module_lock,_ldap.initialize,uri) + self._restore_options() + # StartTLS extended operation in case this was called before + if self._start_tls: + SimpleLDAPObject.start_tls_s(self) + # Repeat last simple or SASL bind + self._apply_last_bind() + except ldap.LDAPError: + SimpleLDAPObject.unbind_s(self) + raise except (ldap.SERVER_DOWN,ldap.TIMEOUT): if __debug__ and self._trace_level>=1: self._trace_file.write('*** %s reconnect to %s failed\n' % ( @@ -1185,7 +1189,6 @@ def reconnect(self,uri,retry_max=1,retry_delay=60.0): if __debug__ and self._trace_level>=1: self._trace_file.write('=> delay %s...\n' % (retry_delay)) time.sleep(retry_delay) - SimpleLDAPObject.unbind_s(self) else: if __debug__ and self._trace_level>=1: self._trace_file.write('*** %s reconnect to %s successful => repeat last operation\n' % ( diff --git a/Tests/t_ldapobject.py b/Tests/t_ldapobject.py index 67adeb2..36e2acf 100644 --- a/Tests/t_ldapobject.py +++ b/Tests/t_ldapobject.py @@ -736,6 +736,23 @@ def test104_reconnect_restore(self): l2 = pickle.loads(l1_state) self.assertEqual(l2.whoami_s(), 'dn:'+bind_dn) + def test105_reconnect_restore(self): + l1 = self.ldap_object_class(self.server.ldap_uri, retry_max=2, retry_delay=1) + bind_dn = 'cn=user1,'+self.server.suffix + l1.simple_bind_s(bind_dn, 'user1_pw') + self.assertEqual(l1.whoami_s(), 'dn:'+bind_dn) + self.server._proc.terminate() + self.server.wait() + try: + l1.whoami_s() + except ldap.SERVER_DOWN: + pass + else: + self.assertEqual(True, False) + finally: + self.server._start_slapd() + self.assertEqual(l1.whoami_s(), 'dn:'+bind_dn) + if __name__ == '__main__': unittest.main()