From 528abad2ada870640de70487022e8113af2f98e7 Mon Sep 17 00:00:00 2001 From: navidgh67 Date: Fri, 12 Jun 2026 11:29:40 -0400 Subject: [PATCH] Fix duplicate user identities --- api/metadata_pg.py | 22 ++- api/routers/admin.py | 116 ++++-------- api/tests/test_phase3_security.py | 36 ++-- api/tests/test_user_identity.py | 123 +++++++++++++ api/user_identity.py | 292 ++++++++++++++++++++++++++++++ web/auth.ts | 11 +- 6 files changed, 497 insertions(+), 103 deletions(-) create mode 100644 api/tests/test_user_identity.py create mode 100644 api/user_identity.py diff --git a/api/metadata_pg.py b/api/metadata_pg.py index 9273a96..871f077 100644 --- a/api/metadata_pg.py +++ b/api/metadata_pg.py @@ -707,16 +707,30 @@ def list_equipment_access_pg(equipment_id: str) -> list[dict[str, Any]]: def search_users_pg(query: str, limit: int = 10) -> list[dict[str, Any]]: - """Search active users by name or email for the maintainer picker.""" + """Search canonical users by name or email for the maintainer picker.""" conn = get_pg_connection() try: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute( """ SELECT id, name, email, organization - FROM users - WHERE status = 'active' - AND (name ILIKE %s OR email ILIKE %s) + FROM ( + SELECT DISTINCT ON (lower(email)) + id, + name, + email, + organization, + status, + last_active + FROM users + WHERE status IN ('active', 'invited') + AND (name ILIKE %s OR email ILIKE %s) + ORDER BY + lower(email), + (id = split_part(lower(email), '@', 1)) DESC, + (status = 'active') DESC, + last_active DESC + ) canonical_users ORDER BY name LIMIT %s """, diff --git a/api/routers/admin.py b/api/routers/admin.py index 6d08932..fb3f034 100644 --- a/api/routers/admin.py +++ b/api/routers/admin.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException from psycopg2.extras import RealDictCursor from pydantic import BaseModel, EmailStr -from data_loader_pg import get_pg_connection +from data_loader_pg import get_pg_connection, get_pg_superuser_connection from domain_configs import CONFIGS_DIR from metadata_pg import ( list_execution_queue_pg, @@ -20,6 +20,7 @@ require_ingestion_token, require_system_token, ) +from user_identity import sync_canonical_user logger = logging.getLogger(__name__) router = APIRouter(prefix="/admin", tags=["admin"]) @@ -39,63 +40,26 @@ def get_or_create_user_role( If the user does not exist, registers them automatically as 'researcher'. """ try: - conn = get_pg_connection() + conn = get_pg_superuser_connection() with conn.cursor(cursor_factory=RealDictCursor) as cur: - resolved_user_id = user_id or email.split("@")[0] - - cur.execute( - """ - SELECT id, role, organization - FROM users - WHERE email = %s OR id = %s - ORDER BY last_active DESC - LIMIT 1 - """, - (email, resolved_user_id), + user = sync_canonical_user( + cur, + email=email, + requested_user_id=user_id, + name=name, + organization=organization, + requested_status="active", ) - user = cur.fetchone() - - if user: - cur.execute( - """ - UPDATE users - SET - last_active = CURRENT_TIMESTAMP, - name = COALESCE(%s, name), - organization = COALESCE(%s, organization) - WHERE id = %s - """, - (name, organization, user["id"]), - ) - conn.commit() - return { - "role": user["role"], - "organization": user["organization"] or organization or "Purdue University", - "status": "active", - "user_id": user["id"], - } - - # Auto-register new user - display_name = name if name else email.split('@')[0] - org = organization if organization else "Purdue University" - - cur.execute( - """ - INSERT INTO users (id, name, email, role, organization, status) - VALUES (%s, %s, %s, 'researcher', %s, 'active') - RETURNING id, role, organization - """, - (resolved_user_id, display_name, email, org) - ) - created = cur.fetchone() - conn.commit() - return { - "role": created["role"], - "organization": created["organization"], - "status": "active", - "is_new": True, - "user_id": created["id"], - } + conn.commit() + response = { + "role": user["role"], + "organization": user["organization"] or organization or "Purdue University", + "status": user["status"], + "user_id": user["id"], + } + if user.get("is_new"): + response["is_new"] = True + return response except Exception as e: logger.error(f"Error syncing user role for {email}: {e}") raise HTTPException(status_code=500, detail="Database lookup failed") @@ -111,7 +75,14 @@ def get_users( try: conn = get_pg_connection() with conn.cursor(cursor_factory=RealDictCursor) as cur: - cur.execute("SELECT id, name, email, role, organization, status, joined_at, last_active FROM users ORDER BY joined_at DESC") + cur.execute( + """ + SELECT id, name, email, role, organization, status, joined_at, last_active + FROM users + WHERE status <> 'merged' + ORDER BY joined_at DESC + """ + ) rows = cur.fetchall() records = [] @@ -142,31 +113,18 @@ def invite_user( ) -> Dict[str, Any]: """Create or update a platform user record before first login.""" try: - conn = get_pg_connection() + conn = get_pg_superuser_connection() with conn.cursor(cursor_factory=RealDictCursor) as cur: inferred_user_id = payload.email.split("@")[0] - display_name = payload.name or inferred_user_id - - cur.execute( - """ - INSERT INTO users (id, name, email, role, organization, status) - VALUES (%s, %s, %s, %s, %s, 'invited') - ON CONFLICT (id) DO UPDATE SET - name = EXCLUDED.name, - email = EXCLUDED.email, - role = EXCLUDED.role, - organization = EXCLUDED.organization - RETURNING id, name, email, role, organization, status, joined_at, last_active - """, - ( - inferred_user_id, - display_name, - payload.email, - payload.role, - payload.organization, - ), + user = sync_canonical_user( + cur, + email=str(payload.email), + requested_user_id=inferred_user_id, + name=payload.name or inferred_user_id, + organization=payload.organization, + requested_role=payload.role, + requested_status="invited", ) - user = cur.fetchone() conn.commit() record = dict(user) diff --git a/api/tests/test_phase3_security.py b/api/tests/test_phase3_security.py index a3223ca..4d7cbab 100644 --- a/api/tests/test_phase3_security.py +++ b/api/tests/test_phase3_security.py @@ -119,21 +119,23 @@ def test_role_sync_rejects_hardcoded_dev_system_token(self): self.assertEqual(response.json()["detail"], "Invalid system token") def test_role_sync_creates_new_user_with_system_token(self): - connection = FakeConnection( - [ - {"contains": "SELECT id, role, organization", "fetchone": None}, - { - "contains": "INSERT INTO users", - "fetchone": { - "id": "alice", - "role": "researcher", - "organization": "Birck", - }, - }, - ] - ) + connection = FakeConnection([]) - with patch.object(admin, "get_pg_connection", return_value=connection): + with patch.object( + admin, + "get_pg_superuser_connection", + return_value=connection, + ), patch.object( + admin, + "sync_canonical_user", + return_value={ + "id": "alice", + "role": "researcher", + "organization": "Birck", + "status": "active", + "is_new": True, + }, + ) as sync_canonical_user: response = self.client.get( "/api/admin/users/alice@example.com/role", headers={"X-System-Token": SYSTEM_TOKEN}, @@ -153,6 +155,10 @@ def test_role_sync_creates_new_user_with_system_token(self): ) self.assertTrue(connection.committed) self.assertTrue(connection.closed) + self.assertEqual( + sync_canonical_user.call_args.kwargs["requested_user_id"], + "alice", + ) def test_admin_users_requires_admin_role(self): response = self.client.get("/api/admin/users", headers=RESEARCHER_HEADERS) @@ -164,7 +170,7 @@ def test_admin_users_allows_admin_role(self): connection = FakeConnection( [ { - "contains": "SELECT id, name, email, role, organization, status, joined_at, last_active FROM users", + "contains": "FROM users", "fetchall": [ { "id": "admin-1", diff --git a/api/tests/test_user_identity.py b/api/tests/test_user_identity.py new file mode 100644 index 0000000..4384cb0 --- /dev/null +++ b/api/tests/test_user_identity.py @@ -0,0 +1,123 @@ +import sys +import unittest +from datetime import datetime, timezone +from pathlib import Path + +API_DIR = Path(__file__).resolve().parents[1] +if str(API_DIR) not in sys.path: + sys.path.insert(0, str(API_DIR)) + +from user_identity import canonical_user_id, highest_role, sync_canonical_user + + +class FakeCursor: + def __init__(self, users, available_columns): + self.users = users + self.available_columns = available_columns + self.executions = [] + self.current_result = [] + self.returned_user = None + + def execute(self, query, params=None): + rendered = repr(query) if hasattr(query, "as_string") else str(query) + self.executions.append((rendered, params)) + if "FROM information_schema.columns" in rendered: + self.current_result = [ + {"table_name": table, "column_name": column} + for table, column in self.available_columns + ] + elif "FROM users" in rendered and "FOR UPDATE" in rendered: + self.current_result = self.users + elif "RETURNING id, name, email, role" in rendered: + user_id = params[-1] if "UPDATE users" in rendered else params[0] + self.returned_user = { + "id": user_id, + "name": params[0] if "UPDATE users" in rendered else params[1], + "email": params[1] if "UPDATE users" in rendered else params[2], + "role": params[2] if "UPDATE users" in rendered else params[3], + "organization": params[3] if "UPDATE users" in rendered else params[4], + "status": params[4] if "UPDATE users" in rendered else params[5], + "joined_at": datetime.now(timezone.utc), + "last_active": datetime.now(timezone.utc), + } + self.current_result = [self.returned_user] + else: + self.current_result = [] + + def fetchall(self): + return self.current_result + + def fetchone(self): + return self.current_result[0] if self.current_result else None + + +class UserIdentityTests(unittest.TestCase): + def test_canonical_user_id_prefers_requested_username(self): + self.assertEqual( + canonical_user_id("hosler0@purdue.edu", "hosler0"), + "hosler0", + ) + + def test_highest_role_preserves_existing_privilege(self): + self.assertEqual( + highest_role( + [{"role": "researcher"}, {"role": "equipment_owner"}], + "researcher", + ), + "equipment_owner", + ) + + def test_duplicate_identity_references_move_to_canonical_user(self): + duplicate_id = "7bcdc8cb-4cf0-4435-90f7-9802716e61aa" + cursor = FakeCursor( + users=[ + { + "id": "hosler0", + "name": "Richard S Hosler", + "email": "hosler0@purdue.edu", + "role": "equipment_owner", + "organization": "Purdue University", + "status": "invited", + }, + { + "id": duplicate_id, + "name": "Richard S Hosler", + "email": "hosler0@purdue.edu", + "role": "equipment_owner", + "organization": "Purdue University", + "status": "active", + }, + ], + available_columns={ + ("equipment_access", "user_id"), + ("equipment_access", "granted_by"), + ("equipment_metadata", "owner_id"), + ("project_members", "nanohub_user_id"), + }, + ) + + result = sync_canonical_user( + cursor, + email="hosler0@purdue.edu", + requested_user_id="hosler0", + name="Richard S Hosler", + organization="Purdue University", + ) + + self.assertEqual(result["id"], "hosler0") + self.assertEqual(result["role"], "equipment_owner") + sql_calls = "\n".join(query for query, _ in cursor.executions) + self.assertIn("INSERT INTO equipment_access", sql_calls) + self.assertIn("INSERT INTO project_members", sql_calls) + self.assertIn("Identifier('equipment_metadata')", sql_calls) + self.assertIn("status = 'merged'", sql_calls) + self.assertTrue( + any( + params == ("hosler0", [duplicate_id]) + for _, params in cursor.executions + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/api/user_identity.py b/api/user_identity.py new file mode 100644 index 0000000..b813a54 --- /dev/null +++ b/api/user_identity.py @@ -0,0 +1,292 @@ +import hashlib +from typing import Any + +from psycopg2 import sql + + +ROLE_PRIORITY = { + "researcher": 0, + "equipment_owner": 1, + "pi": 2, + "admin": 3, +} + +IDENTITY_COLUMNS = ( + ("data_uploads", "owner_id"), + ("equipment_access", "granted_by"), + ("equipment_metadata", "owner_id"), + ("equipment_metadata", "reviewed_by"), + ("equipment_registrations", "reviewed_by"), + ("equipment_registrations", "submitted_by"), + ("experiment_definitions", "facility_reviewer_id"), + ("experiment_definitions", "owner_id"), + ("experiment_type_versions", "changed_by"), + ("experiment_types", "owner_id"), + ("models", "trained_by"), + ("process_definitions", "owner_id"), + ("project_experiments", "added_by"), + ("purr_publication_requests", "approved_by"), + ("purr_publication_requests", "requested_by"), + ("samples", "owner_id"), + ("user_role_audit", "changed_by"), + ("user_role_audit", "user_id"), +) + + +def canonical_user_id(email: str, requested_user_id: str | None = None) -> str: + requested = (requested_user_id or "").strip() + if requested: + return requested + return email.strip().lower().split("@", 1)[0] + + +def highest_role(rows: list[dict[str, Any]], fallback: str) -> str: + roles = [str(row.get("role") or "") for row in rows] + roles.append(fallback) + return max(roles, key=lambda role: ROLE_PRIORITY.get(role, -1)) + + +def _available_identity_columns(cur) -> set[tuple[str, str]]: + cur.execute( + """ + SELECT table_name, column_name + FROM information_schema.columns + WHERE table_schema = 'public' + """ + ) + return { + (row["table_name"], row["column_name"]) + for row in cur.fetchall() + } + + +def _merge_equipment_access( + cur, + *, + canonical_id: str, + duplicate_ids: list[str], + available_columns: set[tuple[str, str]], +) -> None: + if not duplicate_ids or ("equipment_access", "user_id") not in available_columns: + return + + cur.execute( + """ + INSERT INTO equipment_access ( + equipment_id, + user_id, + role, + granted_by, + created_at + ) + SELECT + DISTINCT ON (equipment_id) + equipment_id, + %s, + role, + granted_by, + created_at + FROM equipment_access + WHERE user_id = ANY(%s) + ORDER BY equipment_id, created_at + ON CONFLICT (equipment_id, user_id) DO UPDATE SET + role = EXCLUDED.role, + created_at = LEAST(equipment_access.created_at, EXCLUDED.created_at) + """, + (canonical_id, duplicate_ids), + ) + cur.execute( + "DELETE FROM equipment_access WHERE user_id = ANY(%s)", + (duplicate_ids,), + ) + + +def _merge_project_members( + cur, + *, + canonical_id: str, + duplicate_ids: list[str], + available_columns: set[tuple[str, str]], +) -> None: + if not duplicate_ids or ("project_members", "nanohub_user_id") not in available_columns: + return + + cur.execute( + """ + INSERT INTO project_members (project_id, nanohub_user_id, role) + SELECT + project_id, + %s, + CASE + WHEN bool_or(role = 'pi') THEN 'pi' + ELSE max(role) + END + FROM project_members + WHERE nanohub_user_id = ANY(%s) + GROUP BY project_id + ON CONFLICT (project_id, nanohub_user_id) DO UPDATE SET + role = CASE + WHEN project_members.role = 'pi' OR EXCLUDED.role = 'pi' THEN 'pi' + ELSE project_members.role + END + """, + (canonical_id, duplicate_ids), + ) + cur.execute( + "DELETE FROM project_members WHERE nanohub_user_id = ANY(%s)", + (duplicate_ids,), + ) + + +def sync_canonical_user( + cur, + *, + email: str, + requested_user_id: str | None, + name: str | None, + organization: str | None, + requested_role: str = "researcher", + requested_status: str = "active", +) -> dict[str, Any]: + normalized_email = email.strip().lower() + resolved_user_id = canonical_user_id(normalized_email, requested_user_id) + + # Serialize role sync/invite operations for the same email. + cur.execute( + "SELECT pg_advisory_xact_lock(hashtext(%s))", + (normalized_email,), + ) + cur.execute( + """ + SELECT + id, + name, + email, + role, + organization, + status, + joined_at, + last_active + FROM users + WHERE lower(email) = %s OR id = %s + FOR UPDATE + """, + (normalized_email, resolved_user_id), + ) + rows = [dict(row) for row in cur.fetchall()] + is_new_user = not rows + canonical = next( + (row for row in rows if row["id"] == resolved_user_id), + None, + ) + merged_role = highest_role(rows, requested_role) + merged_name = name or (canonical or {}).get("name") + if not merged_name: + merged_name = next( + (row.get("name") for row in rows if row.get("name")), + resolved_user_id, + ) + merged_org = organization or (canonical or {}).get("organization") + if not merged_org: + merged_org = next( + (row.get("organization") for row in rows if row.get("organization")), + "Purdue University", + ) + merged_status = ( + "active" + if requested_status == "active" + or any(row.get("status") == "active" for row in rows) + else requested_status + ) + + if canonical: + cur.execute( + """ + UPDATE users + SET + name = %s, + email = %s, + role = %s, + organization = %s, + status = %s, + last_active = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, name, email, role, organization, status, joined_at, last_active + """, + ( + merged_name, + normalized_email, + merged_role, + merged_org, + merged_status, + resolved_user_id, + ), + ) + else: + cur.execute( + """ + INSERT INTO users (id, name, email, role, organization, status) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING id, name, email, role, organization, status, joined_at, last_active + """, + ( + resolved_user_id, + merged_name, + normalized_email, + merged_role, + merged_org, + merged_status, + ), + ) + result = dict(cur.fetchone()) + result["is_new"] = is_new_user + + duplicate_ids = [ + row["id"] + for row in rows + if row["id"] != resolved_user_id + ] + if not duplicate_ids: + return result + + available_columns = _available_identity_columns(cur) + _merge_equipment_access( + cur, + canonical_id=resolved_user_id, + duplicate_ids=duplicate_ids, + available_columns=available_columns, + ) + _merge_project_members( + cur, + canonical_id=resolved_user_id, + duplicate_ids=duplicate_ids, + available_columns=available_columns, + ) + + for table_name, column_name in IDENTITY_COLUMNS: + if (table_name, column_name) not in available_columns: + continue + cur.execute( + sql.SQL("UPDATE {} SET {} = %s WHERE {} = ANY(%s)").format( + sql.Identifier(table_name), + sql.Identifier(column_name), + sql.Identifier(column_name), + ), + (resolved_user_id, duplicate_ids), + ) + + for duplicate_id in duplicate_ids: + retired_email = ( + f"merged+{hashlib.sha256(duplicate_id.encode('utf-8')).hexdigest()[:20]}" + "@invalid.local" + ) + cur.execute( + """ + UPDATE users + SET email = %s, status = 'merged', last_active = CURRENT_TIMESTAMP + WHERE id = %s + """, + (retired_email, duplicate_id), + ) + + return result diff --git a/web/auth.ts b/web/auth.ts index fecf72f..f1facd2 100644 --- a/web/auth.ts +++ b/web/auth.ts @@ -311,7 +311,13 @@ export const { handlers, auth, signIn, signOut } = NextAuth({ const authUser = coerceUser(user); const resolvedEmail = authUser.email ?? token.email ?? undefined; const resolvedName = authUser.name ?? token.name ?? undefined; + const resolvedUsername = + authUser.nanohubUsername ?? + (typeof token.nanohubUsername === "string" + ? token.nanohubUsername + : undefined); const resolvedUserId = + resolvedUsername ?? authUser.id ?? (typeof token.id === "string" ? token.id : undefined) ?? resolvedEmail; @@ -321,11 +327,6 @@ export const { handlers, auth, signIn, signOut } = NextAuth({ ? token.organization : undefined) ?? DEFAULT_ORGANIZATION; - const resolvedUsername = - authUser.nanohubUsername ?? - (typeof token.nanohubUsername === "string" - ? token.nanohubUsername - : undefined); if (resolvedUserId) { token.id = resolvedUserId;