Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions api/metadata_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,16 +707,30 @@ def list_equipment_access_pg(equipment_id: str) -> list[dict[str, Any]]:


def search_users_pg(query: str, limit: int = 10) -> list[dict[str, Any]]:
"""Search active users by name or email for the maintainer picker."""
"""Search canonical users by name or email for the maintainer picker."""
conn = get_pg_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT id, name, email, organization
FROM users
WHERE status = 'active'
AND (name ILIKE %s OR email ILIKE %s)
FROM (
SELECT DISTINCT ON (lower(email))
id,
name,
email,
organization,
status,
last_active
FROM users
WHERE status IN ('active', 'invited')
AND (name ILIKE %s OR email ILIKE %s)
ORDER BY
lower(email),
(id = split_part(lower(email), '@', 1)) DESC,
(status = 'active') DESC,
last_active DESC
) canonical_users
ORDER BY name
LIMIT %s
""",
Expand Down
116 changes: 37 additions & 79 deletions api/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException
from psycopg2.extras import RealDictCursor
from pydantic import BaseModel, EmailStr
from data_loader_pg import get_pg_connection
from data_loader_pg import get_pg_connection, get_pg_superuser_connection
from domain_configs import CONFIGS_DIR
from metadata_pg import (
list_execution_queue_pg,
Expand All @@ -20,6 +20,7 @@
require_ingestion_token,
require_system_token,
)
from user_identity import sync_canonical_user

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin", tags=["admin"])
Expand All @@ -39,63 +40,26 @@ def get_or_create_user_role(
If the user does not exist, registers them automatically as 'researcher'.
"""
try:
conn = get_pg_connection()
conn = get_pg_superuser_connection()
with conn.cursor(cursor_factory=RealDictCursor) as cur:
resolved_user_id = user_id or email.split("@")[0]

cur.execute(
"""
SELECT id, role, organization
FROM users
WHERE email = %s OR id = %s
ORDER BY last_active DESC
LIMIT 1
""",
(email, resolved_user_id),
user = sync_canonical_user(
cur,
email=email,
requested_user_id=user_id,
name=name,
organization=organization,
requested_status="active",
)
user = cur.fetchone()

if user:
cur.execute(
"""
UPDATE users
SET
last_active = CURRENT_TIMESTAMP,
name = COALESCE(%s, name),
organization = COALESCE(%s, organization)
WHERE id = %s
""",
(name, organization, user["id"]),
)
conn.commit()
return {
"role": user["role"],
"organization": user["organization"] or organization or "Purdue University",
"status": "active",
"user_id": user["id"],
}

# Auto-register new user
display_name = name if name else email.split('@')[0]
org = organization if organization else "Purdue University"

cur.execute(
"""
INSERT INTO users (id, name, email, role, organization, status)
VALUES (%s, %s, %s, 'researcher', %s, 'active')
RETURNING id, role, organization
""",
(resolved_user_id, display_name, email, org)
)
created = cur.fetchone()
conn.commit()
return {
"role": created["role"],
"organization": created["organization"],
"status": "active",
"is_new": True,
"user_id": created["id"],
}
conn.commit()
response = {
"role": user["role"],
"organization": user["organization"] or organization or "Purdue University",
"status": user["status"],
"user_id": user["id"],
}
if user.get("is_new"):
response["is_new"] = True
return response
except Exception as e:
logger.error(f"Error syncing user role for {email}: {e}")
raise HTTPException(status_code=500, detail="Database lookup failed")
Expand All @@ -111,7 +75,14 @@ def get_users(
try:
conn = get_pg_connection()
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute("SELECT id, name, email, role, organization, status, joined_at, last_active FROM users ORDER BY joined_at DESC")
cur.execute(
"""
SELECT id, name, email, role, organization, status, joined_at, last_active
FROM users
WHERE status <> 'merged'
ORDER BY joined_at DESC
"""
)
rows = cur.fetchall()

records = []
Expand Down Expand Up @@ -142,31 +113,18 @@ def invite_user(
) -> Dict[str, Any]:
"""Create or update a platform user record before first login."""
try:
conn = get_pg_connection()
conn = get_pg_superuser_connection()
with conn.cursor(cursor_factory=RealDictCursor) as cur:
inferred_user_id = payload.email.split("@")[0]
display_name = payload.name or inferred_user_id

cur.execute(
"""
INSERT INTO users (id, name, email, role, organization, status)
VALUES (%s, %s, %s, %s, %s, 'invited')
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
email = EXCLUDED.email,
role = EXCLUDED.role,
organization = EXCLUDED.organization
RETURNING id, name, email, role, organization, status, joined_at, last_active
""",
(
inferred_user_id,
display_name,
payload.email,
payload.role,
payload.organization,
),
user = sync_canonical_user(
cur,
email=str(payload.email),
requested_user_id=inferred_user_id,
name=payload.name or inferred_user_id,
organization=payload.organization,
requested_role=payload.role,
requested_status="invited",
)
user = cur.fetchone()
conn.commit()

record = dict(user)
Expand Down
36 changes: 21 additions & 15 deletions api/tests/test_phase3_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,23 @@ def test_role_sync_rejects_hardcoded_dev_system_token(self):
self.assertEqual(response.json()["detail"], "Invalid system token")

def test_role_sync_creates_new_user_with_system_token(self):
connection = FakeConnection(
[
{"contains": "SELECT id, role, organization", "fetchone": None},
{
"contains": "INSERT INTO users",
"fetchone": {
"id": "alice",
"role": "researcher",
"organization": "Birck",
},
},
]
)
connection = FakeConnection([])

with patch.object(admin, "get_pg_connection", return_value=connection):
with patch.object(
admin,
"get_pg_superuser_connection",
return_value=connection,
), patch.object(
admin,
"sync_canonical_user",
return_value={
"id": "alice",
"role": "researcher",
"organization": "Birck",
"status": "active",
"is_new": True,
},
) as sync_canonical_user:
response = self.client.get(
"/api/admin/users/alice@example.com/role",
headers={"X-System-Token": SYSTEM_TOKEN},
Expand All @@ -153,6 +155,10 @@ def test_role_sync_creates_new_user_with_system_token(self):
)
self.assertTrue(connection.committed)
self.assertTrue(connection.closed)
self.assertEqual(
sync_canonical_user.call_args.kwargs["requested_user_id"],
"alice",
)

def test_admin_users_requires_admin_role(self):
response = self.client.get("/api/admin/users", headers=RESEARCHER_HEADERS)
Expand All @@ -164,7 +170,7 @@ def test_admin_users_allows_admin_role(self):
connection = FakeConnection(
[
{
"contains": "SELECT id, name, email, role, organization, status, joined_at, last_active FROM users",
"contains": "FROM users",
"fetchall": [
{
"id": "admin-1",
Expand Down
123 changes: 123 additions & 0 deletions api/tests/test_user_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import sys
import unittest
from datetime import datetime, timezone
from pathlib import Path

API_DIR = Path(__file__).resolve().parents[1]
if str(API_DIR) not in sys.path:
sys.path.insert(0, str(API_DIR))

from user_identity import canonical_user_id, highest_role, sync_canonical_user


class FakeCursor:
def __init__(self, users, available_columns):
self.users = users
self.available_columns = available_columns
self.executions = []
self.current_result = []
self.returned_user = None

def execute(self, query, params=None):
rendered = repr(query) if hasattr(query, "as_string") else str(query)
self.executions.append((rendered, params))
if "FROM information_schema.columns" in rendered:
self.current_result = [
{"table_name": table, "column_name": column}
for table, column in self.available_columns
]
elif "FROM users" in rendered and "FOR UPDATE" in rendered:
self.current_result = self.users
elif "RETURNING id, name, email, role" in rendered:
user_id = params[-1] if "UPDATE users" in rendered else params[0]
self.returned_user = {
"id": user_id,
"name": params[0] if "UPDATE users" in rendered else params[1],
"email": params[1] if "UPDATE users" in rendered else params[2],
"role": params[2] if "UPDATE users" in rendered else params[3],
"organization": params[3] if "UPDATE users" in rendered else params[4],
"status": params[4] if "UPDATE users" in rendered else params[5],
"joined_at": datetime.now(timezone.utc),
"last_active": datetime.now(timezone.utc),
}
self.current_result = [self.returned_user]
else:
self.current_result = []

def fetchall(self):
return self.current_result

def fetchone(self):
return self.current_result[0] if self.current_result else None


class UserIdentityTests(unittest.TestCase):
def test_canonical_user_id_prefers_requested_username(self):
self.assertEqual(
canonical_user_id("hosler0@purdue.edu", "hosler0"),
"hosler0",
)

def test_highest_role_preserves_existing_privilege(self):
self.assertEqual(
highest_role(
[{"role": "researcher"}, {"role": "equipment_owner"}],
"researcher",
),
"equipment_owner",
)

def test_duplicate_identity_references_move_to_canonical_user(self):
duplicate_id = "7bcdc8cb-4cf0-4435-90f7-9802716e61aa"
cursor = FakeCursor(
users=[
{
"id": "hosler0",
"name": "Richard S Hosler",
"email": "hosler0@purdue.edu",
"role": "equipment_owner",
"organization": "Purdue University",
"status": "invited",
},
{
"id": duplicate_id,
"name": "Richard S Hosler",
"email": "hosler0@purdue.edu",
"role": "equipment_owner",
"organization": "Purdue University",
"status": "active",
},
],
available_columns={
("equipment_access", "user_id"),
("equipment_access", "granted_by"),
("equipment_metadata", "owner_id"),
("project_members", "nanohub_user_id"),
},
)

result = sync_canonical_user(
cursor,
email="hosler0@purdue.edu",
requested_user_id="hosler0",
name="Richard S Hosler",
organization="Purdue University",
)

self.assertEqual(result["id"], "hosler0")
self.assertEqual(result["role"], "equipment_owner")
sql_calls = "\n".join(query for query, _ in cursor.executions)
self.assertIn("INSERT INTO equipment_access", sql_calls)
self.assertIn("INSERT INTO project_members", sql_calls)
self.assertIn("Identifier('equipment_metadata')", sql_calls)
self.assertIn("status = 'merged'", sql_calls)
self.assertTrue(
any(
params == ("hosler0", [duplicate_id])
for _, params in cursor.executions
)
)


if __name__ == "__main__":
unittest.main()
Loading