Skip to content

Commit

Permalink
feat (logging) Improve logging (#226)
Browse files Browse the repository at this point in the history
* new logger

* new logger and evaluation message

* sed to remove unused import everywhere

* misc eval changes and gym up level

* removed last unused import

* docstrings and cleanup of eval

* readd verbose

* bug gym logging

* add to api

* add colors

* new style log

* misc ameliorations

* add test on second line also

* add a comment to warn in logging to be careful when changing default message
  • Loading branch information
TimotheeMathieu authored Jul 20, 2022
1 parent 357a0ac commit b8a8b9c
Show file tree
Hide file tree
Showing 56 changed files with 373 additions and 167 deletions.
9 changes: 9 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ Check Utilities
utils.check_seeding_agent
utils.check_agent_manager

Logging Utilities
-----------------

.. autosummary::
:toctree: generated/
:template: function.rst

utils.logging.set_level


Typing
------
Expand Down
6 changes: 4 additions & 2 deletions examples/demo_bandits/plot_mirror_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from rlberry.agents.bandits import BanditWithSimplePolicy
from rlberry.wrappers import WriterWrapper
import rlberry.spaces as spaces
import logging

import requests
import matplotlib.pyplot as plt


logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger

# Environment definition

Expand Down
8 changes: 6 additions & 2 deletions rlberry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from ._version import __version__
import logging

logger = logging.getLogger("rlberry_logger")

from rlberry.utils.logging import configure_logging


__path__ = __import__("pkgutil").extend_path(__path__, __name__)

# Initialize logging level

configure_logging(level="INFO")


# define __version__

__all__ = ["__version__"]
__all__ = ["__version__", "logger"]
5 changes: 3 additions & 2 deletions rlberry/agents/adaptiveql/adaptiveql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import gym.spaces as spaces
import numpy as np
from rlberry.agents import AgentWithSimplePolicy
from rlberry.agents.adaptiveql.tree import MDPTreePartition

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class AdaptiveQLAgent(AgentWithSimplePolicy):
Expand Down
4 changes: 2 additions & 2 deletions rlberry/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
import dill
import pickle
import logging
import numpy as np
from inspect import signature
from pathlib import Path
Expand All @@ -14,8 +13,9 @@
from typing import Optional
import inspect

import rlberry

logger = logging.getLogger(__name__)
logger = rlberry.logger


class Agent(ABC):
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/bandits/bandit_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from rlberry.agents import AgentWithSimplePolicy
from .tools import BanditTracker
import pickle
import logging

from pathlib import Path

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class BanditWithSimplePolicy(AgentWithSimplePolicy):
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/bandits/index_agents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from rlberry.agents.bandits import BanditWithSimplePolicy
import logging

logger = logging.getLogger(__name__)

import rlberry

logger = rlberry.logger

# TODO : fix bug when doing several fit, the fit do not resume. Should define
# self.rewards and self.action and resume training.
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/bandits/randomized_agents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from rlberry.agents.bandits import BanditWithSimplePolicy
import logging

logger = logging.getLogger(__name__)

import rlberry

logger = rlberry.logger


class RandomizedAgent(BanditWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/bandits/tools/tracker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from rlberry import metadata_utils
from rlberry.utils.writers import DefaultWriter

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class BanditTracker(DefaultWriter):
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/bandits/ts_agents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from rlberry.agents.bandits import BanditWithSimplePolicy
import logging

logger = logging.getLogger(__name__)

import rlberry

logger = rlberry.logger


class TSAgent(BanditWithSimplePolicy):
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/experimental/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
import logging

import numpy as np
import optax
import dill
Expand All @@ -43,7 +43,9 @@
from rlberry.agents.jax.utils.replay_buffer import ReplayBuffer
from typing import Any, Callable, Mapping, Optional

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


@chex.dataclass
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/experimental/jax/utils/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
* For priority updates, see https://github.com/deepmind/reverb/issues/28
"""

import logging

import tensorflow as tf

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger

try:
import reverb
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/experimental/torch/avec/avec_ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import logging

import torch.nn as nn
import inspect

Expand All @@ -12,7 +12,9 @@
from rlberry.utils.torch import choose_device
from rlberry.wrappers.uncertainty_estimator_wrapper import UncertaintyEstimatorWrapper

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class AVECPPOAgent(AgentWithSimplePolicy):
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/experimental/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from torch.nn.functional import one_hot
import logging

import gym.spaces as spaces

from rlberry.agents import AgentWithSimplePolicy
Expand All @@ -14,7 +14,9 @@
from rlberry.utils.torch import choose_device
from rlberry.wrappers.uncertainty_estimator_wrapper import UncertaintyEstimatorWrapper

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class SACAgent(AgentWithSimplePolicy):
Expand Down
6 changes: 3 additions & 3 deletions rlberry/agents/kernel_based/rs_kernel_ucbvi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

import numpy as np
from rlberry.utils.jit_setup import numba_jit

Expand All @@ -11,7 +9,9 @@
from rlberry.agents.kernel_based.kernels import kernel_func
from rlberry.agents.kernel_based.common import map_to_representative

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


@numba_jit
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/kernel_based/rs_ucbvi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from rlberry.agents.agent import AgentWithSimplePolicy
import numpy as np

Expand All @@ -7,7 +6,9 @@
from rlberry.agents.dynprog.utils import backward_induction_in_place
from rlberry.agents.kernel_based.common import map_to_representative

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class RSUCBVIAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/linear/lsvi_ucb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import numpy as np
from rlberry.agents import AgentWithSimplePolicy
from gym.spaces import Discrete
from rlberry.utils.jit_setup import numba_jit

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


@numba_jit
Expand Down
6 changes: 4 additions & 2 deletions rlberry/agents/mbqvi/mbqvi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
import logging


from rlberry.agents import AgentWithSimplePolicy
from rlberry.agents.dynprog.utils import backward_induction, value_iteration
from gym.spaces import Discrete

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class MBQVIAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/optql/optql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import numpy as np

import gym.spaces as spaces
from rlberry.agents import AgentWithSimplePolicy
from rlberry.exploration_tools.discrete_counter import DiscreteCounter

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class OptQLAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/psrl/psrl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import numpy as np

import gym.spaces as spaces
Expand All @@ -9,7 +8,9 @@
backward_induction_sd,
)

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class PSRLAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/rlsvi/rlsvi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import numpy as np

import gym.spaces as spaces
Expand All @@ -10,7 +9,9 @@
backward_induction_sd,
)

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class RLSVIAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/stable_baselines/stable_baselines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type, Union

Expand All @@ -13,7 +12,9 @@
from rlberry.agents import AgentWithSimplePolicy


logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


def is_recordable(value: Any) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import logging

import gym.spaces as spaces
import numpy as np
Expand All @@ -13,7 +12,9 @@
from rlberry.utils.factory import load
from typing import Optional

logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


class A2CAgent(AgentWithSimplePolicy):
Expand Down
5 changes: 3 additions & 2 deletions rlberry/agents/torch/dqn/dqn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import inspect

import numpy as np
Expand All @@ -19,7 +18,9 @@
from typing import Callable, Optional, Union


logger = logging.getLogger(__name__)
import rlberry

logger = rlberry.logger


def default_q_net_fn(env, **kwargs):
Expand Down
Loading

0 comments on commit b8a8b9c

Please sign in to comment.