import sys from datetime import date from sqlalchemy import text, inspect from .db import engine, SessionLocal from .models import TimesheetPeriod, TimeEntry, WeekAssignment, EmployeePeriodSetting, TimesheetStatus from .utils import _semi_monthly_period_for_date def column_exists(inspector, table, column): return any(c["name"] == column for c in inspector.get_columns(table)) def run(): print("[migrate] Starting timesheet instances migration...") insp = inspect(engine) with engine.begin() as conn: # add columns if missing if not column_exists(insp, "time_entries", "timesheet_id"): conn.execute(text("ALTER TABLE time_entries ADD COLUMN timesheet_id INTEGER")) if not column_exists(insp, "week_assignments", "timesheet_id"): conn.execute(text("ALTER TABLE week_assignments ADD COLUMN timesheet_id INTEGER")) if not column_exists(insp, "employee_period_settings", "timesheet_id"): conn.execute(text("ALTER TABLE employee_period_settings ADD COLUMN timesheet_id INTEGER")) if not column_exists(insp, "timesheet_status", "timesheet_id"): conn.execute(text("ALTER TABLE timesheet_status ADD COLUMN timesheet_id INTEGER")) # NEW: add created_at to timesheet_periods so ordering can later use it safely if not column_exists(insp, "timesheet_periods", "created_at"): conn.execute(text("ALTER TABLE timesheet_periods ADD COLUMN created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW() NOT NULL")) s = SessionLocal() try: # derive periods from entries dates = [r[0] for r in s.query(TimeEntry.work_date).order_by(TimeEntry.work_date.asc()).all()] period_keys = set() for d in dates: ps, pe = _semi_monthly_period_for_date(d) period_keys.add((ps, pe)) # also from other tables was = s.query(WeekAssignment.period_start, WeekAssignment.period_end).group_by(WeekAssignment.period_start, WeekAssignment.period_end).all() for ps, pe in was: period_keys.add((ps, pe)) sts = s.query(TimesheetStatus.period_start, TimesheetStatus.period_end).group_by(TimesheetStatus.period_start, TimesheetStatus.period_end).all() for ps, pe in sts: period_keys.add((ps, pe)) epss = s.query(EmployeePeriodSetting.period_start, EmployeePeriodSetting.period_end).group_by(EmployeePeriodSetting.period_start, EmployeePeriodSetting.period_end).all() for ps, pe in epss: period_keys.add((ps, pe)) created = {} for ps, pe in sorted(period_keys): ts = s.query(TimesheetPeriod).filter(TimesheetPeriod.period_start == ps, TimesheetPeriod.period_end == pe).first() if not ts: ts = TimesheetPeriod(period_start=ps, period_end=pe, name=f"{ps.isoformat()} .. {pe.isoformat()}") s.add(ts) s.flush() created[(ps, pe)] = ts.id # assign timesheet_id to all rows by their period entries = s.query(TimeEntry).filter(TimeEntry.timesheet_id.is_(None)).all() for e in entries: ps, pe = _semi_monthly_period_for_date(e.work_date) e.timesheet_id = created.get((ps, pe)) was = s.query(WeekAssignment).filter(WeekAssignment.timesheet_id.is_(None)).all() for w in was: w.timesheet_id = created.get((w.period_start, w.period_end)) epss = s.query(EmployeePeriodSetting).filter(EmployeePeriodSetting.timesheet_id.is_(None)).all() for ep in epss: ep.timesheet_id = created.get((ep.period_start, ep.period_end)) sts = s.query(TimesheetStatus).filter(TimesheetStatus.timesheet_id.is_(None)).all() for st in sts: st.timesheet_id = created.get((st.period_start, st.period_end)) s.commit() print(f"[migrate] Done. Created {len(created)} timesheet_periods and backfilled timesheet_id.") finally: s.close() if __name__ == "__main__": run()