#!/usr/bin/env python from common import get_middle_page, parse def update_is_correct( update: tuple[int, ...], not_before: dict[int, list[int]], not_after: dict[int, list[int]] ) -> bool: correct = True for index, page in enumerate(update[:-1]): correct = all( map( lambda x: x in not_before.get(page, []) and x not in not_after.get(page, []), update[index + 1 :], ) ) if not correct: break return correct def condense_page_orders(page_orders: list[tuple[int, int]]) -> tuple[dict[int, list[int]], dict[int, list[int]]]: not_before = {} not_after = {} for item in page_orders: if item[0] in not_before: not_before[item[0]].append(item[1]) else: not_before[item[0]] = [item[1]] if item[1] in not_after: not_after[item[1]].append(item[0]) else: not_after[item[1]] = [item[0]] return not_before, not_after def solve(input: tuple[list[tuple[int, int]], list[tuple[int, ...]]]) -> int: raw_page_orders, updates = input not_before, not_after = condense_page_orders(raw_page_orders) middle_pages = map(get_middle_page, filter(lambda u: update_is_correct(u, not_before, not_after), updates)) return sum(middle_pages) def main(): print(solve(parse("input"))) if __name__ == "__main__": main()