Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Defining Models

A Model ties together regimes, an age grid, and a regime ID class into a solvable lifecycle model.

The Model Constructor

from lcm import Model

model = Model(
    regimes=regimes,  # dict mapping names to Regime instances
    ages=ages,  # AgeGrid defining the lifecycle timeline
    regime_id_class=RegimeId,  # @categorical dataclass mapping names to int indices
    enable_jit=True,  # controls JAX compilation (default: True)
    fixed_params={},  # optional params baked in at init time
    description="",  # optional description string
)

All arguments are keyword-only. The three required arguments are regimes, ages, and regime_id_class.

Regime ID Classes

The regime_id_class maps regime names to integer indices. Use the @categorical decorator to create it:

from lcm import categorical


@categorical(ordered=False)
class RegimeId:
    retired: int
    working: int

Rules:

Age Grids

The ages argument defines the lifecycle timeline. There are two construction modes:

Range-based

from lcm import AgeGrid

ages = AgeGrid(start=25, stop=75, step="Y")  # annual steps, ages 25 to 75

Step formats:

The stop value is inclusive if (stop - start) is exactly divisible by the step size.

Exact values

ages = AgeGrid(exact_values=[25, 35, 45, 55, 65, 75])

Use this for irregular age spacing.

Key properties

Model Validation Rules

The Model constructor validates:

Inspecting a Model

After construction, the model exposes several useful attributes:

model.regimes  # immutable mapping of user Regime objects
model.internal_regimes  # processed internal representations
model.n_periods  # number of periods
model.regime_names_to_ids  # name -> integer mapping
model.get_params_template()  # mutable copy of the parameter template

Use model.get_params_template() to get a mutable copy of the parameter template — see Parameters.

Complete Example

import jax.numpy as jnp
from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical


@categorical(ordered=False)
class RegimeId:
    retired: int
    working: int


@categorical(ordered=True)
class LaborSupply:
    do_not_work: int
    work: int


def next_wealth(wealth, consumption, interest_rate):
    return (wealth - consumption) * (1 + interest_rate)


def next_regime(labor_supply):
    return jnp.where(
        labor_supply == LaborSupply.work, RegimeId.working, RegimeId.retired
    )


def utility(consumption, labor_supply, disutility_of_work):
    return jnp.log(consumption) - disutility_of_work * labor_supply


def terminal_utility(wealth):
    return jnp.log(wealth)


working = Regime(
    transition=next_regime,
    states={
        "wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
    },
    state_transitions={
        "wealth": next_wealth,
    },
    actions={
        "consumption": LinSpacedGrid(start=1, stop=50, n_points=30),
        "labor_supply": DiscreteGrid(LaborSupply),
    },
    functions={"utility": utility},
)

retired = Regime(
    transition=None,
    states={
        "wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
    },
    functions={"utility": terminal_utility},
)

model = Model(
    regimes={"working": working, "retired": retired},
    ages=AgeGrid(start=25, stop=75, step="Y"),
    regime_id_class=RegimeId,
)

See Also