-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1597 from dlstadther/sf-patch
Salesforce - Improved Reliability
- Loading branch information
Showing
1 changed file
with
84 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from collections import OrderedDict | ||
import re | ||
import csv | ||
import tempfile | ||
|
||
import luigi | ||
from luigi import Task | ||
|
@@ -39,13 +40,59 @@ | |
|
||
|
||
def get_soql_fields(soql): | ||
soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql) # get fields | ||
soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces | ||
fields = re.split(',|\n|\r', soql_fields) # split on commas and newlines | ||
fields = [field for field in fields if field != ''] # remove empty strings | ||
""" | ||
Gets queried columns names. | ||
""" | ||
soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql, re.IGNORECASE) # get fields | ||
soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces | ||
soql_fields = re.sub('\t', '', soql_fields) # remove tabs | ||
fields = re.split(',|\n|\r|', soql_fields) # split on commas and newlines | ||
fields = [field for field in fields if field != ''] # remove empty strings | ||
return fields | ||
|
||
|
||
def ensure_utf(value): | ||
return value.encode("utf-8") if isinstance(value, unicode) else value | ||
|
||
|
||
def parse_results(fields, data): | ||
""" | ||
Traverses ordered dictionary, calls _traverse_results() to recursively read into the dictionary depth of data | ||
""" | ||
master = [] | ||
|
||
for record in data['records']: # for each 'record' in response | ||
row = [None] * len(fields) # create null list the length of number of columns | ||
for obj, value in record.iteritems(): # for each obj in record | ||
if isinstance(value, basestring): # if query base object has desired fields | ||
if obj in fields: | ||
row[fields.index(obj)] = ensure_utf(value) | ||
|
||
elif isinstance(value, dict) and obj != 'attributes': # traverse down into object | ||
path = obj | ||
_traverse_results(value, fields, row, path) | ||
|
||
master.append(row) | ||
return master | ||
|
||
|
||
def _traverse_results(value, fields, row, path): | ||
""" | ||
Helper method for parse_results(). | ||
Traverses through ordered dict and recursively calls itself when encountering a dictionary | ||
""" | ||
for f, v in value.iteritems(): # for each item in obj | ||
field_name = '{path}.{name}'.format(path=path, name=f) if path else f | ||
|
||
if not isinstance(v, (dict, list, tuple)): # if not data structure | ||
if field_name in fields: | ||
row[fields.index(field_name)] = ensure_utf(v) | ||
|
||
elif isinstance(v, dict) and f != 'attributes': # it is a dict | ||
_traverse_results(v, fields, row, field_name) | ||
|
||
|
||
class salesforce(luigi.Config): | ||
""" | ||
Config system to get config vars from 'salesforce' section in configuration file. | ||
|
@@ -92,48 +139,6 @@ def is_soql_file(self): | |
"""Override to True if soql property is a file path.""" | ||
return False | ||
|
||
def parse_output(self, data): | ||
""" | ||
Traverses ordered dictionary, calls _traverse_output to recursively read into the dictionary depth of data | ||
""" | ||
fields = get_soql_fields(self.soql) | ||
header = fields | ||
|
||
master = [header] | ||
|
||
for record in data['records']: # for each 'record' in response | ||
row = [] | ||
for obj, value in record.iteritems(): # for each obj in record | ||
while len(row) < len(fields): | ||
row.append('') | ||
|
||
if isinstance(value, basestring): # if query base object has desired fields | ||
if obj in fields: | ||
row[fields.index(value)] = value | ||
|
||
elif isinstance(value, dict) and obj != 'attributes': # traverse down into object | ||
path = obj | ||
row.append(self._traverse_output(value, fields, row, path)) | ||
|
||
master.append(row) | ||
return master | ||
|
||
def _traverse_output(self, value, fields, row, path): | ||
""" | ||
Helper method for parse_output(). | ||
Traverses through ordered dict and recursively calls itself when encountering a dictionary | ||
""" | ||
for f, v in value.iteritems(): # for each item in obj | ||
field_name = '{path}.{name}'.format(path=path, name=f) if path else f | ||
|
||
if not isinstance(v, (dict, list, tuple)): # if not data structure | ||
if field_name in fields: | ||
row[fields.index(field_name)] = v | ||
|
||
elif isinstance(v, dict) and f != 'attributes': # it is a dict | ||
self._traverse_output(v, fields, row, field_name) | ||
|
||
def run(self): | ||
if self.use_sandbox and not self.sandbox_name: | ||
raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox") | ||
|
@@ -176,13 +181,13 @@ def run(self): | |
|
||
if 'state_message' in status and 'foreign key relationships not supported' in status['state_message'].lower(): | ||
logger.info("Retrying with REST API query") | ||
data = sf.query_all(self.soql) | ||
|
||
data_csv = self.parse_output(data) | ||
data_file = sf.query_all(self.soql) | ||
|
||
reader = csv.reader(data_file) | ||
with open(self.output().fn, 'w') as outfile: | ||
writer = csv.writer(outfile) | ||
writer.writerows(data_csv) | ||
writer = csv.writer(outfile, dialect='excel') | ||
for row in reader: | ||
writer.writerow(row) | ||
|
||
|
||
class SalesforceAPI(object): | ||
|
@@ -295,33 +300,38 @@ def query_all(self, query, **kwargs): | |
:param query: the SOQL query to send to Salesforce, e.g. | ||
`SELECT Id FROM Lead WHERE Email = "[email protected]"` | ||
""" | ||
def get_all_responses(previous_response, **kwargs): | ||
""" | ||
Inner function for recursing until there are no more results. | ||
Returns the full set of results that will be the return value for | ||
`query_all(...)` | ||
:param previous_response: the modified result of previous calls to | ||
Salesforce for this query | ||
""" | ||
if previous_response['done']: | ||
return previous_response | ||
else: | ||
response = self.query_more(previous_response['nextRecordsUrl'], | ||
identifier_is_url=True, **kwargs) | ||
response['totalSize'] += previous_response['totalSize'] | ||
# Include the new list of records with the previous list | ||
previous_response['records'].extend(response['records']) | ||
response['records'] = previous_response['records'] | ||
if not len(response['records']) % 10000: | ||
logger.info('Requested {0} lines...'.format(len(response['records']))) | ||
# Continue the recursion | ||
return get_all_responses(response, **kwargs) | ||
# Make the initial query to Salesforce | ||
response = self.query(query, **kwargs) | ||
|
||
# get fields | ||
fields = get_soql_fields(query) | ||
|
||
# put fields and first page of results into a temp list to be written to TempFile | ||
tmp_list = [fields] | ||
tmp_list.extend(parse_results(fields, response)) | ||
|
||
tmp_dir = luigi.configuration.get_config().get('salesforce', 'local-tmp-dir', None) | ||
tmp_file = tempfile.TemporaryFile(mode='a+b', dir=tmp_dir) | ||
|
||
writer = csv.writer(tmp_file) | ||
writer.writerows(tmp_list) | ||
|
||
# The number of results might have exceeded the Salesforce batch limit | ||
# so check whether there are more results and retrieve them if so. | ||
return get_all_responses(response, **kwargs) | ||
|
||
length = len(response['records']) | ||
while not response['done']: | ||
response = self.query_more(response['nextRecordsUrl'], identifier_is_url=True, **kwargs) | ||
|
||
writer.writerows(parse_results(fields, response)) | ||
length += len(response['records']) | ||
if not length % 10000: | ||
logger.info('Requested {0} lines...'.format(length)) | ||
|
||
logger.info('Requested a total of {0} lines.'.format(length)) | ||
|
||
tmp_file.seek(0) | ||
return tmp_file | ||
|
||
# Generic Rest Function | ||
def restful(self, path, params): | ||
|