Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Parallel Coordinates and Mosaic widgets #124

Merged
merged 15 commits into from
Nov 6, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Orange/data/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def var_from_domain(self, var, check_included=False, no_index=False):
"""
if isinstance(var, str):
if not var in self.indices:
raise IndexError("Variable '%s' is not in the domain", var)
raise IndexError("Variable '%s' is not in the domain %s" % (var, self))
idx = self.indices[var]
return self._variables[idx] if idx >= 0 else self._metas[-1 - idx]

Expand All @@ -187,7 +187,7 @@ def var_from_domain(self, var, check_included=False, no_index=False):
if each is var:
return var
raise IndexError(
"Variable '%s' is not in the domain", var.name)
"Variable '%s' is not in the domain %s" % (var.name, self))
else:
return var

Expand Down
2 changes: 2 additions & 0 deletions Orange/data/sql/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def _get_field_values(self, field_name, field_type):
def _get_distinct_values(self, field_name):
sql = " ".join(["SELECT DISTINCT", self.quote_identifier(field_name),
"FROM", self.table_name,
"WHERE {} IS NOT NULL".format(
self.quote_identifier(field_name)),
"ORDER BY", self.quote_identifier(field_name),
"LIMIT 21"])
with self._execute_sql_query(sql) as cur:
Expand Down
3 changes: 3 additions & 0 deletions Orange/data/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def to_val(self, s):
:param s: values, represented as a number, string or `None`
:rtype: float
"""
if s is None:
return Unknown

if self.has_numeric_values:
s = str(s)

Expand Down
9 changes: 5 additions & 4 deletions Orange/widgets/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,11 @@ def settings_from_widget(self, widget):
return

def packer(setting, instance):
value = getattr(instance, setting.name)
yield setting.name, self.encode_setting(context, setting, value)
if hasattr(setting, "selected"):
yield setting.selected, list(getattr(instance, setting.selected))
if hasattr(instance, setting.name):
value = getattr(instance, setting.name)
yield setting.name, self.encode_setting(context, setting, value)
if hasattr(setting, "selected"):
yield setting.selected, list(getattr(instance, setting.selected))

context.values = self.provider.pack(widget, packer=packer)

Expand Down
48 changes: 19 additions & 29 deletions Orange/widgets/visualize/owmosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,6 @@ def setData(self, data):
if not self.data:
return

if not self.data.domain.class_var:
self.warning(0, "Data does not have a class variable.")
return
else:
self.warning(0)

if any(isinstance(attr, ContinuousVariable) for attr in self.data.domain):
# previously done in optimizationDlg.setData()
self.data = DiscretizeTable(data, method=EqualFreq())
Expand Down Expand Up @@ -552,29 +546,25 @@ def updateGraph(self, data=-1, subsetData=-1, attrList=-1, **args):
# create a dictionary with all possible pairs of "combination-of-attr-values" : count
## TODO: this function is used both in owmosaic and owsieve --> where to put it?
def getConditionalDistributions(self, data, attrs):
if type(data) == SqlTable:
cond_dist = defaultdict(lambda: 0)
var_attrs = [data.domain[a] for a in attrs]
# make all possible pairs of attributes + class_var
for i in range(0, len(var_attrs) + 1):
attr = [v.to_sql() for v in var_attrs[:i + 1]]
if i == len(var_attrs):
attr.append(data.domain.class_var.to_sql())
cond_dist = defaultdict(int)
all_attrs = [data.domain[a] for a in attrs]
if data.domain.class_var is not None:
all_attrs.append(data.domain.class_var)

for i in range(1, len(all_attrs)+1):
attr = all_attrs[:i]
if type(data) == SqlTable:
# make all possible pairs of attributes + class_var
attr = [a.to_sql() for a in attr]
fields = attr + ["COUNT(*)"]
query = data._sql_query(fields, group_by=attr)
with data._execute_sql_query(query) as cur:
res = cur.fetchall()
for r in list(res):
cond_dist['-'.join(r[:-1])] = r[-1]
else:
cond_dist = {}
for i in range(0, len(attrs) + 1):
attr = []
for j in range(0, i + 1):
if j == len(attrs):
attr.append(data.domain.class_var)
else:
attr.append(data.domain[attrs[j]])
for r in res:
str_values =[a.repr_val(a.to_val(x)) for a, x in zip(all_attrs, r[:-1])]
str_values = [x if x != '?' else 'None' for x in str_values]
cond_dist['-'.join(str_values)] = r[-1]
else:
for indices in product(*(range(len(a.values)) for a in attr)):
vals = []
conditions = []
Expand Down Expand Up @@ -757,13 +747,13 @@ def addRect(self, x0, x1, y0, y1, condition="", used_attrs=[], used_vals=[], att
if used_vals == [vals[self.activeRule[0].index(a)] for a in used_attrs]:
values = list(
self.attributeValuesDict.get(self.data.domain.classVar.name, [])) or get_variable_values_sorted(
self.data.domain.classVar)
self.data.domain.class_var)
counts = [self.conditionalDict[attrVals + "-" + val] for val in values]
d = 2
r = OWCanvasRectangle(self.canvas, x0 - d, y0 - d, x1 - x0 + 2 * d + 1, y1 - y0 + 2 * d + 1, z=50)
r.setPen(QPen(self.colorPalette[counts.index(max(counts))], 2, Qt.DashLine))

aprioriDist = None
aprioriDist = ()
pearson = None
expected = None
outerRect = OWCanvasRectangle(self.canvas, x0, y0, x1 - x0, y1 - y0, z=30)
Expand Down Expand Up @@ -980,8 +970,8 @@ def setColors(self):
self.color_settings = dlg.getColorSchemas()
self.selected_schema_index = dlg.selectedSchemaIndex
self.colorPalette = dlg.getDiscretePalette("discPalette")
if self.data and self.data.domain.classVar and isinstance(self.data.domain.classVar, DiscreteVariable):
self.colorPalette.set_number_of_colors(len(self.data.domain.classVar.values))
if self.data and self.data.domain.class_var and isinstance(self.data.domain.class_var, DiscreteVariable):
self.colorPalette.set_number_of_colors(len(self.data.domain.class_var.values))
self.updateGraph()

def createColorDialog(self):
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/visualize/owmpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def rank(self):
[disc(self.data, attr) if type(attr) == Orange.data.variable.ContinuousVariable
else attr for attr in self.data.domain.attributes], self.data.domain.class_vars)

t = Orange.data.Table(ndomain, self.data)
t = self.data.from_table(ndomain, self.data)

attrs = t.domain.attributes

Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/visualize/owparallelcoordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create_control_panel(self):
self.add_zoom_select_toolbar(self.general_tab)

self.add_visual_settings(self.settings_tab)
self.add_annotation_settings(self.settings_tab)
#self.add_annotation_settings(self.settings_tab)
self.add_color_settings(self.settings_tab)
self.add_group_settings(self.settings_tab)

Expand Down
61 changes: 43 additions & 18 deletions Orange/widgets/visualize/owparallelgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def update_data(self, attributes, mid_labels=None):
self.show_statistics = False
self.draw_groups()
else:
self.show_statistics = True
self.show_statistics = False
self.draw_curves()
self.draw_distributions()
self.draw_axes()
Expand Down Expand Up @@ -164,15 +164,20 @@ def is_selected(example):

diff, mins = [], []
for i in self.attribute_indices:
diff.append(self.domain_data_stat[i].max - self.domain_data_stat[i].min or 1)
mins.append(self.domain_data_stat[i].min)
var = self.data_domain[i]
if isinstance(var, DiscreteVariable):
diff.append(len(var.values))
mins.append(-0.5)
else:
diff.append(self.domain_data_stat[i].max - self.domain_data_stat[i].min or 1)
mins.append(self.domain_data_stat[i].min)

def scale_row(row):
return [(x - m) / d for x, m, d in zip(row, mins, diff)]

for row_idx, row in enumerate(self.data[:, self.attribute_indices]):
#if not self.valid_data[row_idx]:
# continue
if any(np.isnan(v) for v in row.x):
continue

color = self.select_color(row_idx)

Expand Down Expand Up @@ -214,8 +219,13 @@ def draw_groups(self):

diff, mins = [], []
for i in self.attribute_indices:
diff.append(self.domain_data_stat[i].max - self.domain_data_stat[i].min or 1)
mins.append(self.domain_data_stat[i].min)
var = self.data_domain[i]
if isinstance(var, DiscreteVariable):
diff.append(len(var.values))
mins.append(-0.5)
else:
diff.append(self.domain_data_stat[i].max - self.domain_data_stat[i].min or 1)
mins.append(self.domain_data_stat[i].min)

for j, (phi, cluster_mus, cluster_sigma) in enumerate(zip(phis, mus, sigmas)):
for i, (mu1, sigma1, mu2, sigma2), in enumerate(
Expand Down Expand Up @@ -268,23 +278,22 @@ def draw_mid_labels(self, mid_labels):

def draw_statistics(self):
"""Draw lines that represent standard deviation or quartiles"""
return # TODO: Implement using BasicStats
if self.show_statistics and self.have_data:
n_attr = len(self.attributes)
data = []
for attr_idx in self.attribute_indices:
if not isinstance(self.data_domain[attr_idx], ContinuousVariable):
data.append([()])
continue # only for continuous attributes

if not self.data_has_class or self.data_has_continuous_class: # no class
attr_values = self.no_jittering_scaled_data[attr_idx]
attr_values = attr_values[~np.isnan(attr_values)]

if self.show_statistics == MEANS:
m = attr_values.mean()
dev = attr_values.std()
m = self.domain_data_stat[attr_idx].mean
dev = self.domain_data_stat[attr_idx].var
data.append([(m - dev, m, m + dev)])
elif self.show_statistics == MEDIAN:
data.append([(0, 0, 0)]); continue

sorted_array = np.sort(attr_values)
if len(sorted_array) > 0:
data.append([(sorted_array[int(len(sorted_array) / 4.0)],
Expand All @@ -295,6 +304,7 @@ def draw_statistics(self):
else:
curr = []
class_values = get_variable_values_sorted(self.data_domain.class_var)

for c in range(len(class_values)):
attr_values = self.data[attr_idx, self.data[self.data_class_index] == c]
attr_values = attr_values[~np.isnan(attr_values)]
Expand Down Expand Up @@ -768,10 +778,22 @@ def create_contingencies(X, callback=None):
dim = len(X.domain)

X_ = DiscretizeTable(X, method=EqualFreq(n=10))
vals = [[tuple(map(str.strip, v.strip('[]()<>=').split(','))) for v in var.values]
for var in X_.domain]
m = [{i: (float(v[0]) if len(v) == 1 else (float(v[0]) + (float(v[1]) - float(v[0])) / 2))
for i, v in enumerate(val)} for val in vals]
m = []
for i, var in enumerate(X_.domain):
cleaned_values = [tuple(map(str.strip, v.strip('[]()<>=').split(',')))
for v in var.values]
try:
float_values = [[float(v) for v in vals] for vals in cleaned_values]
bin_centers = {
i: v[0] if len(v) == 1 else v[0] + (v[1] - v[0])
for i, v in enumerate(float_values)
}
except ValueError:
bin_centers = {
i: i
for i, v in enumerate(cleaned_values)
}
m.append(bin_centers)

from Orange.data.sql.table import SqlTable
if isinstance(X, SqlTable):
Expand All @@ -792,6 +814,8 @@ def create_contingencies(X, callback=None):
else:
conts = [defaultdict(float) for i in range(len(X_.domain))]
for i, r in enumerate(X_):
if any(np.isnan(r)):
continue
row = tuple(m[vi].get(v) for vi, v in enumerate(r))
for l in range(len(X_.domain)):
lower = l - window_size if l - window_size >= 0 else None
Expand All @@ -817,8 +841,9 @@ def convert(row):
for i, v in enumerate(row)]

group_by = [a.to_sql() for a in (X.domain[c] for c in columns)]
filters = ['%s IS NOT NULL' % a for a in group_by]
fields = group_by + ['COUNT(%s)' % group_by[0]]
query = X._sql_query(fields, group_by=group_by)
query = X._sql_query(fields, group_by=group_by, filters=filters)
with X._execute_sql_query(query) as cur:
cont = np.array(list(map(convert, cur.fetchall())), dtype='float')
return cont[:, :-1], cont[:, -1:].flatten()