Skip to content

Commit

Permalink
all typed
Browse files Browse the repository at this point in the history
  • Loading branch information
Carreau committed Dec 12, 2024
1 parent d70585a commit c61a521
Showing 1 changed file with 38 additions and 31 deletions.
69 changes: 38 additions & 31 deletions IPython/core/history.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
""" History related magics and functionality """

from __future__ import annotations

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.


import atexit
import datetime
import re
try:
from sqlite3 import DatabaseError, OperationalError
import sqlite3
sqlite3_found = True
except ModuleNotFoundError:
sqlite3_found = False
class DatabaseError(Exception):pass #type: ignore [no-redef]
class OperationalError(Exception):pass #type: ignore [no-redef]



Expand All @@ -39,8 +33,20 @@ class OperationalError(Exception):pass #type: ignore [no-redef]

from IPython.paths import locate_profile
from IPython.utils.decorators import undoc
from typing import Iterable, Tuple, Optional
from typing import Iterable, Tuple, Optional, TYPE_CHECKING
import typing
if TYPE_CHECKING:
from IPython.core.interactiveshell import InteractiveShell
from IPython.config.Configuration import Configuration

try:
from sqlite3 import DatabaseError, OperationalError
import sqlite3
sqlite3_found = True
except ModuleNotFoundError:
sqlite3_found = False
class DatabaseError(Exception):pass #type: ignore [no-redef]
class OperationalError(Exception):pass #type: ignore [no-redef]

InOrInOut = typing.Union[str, Tuple[str, Optional[str]]]

Expand Down Expand Up @@ -192,14 +198,14 @@ class HistoryAccessor(HistoryAccessorBase):
).tag(config=True)

@default("connection_options")
def _default_connection_options(self):
def _default_connection_options(self) -> typing.Dict[str, bool]:
return dict(check_same_thread=False)

# The SQLite database
db = Any()
@observe('db')
@only_when_enabled
def _db_changed(self, change):
def _db_changed(self, change): # type: ignore [no-untyped-def]
"""validate the db, since it can be an Instance of two different types"""
new = change['new']
connection_types = (DummyDB, sqlite3.Connection)
Expand All @@ -208,7 +214,7 @@ def _db_changed(self, change):
(self.__class__.__name__, new)
raise TraitError(msg)

def __init__(self, profile="default", hist_file="", **traits):
def __init__(self, profile:str="default", hist_file:str="", **traits:typing.Any) -> None:
"""Create a new history accessor.
Parameters
Expand Down Expand Up @@ -236,7 +242,7 @@ def __init__(self, profile="default", hist_file="", **traits):

self.init_db()

def _get_hist_file_name(self, profile='default'):
def _get_hist_file_name(self, profile:str='default') -> Path:
"""Find the history file for the given profile name.
This is overridden by the HistoryManager subclass, to use the shell's
Expand All @@ -250,7 +256,7 @@ def _get_hist_file_name(self, profile='default'):
return Path(locate_profile(profile)) / "history.sqlite"

@catch_corrupt_db
def init_db(self):
def init_db(self) -> None:
"""Connect to the database, and create tables if necessary."""
if not self.enabled:
self.db = DummyDB()
Expand All @@ -259,7 +265,7 @@ def init_db(self):
# use detect_types so that timestamps return datetime objects
kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
kwargs.update(self.connection_options)
self.db = sqlite3.connect(str(self.hist_file), **kwargs)
self.db = sqlite3.connect(str(self.hist_file), **kwargs) # type: ignore [call-overload]
with self.db:
self.db.execute(
"""CREATE TABLE IF NOT EXISTS sessions (session integer
Expand All @@ -281,15 +287,15 @@ def init_db(self):
# success! reset corrupt db count
self._corrupt_db_counter = 0

def writeout_cache(self):
def writeout_cache(self) -> None:
"""Overridden by HistoryManager to dump the cache before certain
database lookups."""
pass

## -------------------------------
## Methods for retrieving history:
## -------------------------------
def _run_sql(self, sql, params, raw=True, output=False, latest=False):
def _run_sql(self, sql:str, params:typing.Tuple, raw:bool=True, output:bool=False, latest:bool=False) -> Iterable[Tuple[int, int, InOrInOut]]:
"""Prepares and runs an SQL query for the history database.
Parameters
Expand Down Expand Up @@ -324,7 +330,7 @@ def _run_sql(self, sql, params, raw=True, output=False, latest=False):

@only_when_enabled
@catch_corrupt_db
def get_session_info(self, session) -> Tuple[int, datetime.datetime, Optional[datetime.datetime], Optional[int], str]:
def get_session_info(self, session:int) -> Tuple[int, datetime.datetime, Optional[datetime.datetime], Optional[int], str]:
"""Get info about a session.
Parameters
Expand All @@ -349,17 +355,18 @@ def get_session_info(self, session) -> Tuple[int, datetime.datetime, Optional[da
return self.db.execute(query, (session,)).fetchone()

@catch_corrupt_db
def get_last_session_id(self):
def get_last_session_id(self) -> Optional[int]:
"""Get the last session ID currently in the database.
Within IPython, this should be the same as the value stored in
:attr:`HistoryManager.session_number`.
"""
for record in self.get_tail(n=1, include_latest=True):
return record[0]
return None

@catch_corrupt_db
def get_tail(self, n=10, raw=True, output=False, include_latest=False) -> Iterable[Tuple[int, int, str]]:
def get_tail(self, n:int=10, raw:bool=True, output:bool=False, include_latest:bool=False) -> Iterable[Tuple[int, int, InOrInOut]]:
"""Get the last n lines from the history database.
Parameters
Expand Down Expand Up @@ -388,8 +395,7 @@ def get_tail(self, n=10, raw=True, output=False, include_latest=False) -> Iterab
return reversed(list(cur))

@catch_corrupt_db
def search(self, pattern="*", raw=True, search_raw=True,
output=False, n=None, unique=False):
def search(self, pattern:str="*", raw:bool=True, search_raw:bool=True, output:bool=False, n:Optional[int]=None, unique:bool=False) -> Iterable[Tuple[int, int, InOrInOut]]:
"""Search the database using unix glob-style matching (wildcards
* and ?).
Expand All @@ -416,7 +422,7 @@ def search(self, pattern="*", raw=True, search_raw=True,
tosearch = "history." + tosearch
self.writeout_cache()
sqlform = "WHERE %s GLOB ?" % tosearch
params = (pattern,)
params:typing.Tuple[typing.Any, ...] = (pattern,)
if unique:
sqlform += ' GROUP BY {0}'.format(tosearch)
if n is not None:
Expand All @@ -430,7 +436,7 @@ def search(self, pattern="*", raw=True, search_raw=True,
return cur

@catch_corrupt_db
def get_range(self, session, start=1, stop=None, raw=True,output=False):
def get_range(self, session:int, start:int=1, stop:Optional[int]=None, raw:bool=True, output:bool=False) -> Iterable[Tuple[int, int, InOrInOut]]:
"""Retrieve input by session.
Parameters
Expand All @@ -457,6 +463,7 @@ def get_range(self, session, start=1, stop=None, raw=True,output=False):
(session, line, input) if output is False, or
(session, line, (input, output)) if output is True.
"""
params:typing.Tuple[typing.Any, ...]
if stop:
lineclause = "line >= ? AND line < ?"
params = (session, start, stop)
Expand All @@ -467,7 +474,7 @@ def get_range(self, session, start=1, stop=None, raw=True,output=False):
return self._run_sql("WHERE session==? AND %s" % lineclause,
params, raw=raw, output=output)

def get_range_by_str(self, rangestr, raw=True, output=False):
def get_range_by_str(self, rangestr:str, raw:bool=True, output:bool=False) -> Iterable[Tuple[int, int, InOrInOut]]:
"""Get lines of history from a string of ranges, as used by magic
commands %hist, %save, %macro, etc.
Expand Down Expand Up @@ -498,7 +505,7 @@ class HistoryManager(HistoryAccessor):

# An instance of the IPython shell we are attached to
shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
allow_none=True)
allow_none=False)
# Lists to hold processed and raw history. These start with a blank entry
# so that we can index them starting from 1
input_hist_parsed = List([""])
Expand All @@ -507,7 +514,7 @@ class HistoryManager(HistoryAccessor):
dir_hist: List = List()

@default("dir_hist")
def _dir_hist_default(self):
def _dir_hist_default(self) -> typing.List[Path]:
try:
return [Path.cwd()]
except OSError:
Expand All @@ -517,10 +524,10 @@ def _dir_hist_default(self):
# execution count.
output_hist = Dict()
# The text/plain repr of outputs.
output_hist_reprs: typing.Dict[int, str] = Dict()
output_hist_reprs: typing.Dict[int, str] = Dict() # type: ignore [assignment]

# The number of the current session in the history database
session_number:int = Integer()
session_number:int = Integer() #type: ignore [assignment]

db_log_output = Bool(False,
help="Should the history database include output? (default: no)"
Expand Down Expand Up @@ -552,7 +559,7 @@ def _dir_hist_default(self):
# an exit call).
_exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")

def __init__(self, shell=None, config=None, **traits):
def __init__(self, shell: InteractiveShell, config: Optional[Configuration] = None, **traits: typing.Any):
"""Create a new history manager associated with a shell instance.
"""
super().__init__(shell=shell, config=config,
Expand All @@ -579,7 +586,7 @@ def __init__(self, shell=None, config=None, **traits):
)
self.hist_file = ":memory:"

def _get_hist_file_name(self, profile:str=None):
def _get_hist_file_name(self, profile:Optional[str]=None) -> Path:
"""Get default history file name based on the Shell's profile.
The profile parameter is ignored, but must exist for compatibility with
Expand Down

0 comments on commit c61a521

Please sign in to comment.