from collections import defaultdict from dataclasses import dataclass from datetime import date, timedelta from typing import Dict, List, Tuple, Optional from sqlalchemy.orm import Session from sqlalchemy import func from .models import TimeEntry @dataclass class DayRow: entry_id: int work_date: date clock_in: str clock_out: str break_hours: float total_hours: float pto_hours: float pto_type: str | None holiday_hours: float bereavement_hours: float hours_paid: float def parse_period_selector(selector: Optional[str]) -> Dict: if not selector: return {"type": "current_pay_period", "label": "Current Pay Period"} if ".." in selector: s, e = selector.split("..", 1) return {"type": "range", "start": date.fromisoformat(s), "end": date.fromisoformat(e), "label": f"{s}..{e}"} if len(selector) == 7: y, m = selector.split("-") return {"type": "month", "year": int(y), "month": int(m), "label": selector} if len(selector) == 4: return {"type": "year", "year": int(selector), "label": selector} try: d = date.fromisoformat(selector) return {"type": "single", "date": d, "label": selector} except Exception: return {"type": "current_pay_period", "label": "Current Pay Period"} def _start_of_week(d: date, start_weekday: int) -> date: return d - timedelta(days=(d.weekday() - start_weekday) % 7) def compute_period_bounds(selector: Dict, pay_period_type: str, start_weekday: int) -> Tuple[date, date]: today = date.today() if selector["type"] == "range": return selector["start"], selector["end"] if selector["type"] == "month": y, m = selector["year"], selector["month"] start = date(y, m, 1) if m == 12: end = date(y, 12, 31) else: end = date(y, m + 1, 1) - timedelta(days=1) return start, end if selector["type"] == "year": y = selector["year"] return date(y, 1, 1), date(y, 12, 31) start_week = _start_of_week(today, start_weekday) if pay_period_type.upper() == "WEEKLY": return start_week, start_week + timedelta(days=6) if pay_period_type.upper() == "SEMI_MONTHLY": if today.day <= 15: return date(today.year, today.month, 1), date(today.year, today.month, 15) else: if today.month == 12: eom = date(today.year, 12, 31) else: eom = date(today.year, today.month + 1, 1) - timedelta(days=1) return date(today.year, today.month, 16), eom # BIWEEKLY default start = start_week epoch = date(2020, 1, 6) delta_weeks = ((start - epoch).days // 7) if delta_weeks % 2 != 0: start = start - timedelta(days=7) return start, start + timedelta(days=13) def default_week_ranges(start: date, end: date, start_weekday: int) -> List[Tuple[date, date]]: """ Produce contiguous week ranges inside [start, end] using start_weekday. This often yields 2–3 weeks for semi-monthly periods. """ ranges: List[Tuple[date, date]] = [] cursor = start while cursor <= end: week_start = cursor # end of week is based on the configured start_weekday end_of_week = _start_of_week(cursor, start_weekday) + timedelta(days=6) week_end = min(end, end_of_week) ranges.append((week_start, week_end)) cursor = week_end + timedelta(days=1) return ranges def group_entries_for_timesheet( entries: List[TimeEntry], start: date, end: date, pay_period_type: str, start_weekday: int, week_ranges: Optional[List[Tuple[date, date]]] = None, carry_over_hours: float = 0.0, ): rows: List[DayRow] = [] for e in entries: rows.append( DayRow( entry_id=e.id, work_date=e.work_date, clock_in=e.clock_in.strftime("%I:%M %p") if e.clock_in else "", clock_out=e.clock_out.strftime("%I:%M %p") if e.clock_out else "", break_hours=round(e.break_hours or 0.0, 2), total_hours=round(e.total_hours or 0.0, 2), pto_hours=round(e.pto_hours or 0.0, 2), pto_type=e.pto_type or "", holiday_hours=round(e.holiday_hours or 0.0, 2), bereavement_hours=round(e.bereavement_hours or 0.0, 2), hours_paid=round(e.hours_paid or (e.total_hours or 0.0), 2), ) ) rows.sort(key=lambda r: (r.work_date, r.clock_in or "")) # Week ranges: explicit or default week_ranges = week_ranges or default_week_ranges(start, end, start_weekday) # Map date -> week index def week_idx(d: date) -> int: for idx, (ws, we) in enumerate(week_ranges, start=1): if ws <= d <= we: return idx return len(week_ranges) weekly_hours = defaultdict(float) for r in rows: weekly_hours[week_idx(r.work_date)] += r.total_hours weekly_summary = [] for i, (ws, we) in enumerate(week_ranges, start=1): base_hours = round(weekly_hours[i], 2) carry = carry_over_hours if i == 1 else 0.0 all_with_carry = base_hours + carry ot = max(0.0, all_with_carry - 40.0) reg = max(0.0, base_hours - ot) weekly_summary.append({ "label": f"Week {i}", "start": ws, "end": we, "all": round(base_hours, 2), "reg": round(reg, 2), "ot": round(ot, 2), }) totals = { "regular": round(sum(ws["reg"] for ws in weekly_summary), 2), "pto": round(sum(r.pto_hours for r in rows), 2), "holiday": round(sum(r.holiday_hours for r in rows), 2), "bereavement": round(sum(r.bereavement_hours for r in rows), 2), "overtime": round(sum(ws["ot"] for ws in weekly_summary), 2), "paid_total": round(sum(r.hours_paid for r in rows), 2), } return {"rows": rows, "weekly_summary": weekly_summary, "totals": totals, "week_ranges": week_ranges} def compute_yearly_stats(db: Session, employee_id: int, scope: str = "year", year: Optional[int] = None, month: Optional[int] = None): q = db.query( func.date_part("year", TimeEntry.work_date).label("y"), func.date_part("month", TimeEntry.work_date).label("m"), func.sum(TimeEntry.total_hours).label("total"), func.sum(TimeEntry.pto_hours).label("pto"), func.sum(TimeEntry.holiday_hours).label("holiday"), func.sum(TimeEntry.bereavement_hours).label("bereavement"), func.sum(TimeEntry.hours_paid).label("paid"), ).filter(TimeEntry.employee_id == employee_id) if scope == "month" and year and month: q = q.filter(func.date_part("year", TimeEntry.work_date) == year) q = q.filter(func.date_part("month", TimeEntry.work_date) == month) elif scope == "year" and year: q = q.filter(func.date_part("year", TimeEntry.work_date) == year) q = q.group_by("y", "m").order_by("y", "m") rows = q.all() data = [] for y, m, total, pto, holiday, bereavement, paid in rows: data.append({ "year": int(y), "month": int(m), "total_hours": float(total or 0.0), "average_daily": round(float(total or 0.0) / 20.0, 2), "pto": float(pto or 0.0), "holiday": float(holiday or 0.0), "bereavement": float(bereavement or 0.0), "paid": float(paid or 0.0), }) return {"rows": data} def available_years_for_employee(db: Session, employee_id: int) -> List[int]: rows = ( db.query(func.min(TimeEntry.work_date), func.max(TimeEntry.work_date)) .filter(TimeEntry.employee_id == employee_id) .first() ) if not rows or not rows[0]: return [] start, end = rows return list(range(start.year, end.year + 1))