Skip to content

Commit

Permalink
Merge pull request #484 from cmu-phil/joebugfixes
Browse files Browse the repository at this point in the history
joebugfixes
  • Loading branch information
espinoj authored Jun 19, 2017
2 parents ef8cdc8 + 40b2b94 commit c545c4f
Show file tree
Hide file tree
Showing 24 changed files with 231 additions and 56 deletions.
10 changes: 4 additions & 6 deletions tetrad-gui/src/main/java/edu/cmu/tetradapp/Tetrad.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,9 @@ private void launchFrame() {
this.frame = new JFrame(this.mainTitle) {
public Dimension getPreferredSize() {
Dimension size = Toolkit.getDefaultToolkit().getScreenSize();
double width = size.getWidth();
double height = size.getHeight();
double minSize = Math.min(width, height);

height = minSize / 2;
width = minSize * 0.75;
double minLength = Math.min(size.getWidth(), size.getHeight());
double height = minLength * 0.8;
double width = height * (4.0 / 3);

return new Dimension((int) width, (int) height);
// return Toolkit.getDefaultToolkit().getScreenSize();
Expand Down Expand Up @@ -195,6 +192,7 @@ public Dimension getPreferredSize() {

getFrame().setContentPane(getDesktop());
getFrame().pack();
getFrame().setLocationRelativeTo(null);

// This doesn't let the user resize the main window.
// getFrame().setExtendedState(Frame.MAXIMIZED_BOTH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@
*/
public class GeneralAlgorithmEditor extends JPanel implements FinalizingEditor {

// Note: When adding an algorithm, make sure you do all of the following:
// 1. Add a new type to private enum AlgName.
// 2. Add a desription for it to final List<AlgorithmDescription> descriptions.
// 3. In private Algorithm getAlgorithm, add a new case to the switch statement returning
// an instance of the algorithm.

private static final long serialVersionUID = -5719467682865706447L;

private final HashMap<AlgName, AlgorithmDescription> mappedDescriptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import edu.cmu.tetradapp.model.KnowledgeEditable;
import edu.cmu.tetradapp.model.Simulation;
import edu.cmu.tetradapp.util.WatchedProcess;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
Expand Down Expand Up @@ -63,6 +64,7 @@ public final class SimulationEditor extends JPanel implements KnowledgeEditable,
private final JComboBox<String> simulationsDropdown = new JComboBox<>();

//==========================CONSTUCTORS===============================//

/**
* Constructs the data editor with an empty list of data displays.
*/
Expand Down Expand Up @@ -104,11 +106,11 @@ public SimulationEditor(final Simulation simulation) {
tabbedPane.setPreferredSize(new Dimension(900, 600));

final String[] graphItems = new String[]{
"Random Foward DAG",
"Scale Free DAG",
"Cyclic, constructed from small loops",
"Random One Factor MIM",
"Random Two Factor MIM"
"Random Foward DAG",
"Scale Free DAG",
"Cyclic, constructed from small loops",
"Random One Factor MIM",
"Random Two Factor MIM"
};

for (String item : graphItems) {
Expand Down Expand Up @@ -242,15 +244,15 @@ public void actionPerformed(ActionEvent e) {
if (thisOne == null) {
JOptionPane.showMessageDialog((SimulationEditor.this),
"That file was not a simulation, and none of its subdirectories was either. "
+ "\nNeed a directory with a 'data' subdirectory, a 'graph' subdirectory, "
+ "\nand a 'parameters.txt' file.");
+ "\nNeed a directory with a 'data' subdirectory, a 'graph' subdirectory, "
+ "\nand a 'parameters.txt' file.");
return;
}

if (count > 1) {
JOptionPane.showMessageDialog((SimulationEditor.this),
"More than one subdirectory of that directory was a simulation; please select "
+ "\none of the subdirectories.");
+ "\none of the subdirectories.");
return;
}

Expand Down Expand Up @@ -410,6 +412,8 @@ private void resetPanel(Simulation simulation, String[] graphItems, String[] sim
} else if (simulationItem.equals(simulationItems[3])) {
simulation.setSimulation(new LeeHastieSimulation(randomGraph), simulation.getParams());
} else if (simulationItem.equals(simulationItems[4])) {
simulation.setSimulation(new ConditionalGaussianSimulation(randomGraph), simulation.getParams());
} else if (simulationItem.equals(simulationItems[5])) {
simulation.setSimulation(new TimeSeriesSemSimulation(randomGraph), simulation.getParams());
} else {
throw new IllegalArgumentException("Unrecognized simulation type: " + simulationItem);
Expand All @@ -436,6 +440,8 @@ private void resetPanel(Simulation simulation, String[] graphItems, String[] sim
} else if (simulationItem.equals(simulationItems[3])) {
simulation.setSimulation(new LeeHastieSimulation(randomGraph), simulation.getParams());
} else if (simulationItem.equals(simulationItems[4])) {
simulation.setSimulation(new ConditionalGaussianSimulation(randomGraph), simulation.getParams());
} else if (simulationItem.equals(simulationItems[5])) {
simulation.setSimulation(new TimeSeriesSemSimulation(randomGraph), simulation.getParams());
// } else if (simulationItem.equals(simulationItems[6])) {
// simulation.setSimulation(new BooleanGlassSimulation(randomGraph), simulation.getParams());
Expand All @@ -456,25 +462,25 @@ private String[] getSimulationItems(Simulation simulation) {
if (simulation.isFixedSimulation()) {
if (simulation.getSimulation() instanceof BayesNetSimulation) {
simulationItems = new String[]{
"Bayes net",};
"Bayes net",};
} else if (simulation.getSimulation() instanceof SemSimulation) {
simulationItems = new String[]{
"Structural Equation Model"
"Structural Equation Model"
};
} else if (simulation.getSimulation() instanceof LinearFisherModel) {
simulationItems = new String[]{
"Linear Fisher Model"
"Linear Fisher Model"
};
} else if (simulation.getSimulation() instanceof StandardizedSemSimulation) {
simulationItems = new String[]{
"Standardized Structural Equation Model"
"Standardized Structural Equation Model"
};
} else if (simulation.getSimulation() instanceof GeneralSemSimulation) {
simulationItems = new String[]{
"General Structural Equation Model",};
"General Structural Equation Model",};
} else if (simulation.getSimulation() instanceof LoadContinuousDataAndGraphs) {
simulationItems = new String[]{
"Loaded From Files",};
"Loaded From Files",};
} else {
throw new IllegalStateException("Not expecting that model type: "
+ simulation.getSimulation().getClass());
Expand All @@ -487,21 +493,23 @@ private String[] getSimulationItems(Simulation simulation) {
// } else
if (simulation.getSourceGraph() != null) {
simulationItems = new String[]{
"Bayes net",
"Structural Equation Model",
"Linear Fisher Model",
// "General Structural Equation Model Special",
"Lee & Hastie",
"Time Series"
"Bayes net",
"Structural Equation Model",
"Linear Fisher Model",
// "General Structural Equation Model Special",
"Lee & Hastie",
"Conditional Gaussian",
"Time Series"
};
} else {
simulationItems = new String[]{
"Bayes net",
"Structural Equation Model",
"Linear Fisher Model",
// "General Structural Equation Model Special",
"Lee & Hastie",
"Time Series", // "Boolean Glass"
"Bayes net",
"Structural Equation Model",
"Linear Fisher Model",
// "General Structural Equation Model Special",
"Lee & Hastie",
"Conditional Gaussian",
"Time Series", // "Boolean Glass"
};
}
}
Expand All @@ -510,8 +518,8 @@ private String[] getSimulationItems(Simulation simulation) {
}

private Box getParametersPane(Simulation _simulation,
edu.cmu.tetrad.algcomparison.simulation.Simulation simulation,
Parameters parameters) {
edu.cmu.tetrad.algcomparison.simulation.Simulation simulation,
Parameters parameters) {
JScrollPane scroll;

if (simulation != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.session.ParamsResettable;
import edu.cmu.tetrad.session.SessionModel;
import edu.cmu.tetrad.util.Parameters;
Expand Down Expand Up @@ -78,8 +79,15 @@ public KnowledgeBoxModel(KnowledgeBoxInput[] inputs, Parameters params) {
SortedSet<String> variableNames = new TreeSet<>();

for (KnowledgeBoxInput input : inputs) {
variableNodes.addAll(input.getVariables());
variableNames.addAll(input.getVariableNames());
for (Node node : input.getVariables()) {
if (node.getNodeType() == NodeType.MEASURED) {
variableNodes.add(node);
variableNames.add(node.getName());
}
}

// variableNodes.addAll(input.getVariables());
// variableNames.addAll(input.getVariableNames());
}

this.variables = new ArrayList<>(variableNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ public Graph search(DataModel dataSet, Parameters parameters) {
initial = initialGraph.search(dataSet, parameters);
}

Score score = this.score.getScore(DataUtils.getContinuousDataSet(dataSet), parameters);
// Score score = this.score.getScore(DataUtils.getContinuousDataSet(dataSet), parameters);
//
// Score score =

Score score = this.score.getScore(dataSet, parameters);
edu.cmu.tetrad.search.FgesMb search = new edu.cmu.tetrad.search.FgesMb(score);
search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
search.setKnowledge(knowledge);
Expand All @@ -54,7 +58,7 @@ public Graph search(DataModel dataSet, Parameters parameters) {
}

this.targetName = parameters.getString("targetName");
Node target = score.getVariable(targetName);
Node target = this.score.getVariable(targetName);

return search.search(Collections.singletonList(target));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.search.Score;
Expand All @@ -16,9 +17,11 @@
*/
public class BdeuScore implements ScoreWrapper {
static final long serialVersionUID = 23L;
private DataModel dataSet;

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
edu.cmu.tetrad.search.BDeuScore score
= new edu.cmu.tetrad.search.BDeuScore(DataUtils.getDiscreteDataSet(dataSet));
score.setSamplePrior(parameters.getDouble("samplePrior"));
Expand All @@ -43,4 +46,9 @@ public List<String> getParameters() {
parameters.add("structurePrior");
return parameters;
}

@Override
public Node getVariable(String name) {
return dataSet.getVariable(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.search.ConditionalGaussianScore;
Expand All @@ -18,9 +19,11 @@
*/
public class ConditionalGaussianBicScore implements ScoreWrapper, Experimental {
static final long serialVersionUID = 23L;
private DataModel dataSet;

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
final ConditionalGaussianScore conditionalGaussianScore
= new ConditionalGaussianScore(DataUtils.getMixedDataSet(dataSet), parameters.getDouble("structurePrior"), parameters.getBoolean("discretize"));
conditionalGaussianScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
Expand Down Expand Up @@ -48,4 +51,9 @@ public List<String> getParameters() {
parameters.add("discretize");
return parameters;
}

@Override
public Node getVariable(String name) {
return dataSet.getVariable(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.ConditionalGaussianOtherScore;
import edu.cmu.tetrad.search.ConditionalGaussianScore;
import edu.cmu.tetrad.search.Score;
Expand All @@ -19,9 +20,11 @@
*/
public class ConditionalGaussianOtherBicScore implements ScoreWrapper, Experimental {
static final long serialVersionUID = 23L;
private DataModel dataSet;

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
final ConditionalGaussianOtherScore conditionalGaussianScore
= new ConditionalGaussianOtherScore(DataUtils.getMixedDataSet(dataSet), parameters.getDouble("structurePrior"), parameters.getBoolean("discretize"));
conditionalGaussianScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
Expand Down Expand Up @@ -49,4 +52,9 @@ public List<String> getParameters() {
parameters.add("discretize");
return parameters;
}

@Override
public Node getVariable(String name) {
return dataSet.getVariable(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.search.Score;
Expand All @@ -16,9 +17,11 @@
*/
public class DiscreteBicScore implements ScoreWrapper {
static final long serialVersionUID = 23L;
private DataModel dataSet;

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
edu.cmu.tetrad.search.BicScore score
= new edu.cmu.tetrad.search.BicScore(DataUtils.getDiscreteDataSet(dataSet));
score.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
Expand All @@ -41,4 +44,9 @@ public List<String> getParameters() {
paramDescriptions.add("penaltyDiscount");
return paramDescriptions;
}

@Override
public Node getVariable(String name) {
return dataSet.getVariable(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.*;
import edu.cmu.tetrad.util.Parameters;

Expand All @@ -17,13 +18,15 @@
public class DseparationScore implements ScoreWrapper {
static final long serialVersionUID = 23L;
private final RandomGraph randomGraph;
private DataModel dataSet;

public DseparationScore(RandomGraph randomGraph) {
this.randomGraph = randomGraph;
}

@Override
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
if (dataSet == null) {
return new GraphScore(randomGraph.createGraph(parameters));
} else {
Expand All @@ -46,4 +49,9 @@ public List<String> getParameters() {
return new ArrayList<>();
}

@Override
public Node getVariable(String name) {
return dataSet.getVariable(name);
}

}
Loading

0 comments on commit c545c4f

Please sign in to comment.