Skip to content

Commit

Permalink
feat: [#280] Add Sqlite driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl committed Nov 8, 2024
1 parent d303909 commit 093615c
Showing 1 changed file with 13 additions and 70 deletions.
83 changes: 13 additions & 70 deletions database/schema/grammars/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ func NewSqlite() *Sqlite {
return sqlite
}

func (r *Sqlite) CompileAdd(blueprint schema.Blueprint) string {
//return fmt.Sprintf("alter table %s add column %s", blueprint.GetTableName(), getColumn(r, blueprint, command.Column))
return ""
func (r *Sqlite) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s add column %s", blueprint.GetTableName(), getColumn(r, blueprint, command.Column))
}

func (r *Sqlite) CompileCreate(blueprint schema.Blueprint, query orm.Query) string {
Expand All @@ -43,62 +42,35 @@ func (r *Sqlite) CompileCreate(blueprint schema.Blueprint, query orm.Query) stri
}

func (r *Sqlite) CompileDropAllDomains(domains []string) string {
return fmt.Sprintf("drop domain %s cascade", strings.Join(domains, ", "))
return ""
}

func (r *Sqlite) CompileDropAllTables(tables []string) string {
return fmt.Sprintf("drop table %s cascade", strings.Join(tables, ", "))
return "delete from sqlite_master where type in ('table', 'index', 'trigger')"
}

func (r *Sqlite) CompileDropAllTypes(types []string) string {
return fmt.Sprintf("drop type %s cascade", strings.Join(types, ", "))
return ""
}

func (r *Sqlite) CompileDropAllViews(views []string) string {
return fmt.Sprintf("drop view %s cascade", strings.Join(views, ", "))
return "delete from sqlite_master where type in ('view')"
}

func (r *Sqlite) CompileDropIfExists(blueprint schema.Blueprint) string {
return fmt.Sprintf("drop table if exists %s", blueprint.GetTableName())
}

func (r *Sqlite) CompileTables(database string) string {
return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " +
"obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " +
"where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " +
"order by c.relname"
func (r *Sqlite) CompileTables() string {
return "select name from sqlite_master where type = 'table' and name not like 'sqlite_%' order by name"
}

func (r *Sqlite) CompileTypes() string {
return `select t.typname as name, n.nspname as schema, t.typtype as type, t.typcategory as category,
((t.typinput = 'array_in'::regproc and t.typoutput = 'array_out'::regproc) or t.typtype = 'm') as implicit
from pg_type t
join pg_namespace n on n.oid = t.typnamespace
left join pg_class c on c.oid = t.typrelid
left join pg_type el on el.oid = t.typelem
left join pg_class ce on ce.oid = el.typrelid
where ((t.typrelid = 0 and (ce.relkind = 'c' or ce.relkind is null)) or c.relkind = 'c')
and not exists (select 1 from pg_depend d where d.objid in (t.oid, t.typelem) and d.deptype = 'e')
and n.nspname not in ('pg_catalog', 'information_schema')`
return ""
}

func (r *Sqlite) CompileViews() string {
return "select viewname as name, schemaname as schema, definition from pg_views where schemaname not in ('pg_catalog', 'information_schema') order by viewname"
}

func (r *Sqlite) EscapeNames(names []string) []string {
escapedNames := make([]string, 0, len(names))

for _, name := range names {
segments := strings.Split(name, ".")
for i, segment := range segments {
segments[i] = strings.Trim(segment, `'"`)
}
escapedName := `"` + strings.Join(segments, `"."`) + `"`
escapedNames = append(escapedNames, escapedName)
}

return escapedNames
return "select name, sql as definition from sqlite_master where type = 'view' order by name"
}

func (r *Sqlite) GetAttributeCommands() []string {
Expand All @@ -110,18 +82,6 @@ func (r *Sqlite) GetModifiers() []func(blueprint schema.Blueprint, column schema
}

func (r *Sqlite) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetChange() {
if !column.GetAutoIncrement() {
if column.GetDefault() == nil {
return "drop default"
} else {
return fmt.Sprintf("set default %s", getDefaultValue(column.GetDefault()))
}
}

return ""
}

if column.GetDefault() != nil {
return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault()))
}
Expand All @@ -130,14 +90,6 @@ func (r *Sqlite) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnD
}

func (r *Sqlite) ModifyNullable(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetChange() {
if column.GetNullable() {
return "drop not null"
} else {
return "set not null"
}
}

if column.GetNullable() {
return " null"
} else {
Expand All @@ -146,19 +98,15 @@ func (r *Sqlite) ModifyNullable(blueprint schema.Blueprint, column schema.Column
}

func (r *Sqlite) ModifyIncrement(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if !column.GetChange() && !blueprint.HasCommand("primary") && slices.Contains(r.serials, column.GetType()) && column.GetAutoIncrement() {
return " primary key"
if slices.Contains(r.serials, column.GetType()) && column.GetAutoIncrement() {
return " primary key autoincrement"
}

return ""
}

func (r *Sqlite) TypeBigInteger(column schema.ColumnDefinition) string {
if column.GetAutoIncrement() {
return "bigserial"
}

return "bigint"
return "integer"
}

func (r *Sqlite) TypeInteger(column schema.ColumnDefinition) string {
Expand All @@ -170,11 +118,6 @@ func (r *Sqlite) TypeInteger(column schema.ColumnDefinition) string {
}

func (r *Sqlite) TypeString(column schema.ColumnDefinition) string {
length := column.GetLength()
if length > 0 {
return fmt.Sprintf("varchar(%d)", length)
}

return "varchar"
}

Expand Down

0 comments on commit 093615c

Please sign in to comment.