Skip to content

Commit

Permalink
fix: fix postgres sql syntax error
Browse files Browse the repository at this point in the history
  • Loading branch information
wfxr committed Nov 5, 2021
1 parent 1f897d3 commit a376df0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions internal/database/offline/postgres/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ package postgres
import (
"context"
"fmt"
"strings"

"github.com/oom-ai/oomstore/internal/database/dbutil"
"github.com/oom-ai/oomstore/internal/database/offline"
"github.com/oom-ai/oomstore/pkg/oomstore/types"
)

func (db *DB) Export(ctx context.Context, opt offline.ExportOpt) (<-chan *types.RawFeatureValueRecord, error) {
fields := append([]string{opt.EntityName}, opt.FeatureNames...)
query := fmt.Sprintf("select %s from %s", strings.Join(fields, ","), opt.DataTable)
query := fmt.Sprintf("select %s from %s", dbutil.Quote(`"`, fields...), opt.DataTable)
if opt.Limit != nil {
query += fmt.Sprintf(" LIMIT %d", *opt.Limit)
}
Expand Down
14 changes: 7 additions & 7 deletions internal/database/offline/postgres/feature_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package postgres
import (
"context"
"fmt"
"strings"

"github.com/jmoiron/sqlx"
"github.com/spf13/cast"
Expand Down Expand Up @@ -40,27 +39,28 @@ func (db *DB) Join(ctx context.Context, opt offline.JoinOpt) (dataMap map[string

// Step 1: iterate each table range, get result
joinQuery := `
INSERT INTO %s(unique_key, entity_key, unix_time, %s)
INSERT INTO "%s"(unique_key, entity_key, unix_time, %s)
SELECT
CONCAT(l.entity_key, ',', l.unix_time) AS unique_key,
l.entity_key AS entity_key,
l.unix_time AS unix_time,
%s
FROM %s AS l
LEFT JOIN %s AS r
FROM "%s" AS l
LEFT JOIN "%s" AS r
ON l.entity_key = r.%s
WHERE l.unix_time >= $1 AND l.unix_time < $2;
`
featureNamesStr := strings.Join(opt.Features.Names(), ", ")
featureNamesStr := dbutil.Quote(`"`, opt.Features.Names()...)
for _, r := range opt.RevisionRanges {
_, tmpErr := db.ExecContext(ctx, fmt.Sprintf(joinQuery, entityDfWithFeatureName, featureNamesStr, featureNamesStr, entityDfName, r.DataTable, opt.Entity.Name), r.MinRevision, r.MaxRevision)
query := fmt.Sprintf(joinQuery, entityDfWithFeatureName, featureNamesStr, featureNamesStr, entityDfName, r.DataTable, opt.Entity.Name)
_, tmpErr := db.ExecContext(ctx, query, r.MinRevision, r.MaxRevision)
if tmpErr != nil {
return nil, tmpErr
}
}

// Step 2: get rows from entity_df_with_features table
resultQuery := fmt.Sprintf(`SELECT * FROM %s`, entityDfWithFeatureName)
resultQuery := fmt.Sprintf(`SELECT * FROM "%s"`, entityDfWithFeatureName)
rows, tmpErr := db.QueryxContext(ctx, resultQuery)
if tmpErr != nil {
return nil, tmpErr
Expand Down
5 changes: 2 additions & 3 deletions internal/database/online/postgres/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/jackc/pgerrcode"
"github.com/jmoiron/sqlx"
Expand All @@ -17,7 +16,7 @@ import (
func (db *DB) Get(ctx context.Context, opt online.GetOpt) (dbutil.RowMap, error) {
featureNames := opt.FeatureList.Names()
tableName := getOnlineBatchTableName(opt.RevisionId)
query := fmt.Sprintf(`SELECT "%s",%s FROM "%s" WHERE "%s" = $1`, opt.EntityName, strings.Join(featureNames, ","), tableName, opt.EntityName)
query := fmt.Sprintf(`SELECT "%s", %s FROM "%s" WHERE "%s" = $1`, opt.EntityName, dbutil.Quote(`"`, featureNames...), tableName, opt.EntityName)

record, err := db.QueryRowxContext(ctx, query, opt.EntityKey).SliceScan()
if err != nil {
Expand Down Expand Up @@ -45,7 +44,7 @@ func (db *DB) Get(ctx context.Context, opt online.GetOpt) (dbutil.RowMap, error)
func (db *DB) MultiGet(ctx context.Context, opt online.MultiGetOpt) (map[string]dbutil.RowMap, error) {
featureNames := opt.FeatureList.Names()
tableName := getOnlineBatchTableName(opt.RevisionId)
query := fmt.Sprintf(`SELECT "%s", %s FROM "%s" WHERE "%s" in (?);`, opt.EntityName, strings.Join(featureNames, ","), tableName, opt.EntityName)
query := fmt.Sprintf(`SELECT "%s", %s FROM "%s" WHERE "%s" in (?);`, opt.EntityName, dbutil.Quote(`"`, featureNames...), tableName, opt.EntityName)
sql, args, err := sqlx.In(query, opt.EntityKeys)
if err != nil {
return nil, err
Expand Down

0 comments on commit a376df0

Please sign in to comment.