from dataclasses import dataclass from datetime import date, datetime, time from calendar import monthrange from typing import Dict, List, Optional, Tuple from sqlalchemy import func from sqlalchemy.orm import Session from decimal import Decimal, ROUND_HALF_UP, getcontext from .models import TimeEntry, TimesheetPeriod # Decimal settings for consistent rounding getcontext().prec = 28 def D(x) -> Decimal: return Decimal(str(x or 0)) def q2(x: Decimal) -> Decimal: return x.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) # --------------------------- # Period helpers and listing # --------------------------- def _semi_monthly_period_for_date(d: date) -> Tuple[date, date]: if d.day <= 15: start = date(d.year, d.month, 1) end = date(d.year, d.month, 15) else: start = date(d.year, d.month, 16) end = date(d.year, d.month, monthrange(d.year, d.month)[1]) return start, end def enumerate_timesheets_global(db: Session) -> List[Tuple[int, date, date, str]]: rows: List[TimesheetPeriod] = ( db.query(TimesheetPeriod) .order_by(TimesheetPeriod.period_start.asc(), TimesheetPeriod.period_end.asc(), TimesheetPeriod.id.asc()) .all() ) out: List[Tuple[int, date, date, str]] = [] for ts in rows: name = ts.name or f"{ts.period_start.isoformat()}..{ts.period_end.isoformat()}" out.append((ts.id, ts.period_start, ts.period_end, name)) return out # --------------------------- # Viewer/print data shaping # --------------------------- @dataclass class RowOut: entry_id: int work_date: date clock_in: Optional[datetime] clock_out: Optional[datetime] break_hours: Decimal total_hours: Decimal pto_hours: Decimal pto_type: Optional[str] holiday_hours: Decimal bereavement_hours: Decimal hours_paid: Decimal needs_pto_review: bool = False needs_long_shift_review: bool = False @dataclass class Totals: regular: Decimal pto: Decimal holiday: Decimal bereavement: Decimal overtime: Decimal paid_total: Decimal @dataclass class WeekSummary: label: str all: Decimal reg: Decimal ot: Decimal @dataclass class Grouped: rows: List[RowOut] totals: Totals weekly_summary: List[WeekSummary] def _to_datetime(d: date, t) -> Optional[datetime]: if t is None: return None if isinstance(t, datetime): return t if isinstance(t, time): return datetime.combine(d, t) s = str(t).strip() for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%I:%M:%S %p", "%H:%M:%S", "%H:%M"): try: parsed = datetime.strptime(s, fmt) return parsed.replace(year=d.year, month=d.month, day=d.day) except Exception: continue return None def group_entries_for_timesheet( entries: List[TimeEntry], period_start: date, period_end: date, week_map: Optional[Dict[date, int]] = None, carry_over_hours: float = 0.0, ) -> Grouped: rows: List[RowOut] = [] week_totals: Dict[int, Decimal] = {} sum_worked = D(0) sum_pto = D(0) sum_holiday = D(0) sum_bereavement = D(0) for e in sorted(entries, key=lambda x: (x.work_date, _to_datetime(x.work_date, x.clock_in) or datetime.min)): total = D(e.total_hours) brk = D(e.break_hours) pto = D(e.pto_hours) hol = D(e.holiday_hours) ber = D(e.bereavement_hours) worked = total - brk if worked < D(0): worked = D(0) hours_paid_row = q2(worked + pto + hol + ber) rows.append( RowOut( entry_id=e.id, work_date=e.work_date, clock_in=e.clock_in, clock_out=e.clock_out, break_hours=q2(brk), total_hours=q2(total), pto_hours=q2(pto), pto_type=(e.pto_type or None), holiday_hours=q2(hol), bereavement_hours=q2(ber), hours_paid=hours_paid_row, needs_pto_review=(pto > D(0) and not (e.pto_type or "").strip()), needs_long_shift_review=(total > D(10)), ) ) sum_worked += worked sum_pto += pto sum_holiday += hol sum_bereavement += ber wk = (week_map or {}).get(e.work_date) if wk is not None: week_totals[wk] = week_totals.get(wk, D(0)) + worked # Weekly summary weekly_summary: List[WeekSummary] = [] carry = D(carry_over_hours) for wk in sorted(week_totals.keys()): worked_w = week_totals[wk] reg_cap = D(40) - (carry if wk == 1 else D(0)) if reg_cap < D(0): reg_cap = D(0) reg_w = worked_w if worked_w <= reg_cap else reg_cap ot_w = worked_w - reg_w weekly_summary.append( WeekSummary( label=f"Week {wk}", all=q2(worked_w), reg=q2(reg_w), ot=q2(ot_w), ) ) # Totals worked_total = q2(sum((w.all for w in weekly_summary), D(0))) regular_total = q2(sum((w.reg for w in weekly_summary), D(0))) overtime_total = q2(sum((w.ot for w in weekly_summary), D(0))) paid_total = q2(worked_total + sum_pto + sum_holiday + sum_bereavement) totals = Totals( regular=regular_total, pto=q2(sum_pto), holiday=q2(sum_holiday), bereavement=q2(sum_bereavement), overtime=overtime_total, paid_total=paid_total, ) return Grouped(rows=rows, totals=totals, weekly_summary=weekly_summary) # --------------------------- # Duplicate merging # --------------------------- def _sum_gaps(intervals: List[Tuple[datetime, datetime]]) -> Decimal: if not intervals: return D(0) intervals = sorted(intervals, key=lambda x: x[0]) gaps_hours = D(0) current_end = intervals[0][1] for i in range(1, len(intervals)): start_i, end_i = intervals[i] if start_i > current_end: gaps_hours += D((start_i - current_end).total_seconds()) / D(3600) if end_i > current_end: current_end = end_i return q2(gaps_hours) def merge_duplicates_for_timesheet(db: Session, employee_id: int, timesheet_id: int) -> int: dup_dates = [ r[0] for r in ( db.query(TimeEntry.work_date) .filter(TimeEntry.timesheet_id == timesheet_id, TimeEntry.employee_id == employee_id) .group_by(TimeEntry.work_date) .having(func.count(TimeEntry.id) > 1) .all() ) ] merged_count = 0 for d in dup_dates: entries: List[TimeEntry] = ( db.query(TimeEntry) .filter(TimeEntry.timesheet_id == timesheet_id, TimeEntry.employee_id == employee_id, TimeEntry.work_date == d) .order_by(TimeEntry.clock_in.asc()) .all() ) if len(entries) < 2: continue intervals: List[Tuple[datetime, datetime]] = [] for e in entries: ci = _to_datetime(d, e.clock_in) co = _to_datetime(d, e.clock_out) if ci and co and co > ci: intervals.append((ci, co)) earliest_in: Optional[datetime] = min((s for s, _ in intervals), default=None) latest_out: Optional[datetime] = max((e for _, e in intervals), default=None) span_hours = D(0) if earliest_in and latest_out and latest_out > earliest_in: span_hours = q2(D((latest_out - earliest_in).total_seconds()) / D(3600)) break_hours = _sum_gaps(intervals) pto_hours = q2(sum(D(e.pto_hours) for e in entries)) holiday_hours = q2(sum(D(e.holiday_hours) for e in entries)) bereavement_hours = q2(sum(D(e.bereavement_hours) for e in entries)) pto_type = next((e.pto_type for e in entries if e.pto_type), None) worked_hours = span_hours - break_hours if worked_hours < D(0): worked_hours = D(0) hours_paid = q2(worked_hours + pto_hours + holiday_hours + bereavement_hours) def keeper_key(e: TimeEntry): ci = _to_datetime(d, e.clock_in) return ci or datetime.combine(d, time(0, 0)) keeper = min(entries, key=keeper_key) # Persist Decimal directly (models use Numeric now) keeper.total_hours = span_hours keeper.break_hours = break_hours keeper.pto_hours = pto_hours keeper.pto_type = pto_type keeper.holiday_hours = holiday_hours keeper.bereavement_hours = bereavement_hours keeper.hours_paid = hours_paid if earliest_in: keeper.clock_in = earliest_in if latest_out: keeper.clock_out = latest_out for e in entries: if e.id != keeper.id: db.delete(e) merged_count += 1 if merged_count: db.commit() return merged_count