package edu.rice.hj.example.comp322.assignments.hw4; import edu.rice.hj.api.SuspendableException; import java.util.ArrayList; import java.util.Iterator; import static edu.rice.hj.Module1.*; import static edu.rice.hj.Module2.isolated; public class ConstraintSatisfactionSol { IConstraintSystem problem; FVT fvt; int thresholdSize = 1; public ConstraintSatisfactionSol(IConstraintSystem inputProblem) { problem = inputProblem; } public void solve() { double t1 = System.nanoTime() / 1e9; //System.out.println(problem.getInitialFvt()); // Start the parallel search search(problem.getInitialState(), 0, problem.getInitialFvt()); double t2 = System.nanoTime() / 1e9; java.text.DecimalFormat df = new java.text.DecimalFormat("#.##"); // Print out the final output System.out.println("Time = " + df.format(t2 - t1) + " sec"); } public void search(ProblemState state, int curVar, FVT fvt) { if (curVar == fvt.getNumVars()) { problem.addSolution(state); //System.out.println(state); //System.exit(0); } else { Iterator itr = fvt.getValues(curVar).iterator(); while (itr.hasNext()) { Integer v = itr.next(); ProblemState newState = state.copy(); newState.setValue(curVar, v); FVT newFvt = forwardCheck(curVar, v, fvt); if (newFvt != null) { search(newState, curVar + 1, newFvt); } } } } public void searchForParallel(ProblemState state, int curVar, FVT fvt) { if (curVar == fvt.getNumVars()) { isolated(() -> problem.addSolution(state)); //System.out.println(state); //System.exit(0); } else { Iterator itr = fvt.getValues(curVar).iterator(); while (itr.hasNext()) { Integer v = itr.next(); ProblemState newState = state.copy(); newState.setValue(curVar, v); FVT newFvt = forwardCheck(curVar, v, fvt); if (newFvt != null) { searchForParallel(newState, curVar + 1, newFvt); } } } } public void parallelSolve() throws SuspendableException { double t1 = System.nanoTime() / 1e9; // Start the parallel search finish(() -> { parallelSearch(problem.getInitialState(), 0, problem.getInitialFvt()); }); double t2 = System.nanoTime() / 1e9; java.text.DecimalFormat df = new java.text.DecimalFormat("#.##"); // Print out the final output System.out.println("Time = " + df.format(t2 - t1) + " sec"); } public void parallelSearch(ProblemState state, int curVar, FVT fvt) throws SuspendableException { if (curVar == fvt.getNumVars()) { isolated(() -> problem.addSolution(state)); //System.out.println(state); //System.exit(0); } else { ArrayList itr = new ArrayList<>(fvt.getValues(curVar)); forall(0, numWorkerThreads() - 1, (i) -> { for (int j = i; j < itr.size(); j = j + numWorkerThreads()) { int v = itr.get(j); ProblemState newState = state.copy(); newState.setValue(curVar, v); FVT newFvt = forwardCheck(curVar, v, fvt); if (newFvt != null) { searchForParallel(newState, curVar + 1, newFvt); } } }); } } public FVT forwardCheck(int curVar, Integer curVal, FVT fvt) { FVT newFvt = new FVT(fvt.getNumVars()); for (int freeVar = curVar + 1; freeVar < fvt.getNumVars(); freeVar++) { Iterator itr = fvt.getValues(freeVar).iterator(); while (itr.hasNext()) { Integer v = itr.next(); if (problem.isConsistent(curVar, curVal, freeVar, v)) { newFvt.addValue(freeVar, v); } } if (newFvt.getValues(freeVar).size() == 0) { return null; } } return newFvt; } }