Skip to content

Commit

Permalink
Merge pull request #1597 from dlstadther/sf-patch
Browse files Browse the repository at this point in the history
Salesforce - Improved Reliability
  • Loading branch information
erikbern committed Mar 16, 2016
2 parents a43f374 + f6f3b80 commit df8783b
Showing 1 changed file with 84 additions and 74 deletions.
158 changes: 84 additions & 74 deletions luigi/contrib/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections import OrderedDict
import re
import csv
import tempfile

import luigi
from luigi import Task
Expand All @@ -39,13 +40,59 @@


def get_soql_fields(soql):
soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql) # get fields
soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces
fields = re.split(',|\n|\r', soql_fields) # split on commas and newlines
fields = [field for field in fields if field != ''] # remove empty strings
"""
Gets queried columns names.
"""
soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql, re.IGNORECASE) # get fields
soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces
soql_fields = re.sub('\t', '', soql_fields) # remove tabs
fields = re.split(',|\n|\r|', soql_fields) # split on commas and newlines
fields = [field for field in fields if field != ''] # remove empty strings
return fields


def ensure_utf(value):
return value.encode("utf-8") if isinstance(value, unicode) else value


def parse_results(fields, data):
"""
Traverses ordered dictionary, calls _traverse_results() to recursively read into the dictionary depth of data
"""
master = []

for record in data['records']: # for each 'record' in response
row = [None] * len(fields) # create null list the length of number of columns
for obj, value in record.iteritems(): # for each obj in record
if isinstance(value, basestring): # if query base object has desired fields
if obj in fields:
row[fields.index(obj)] = ensure_utf(value)

elif isinstance(value, dict) and obj != 'attributes': # traverse down into object
path = obj
_traverse_results(value, fields, row, path)

master.append(row)
return master


def _traverse_results(value, fields, row, path):
"""
Helper method for parse_results().
Traverses through ordered dict and recursively calls itself when encountering a dictionary
"""
for f, v in value.iteritems(): # for each item in obj
field_name = '{path}.{name}'.format(path=path, name=f) if path else f

if not isinstance(v, (dict, list, tuple)): # if not data structure
if field_name in fields:
row[fields.index(field_name)] = ensure_utf(v)

elif isinstance(v, dict) and f != 'attributes': # it is a dict
_traverse_results(v, fields, row, field_name)


class salesforce(luigi.Config):
"""
Config system to get config vars from 'salesforce' section in configuration file.
Expand Down Expand Up @@ -92,48 +139,6 @@ def is_soql_file(self):
"""Override to True if soql property is a file path."""
return False

def parse_output(self, data):
"""
Traverses ordered dictionary, calls _traverse_output to recursively read into the dictionary depth of data
"""
fields = get_soql_fields(self.soql)
header = fields

master = [header]

for record in data['records']: # for each 'record' in response
row = []
for obj, value in record.iteritems(): # for each obj in record
while len(row) < len(fields):
row.append('')

if isinstance(value, basestring): # if query base object has desired fields
if obj in fields:
row[fields.index(value)] = value

elif isinstance(value, dict) and obj != 'attributes': # traverse down into object
path = obj
row.append(self._traverse_output(value, fields, row, path))

master.append(row)
return master

def _traverse_output(self, value, fields, row, path):
"""
Helper method for parse_output().
Traverses through ordered dict and recursively calls itself when encountering a dictionary
"""
for f, v in value.iteritems(): # for each item in obj
field_name = '{path}.{name}'.format(path=path, name=f) if path else f

if not isinstance(v, (dict, list, tuple)): # if not data structure
if field_name in fields:
row[fields.index(field_name)] = v

elif isinstance(v, dict) and f != 'attributes': # it is a dict
self._traverse_output(v, fields, row, field_name)

def run(self):
if self.use_sandbox and not self.sandbox_name:
raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox")
Expand Down Expand Up @@ -176,13 +181,13 @@ def run(self):

if 'state_message' in status and 'foreign key relationships not supported' in status['state_message'].lower():
logger.info("Retrying with REST API query")
data = sf.query_all(self.soql)

data_csv = self.parse_output(data)
data_file = sf.query_all(self.soql)

reader = csv.reader(data_file)
with open(self.output().fn, 'w') as outfile:
writer = csv.writer(outfile)
writer.writerows(data_csv)
writer = csv.writer(outfile, dialect='excel')
for row in reader:
writer.writerow(row)


class SalesforceAPI(object):
Expand Down Expand Up @@ -295,33 +300,38 @@ def query_all(self, query, **kwargs):
:param query: the SOQL query to send to Salesforce, e.g.
`SELECT Id FROM Lead WHERE Email = "[email protected]"`
"""
def get_all_responses(previous_response, **kwargs):
"""
Inner function for recursing until there are no more results.
Returns the full set of results that will be the return value for
`query_all(...)`
:param previous_response: the modified result of previous calls to
Salesforce for this query
"""
if previous_response['done']:
return previous_response
else:
response = self.query_more(previous_response['nextRecordsUrl'],
identifier_is_url=True, **kwargs)
response['totalSize'] += previous_response['totalSize']
# Include the new list of records with the previous list
previous_response['records'].extend(response['records'])
response['records'] = previous_response['records']
if not len(response['records']) % 10000:
logger.info('Requested {0} lines...'.format(len(response['records'])))
# Continue the recursion
return get_all_responses(response, **kwargs)
# Make the initial query to Salesforce
response = self.query(query, **kwargs)

# get fields
fields = get_soql_fields(query)

# put fields and first page of results into a temp list to be written to TempFile
tmp_list = [fields]
tmp_list.extend(parse_results(fields, response))

tmp_dir = luigi.configuration.get_config().get('salesforce', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(mode='a+b', dir=tmp_dir)

writer = csv.writer(tmp_file)
writer.writerows(tmp_list)

# The number of results might have exceeded the Salesforce batch limit
# so check whether there are more results and retrieve them if so.
return get_all_responses(response, **kwargs)

length = len(response['records'])
while not response['done']:
response = self.query_more(response['nextRecordsUrl'], identifier_is_url=True, **kwargs)

writer.writerows(parse_results(fields, response))
length += len(response['records'])
if not length % 10000:
logger.info('Requested {0} lines...'.format(length))

logger.info('Requested a total of {0} lines.'.format(length))

tmp_file.seek(0)
return tmp_file

# Generic Rest Function
def restful(self, path, params):
Expand Down

0 comments on commit df8783b

Please sign in to comment.