Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"purr_publication_requests",
"data_uploads",
"ingestion_logs",
"equipment_access",
)

STARTUP_MIGRATION_LOCK_ID = 6420260506
Expand Down Expand Up @@ -154,6 +155,7 @@ def _run_startup_migrations() -> None:
ensure_run_version_tracking_columns_pg,
ensure_user_role_audit_table_pg,
ensure_sample_and_publication_tables_pg,
ensure_equipment_access_table_pg,
)

lock_conn = get_pg_superuser_connection()
Expand All @@ -176,6 +178,7 @@ def _run_startup_migrations() -> None:
ensure_run_version_tracking_columns_pg()
ensure_user_role_audit_table_pg()
ensure_sample_and_publication_tables_pg()
ensure_equipment_access_table_pg()
finally:
try:
with lock_conn.cursor() as cur:
Expand Down
243 changes: 241 additions & 2 deletions api/metadata_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,188 @@ def ensure_sample_and_publication_tables_pg() -> None:
conn.close()


def ensure_equipment_access_table_pg() -> None:
"""Create the equipment_access table for trusted maintainers."""
conn = get_pg_superuser_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS equipment_access (
equipment_id VARCHAR(255) NOT NULL REFERENCES equipment_metadata(domain_id) ON DELETE CASCADE,
user_id VARCHAR(255) NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role VARCHAR(50) NOT NULL DEFAULT 'editor',
granted_by VARCHAR(255) DEFAULT '',
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (equipment_id, user_id)
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_equipment_access_user
ON equipment_access(user_id)
"""
)
cur.execute(
"GRANT SELECT, INSERT, UPDATE, DELETE ON equipment_access TO api_client"
)
conn.commit()
finally:
conn.close()


def grant_equipment_access_pg(
equipment_id: str,
user_id: str,
granted_by: str,
) -> dict[str, Any]:
"""Grant a user trusted-maintainer access to a piece of equipment."""
conn = get_pg_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
# Verify equipment exists
cur.execute(
"SELECT domain_id, owner_id FROM equipment_metadata WHERE domain_id = %s",
(equipment_id,),
)
eq = cur.fetchone()
if not eq:
return {"status": "not_found", "detail": "Equipment not found"}

# Verify target user exists
cur.execute(
"SELECT id, name, email, organization FROM users WHERE id = %s",
(user_id,),
)
target_user = cur.fetchone()
if not target_user:
return {"status": "user_not_found", "detail": "User not found"}

# Don't add the owner as a maintainer
if eq["owner_id"] == user_id:
return {"status": "is_owner", "detail": "User is already the equipment owner"}

cur.execute(
"""
INSERT INTO equipment_access (equipment_id, user_id, role, granted_by)
VALUES (%s, %s, 'editor', %s)
ON CONFLICT (equipment_id, user_id) DO NOTHING
RETURNING equipment_id, user_id, role, granted_by, created_at
""",
(equipment_id, user_id, granted_by),
)
row = cur.fetchone()
conn.commit()

if not row:
return {
"status": "already_exists",
"detail": "User already has access",
"user_id": user_id,
"name": target_user["name"],
"email": target_user["email"],
}

return {
"status": "granted",
"equipment_id": equipment_id,
"user_id": user_id,
"name": target_user["name"],
"email": target_user["email"],
"organization": target_user["organization"],
"role": row["role"],
"granted_by": row["granted_by"],
"created_at": _iso(row["created_at"]),
}
finally:
conn.close()


def revoke_equipment_access_pg(
equipment_id: str,
user_id: str,
) -> dict[str, Any]:
"""Remove a user's trusted-maintainer access from a piece of equipment."""
conn = get_pg_connection()
try:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM equipment_access WHERE equipment_id = %s AND user_id = %s",
(equipment_id, user_id),
)
deleted = cur.rowcount
conn.commit()
if not deleted:
return {"status": "not_found", "detail": "Access record not found"}
return {"status": "revoked", "equipment_id": equipment_id, "user_id": user_id}
finally:
conn.close()


def list_equipment_access_pg(equipment_id: str) -> list[dict[str, Any]]:
"""List all trusted maintainers for a piece of equipment."""
conn = get_pg_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT
ea.user_id,
ea.role,
ea.granted_by,
ea.created_at,
u.name,
u.email,
u.organization
FROM equipment_access ea
JOIN users u ON u.id = ea.user_id
WHERE ea.equipment_id = %s
ORDER BY ea.created_at
""",
(equipment_id,),
)
rows = cur.fetchall()
conn.commit()
return [
{
"user_id": row["user_id"],
"name": row["name"],
"email": row["email"],
"organization": row["organization"],
"role": row["role"],
"granted_by": row["granted_by"],
"created_at": _iso(row["created_at"]),
}
for row in rows
]
finally:
conn.close()


def search_users_pg(query: str, limit: int = 10) -> list[dict[str, Any]]:
"""Search active users by name or email for the maintainer picker."""
conn = get_pg_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT id, name, email, organization
FROM users
WHERE status = 'active'
AND (name ILIKE %s OR email ILIKE %s)
ORDER BY name
LIMIT %s
""",
(f"%{query}%", f"%{query}%", limit),
)
rows = cur.fetchall()
conn.commit()
return [dict(row) for row in rows]
finally:
conn.close()


def update_experiment_proposal_generation_status_pg(
*,
experiment_id: str,
Expand Down Expand Up @@ -806,6 +988,10 @@ def equipment_visible_to_user(

if owner_id == user.id:
return True
# Trusted maintainer check
maintainer_ids = record.get("_maintainer_ids", [])
if user.id in maintainer_ids:
return True
if status != "approved":
return False
if owner_org and user.organization and owner_org != user.organization:
Expand Down Expand Up @@ -866,6 +1052,7 @@ def _serialize_equipment(
"primary_target": _coerce_json(record.get("primary_target_json"), {}),
"secondary_target": _coerce_json(record.get("secondary_target_json"), {}),
"config_json": config_json,
"maintainers": record.get("_maintainers", []),
}


Expand Down Expand Up @@ -927,12 +1114,21 @@ def list_equipment_pg(user: PlatformUser) -> list[dict[str, Any]]:
)
rows = cur.fetchall()

# Fetch maintainer mappings for visibility checks
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute("SELECT equipment_id, user_id FROM equipment_access")
access_rows = cur.fetchall()
maintainer_map: dict[str, list[str]] = {}
for arow in access_rows:
maintainer_map.setdefault(arow["equipment_id"], []).append(arow["user_id"])

metrics = _fetch_equipment_run_metrics(conn)
conn.commit()

visible: list[dict[str, Any]] = []
for row in rows:
record = dict(row)
record["_maintainer_ids"] = maintainer_map.get(record["domain_id"], [])
if equipment_visible_to_user(record, user):
visible.append(
_serialize_equipment(record, metrics.get(record["domain_id"]))
Expand Down Expand Up @@ -985,9 +1181,46 @@ def get_equipment_pg(domain_id: str) -> dict[str, Any] | None:
conn.commit()
return None

# Fetch maintainer IDs for visibility check
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"SELECT user_id FROM equipment_access WHERE equipment_id = %s",
(domain_id,),
)
maintainer_rows = cur.fetchall()
record = dict(row)
record["_maintainer_ids"] = [r["user_id"] for r in maintainer_rows]

# Fetch full maintainer details for serialized output
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT ea.user_id, ea.role, ea.granted_by, ea.created_at,
u.name, u.email, u.organization
FROM equipment_access ea
JOIN users u ON u.id = ea.user_id
WHERE ea.equipment_id = %s
ORDER BY ea.created_at
""",
(domain_id,),
)
maintainer_detail_rows = cur.fetchall()
record["_maintainers"] = [
{
"user_id": r["user_id"],
"name": r["name"],
"email": r["email"],
"organization": r["organization"],
"role": r["role"],
"granted_by": r["granted_by"],
"created_at": _iso(r["created_at"]),
}
for r in maintainer_detail_rows
]

metrics = _fetch_equipment_run_metrics(conn)
conn.commit()
return _serialize_equipment(dict(row), metrics.get(domain_id))
return _serialize_equipment(record, metrics.get(domain_id))
finally:
conn.close()

Expand Down Expand Up @@ -1278,7 +1511,13 @@ def update_equipment_pg(
return {"status": "not_found", "domain_id": domain_id}

if user.role != "admin" and existing["owner_id"] != user.id:
raise PermissionError("Only the equipment owner or an admin can edit this equipment")
# Check if user is a trusted maintainer
cur.execute(
"SELECT 1 FROM equipment_access WHERE equipment_id = %s AND user_id = %s",
(domain_id, user.id),
)
if not cur.fetchone():
raise PermissionError("Only the equipment owner, a trusted maintainer, or an admin can edit this equipment")

owner_id = existing.get("owner_id") or user.id
owner_org = existing.get("owner_org") or user.organization
Expand Down
Loading