package edu.rice.hj.example.comp322.labs.lab5; import edu.rice.hj.api.HjFuture; import edu.rice.hj.api.HjPoint; import edu.rice.hj.api.SuspendableException; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Random; import static edu.rice.hj.Module1.*; /* * HJ version ported from Java version at http://www.csse.monash.edu.au/~lloyd/tildeProgLang/Java2/Exp/. * * @author Vincent Cave * @author Sanjay Chatterjee * @author Max Grossman * @author Vivek Sarkar (vsarkar@rice.edu) * * See below for acknowledgments for original code * * L. Allison, September 2001, * School of Computer Science and Software Engineering, * Monash University, Australia 3800. * * Released under the GNU General Public License Version 2, June 1991. */ /** * Class defining Matrix Evaluation */ public class MatrixEvalFuture { /** * Adds the 2 input matrices and returns the result * * @param a - first matrix * @param b - second matrix * @return - the sum matrix */ public static Matrix matrixAdd(Matrix a, Matrix b) { int rows = a.rows; int cols = a.cols; assert (b.rows == rows && b.cols == cols) : "a and b are not conformable for operation +"; Matrix result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion2D(0, rows - 1, 0, cols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = a.data[i][j] + b.data[i][j]; } return result; } /** * Subtracts a matrix from the other * * @param a - first matrix * @param b - second matrix * @return - the difference matrix */ public static Matrix matrixMinus(Matrix a, Matrix b) { int rows = a.rows; int cols = a.cols; assert (b.rows == rows && b.cols == cols) : "a and b are not conformable for operation -"; Matrix result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion2D(0, rows - 1, 0, cols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = a.data[i][j] - b.data[i][j]; } return result; } /** * Finds the product of 2 matrices * * @param a - first matrix * @param b - second matrix * @return - the product matrix */ public static Matrix matrixMultiply(Matrix a, Matrix b) { int arows = a.rows; int acols = a.cols; int brows = b.rows; int bcols = b.cols; if (acols != brows) { Expression.error("Invalid dimensions for matrix multiply"); } // a and b are not conformable Matrix result = new Matrix(arows, bcols); for (HjPoint point : newRectangularRegion2D(0, arows - 1, 0, bcols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = 0; for (HjPoint pointk : newRectangularRegion1D(0, acols - 1).toSeqIterable()) { int k = pointk.get(0); result.data[i][j] += a.data[i][k] * b.data[k][j]; } } return result; } /** * Identity matrix * * @param a - input matrix * @return - identity matrix */ public static Matrix matrixId(Matrix a) { int rows = a.rows; int cols = a.cols; Matrix result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion2D(0, rows - 1, 0, cols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = a.data[i][j]; } return result; } /** * Finds the negation of the input matrix * * @param a - input matrix * @return - the negated matrix */ public static Matrix matrixNeg(Matrix a) { int rows = a.rows; int cols = a.cols; Matrix result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion2D(0, rows - 1, 0, cols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = -a.data[i][j]; } return result; } /** * Prints the input matrix * * @param a - matrix to be printed */ public static void printMatrix(Matrix a) { int rows = a.rows; int cols = a.cols; System.out.println("["); for (HjPoint pointi : newRectangularRegion1D(0, rows - 1).toSeqIterable()) { int i = pointi.get(0); System.out.print(" [ "); for (HjPoint pointj : newRectangularRegion1D(0, cols - 1).toSeqIterable()) { int j = pointj.get(0); System.out.print(a.data[i][j] + " "); } System.out.println("]"); } System.out.println("]"); } /** * Main function * * @param argv - name of file containing the matrix expression */ public static void main(String[] argv) { launchHabaneroApp(() -> { if (argv.length != 1) { System.out.println("usage: java MatrixEvalFuture input_file_name"); return; } FileInputStream in = null; Expression e = null; try { in = new FileInputStream(argv[0]); Syntax syn = new Syntax(new Lexical(in)); e = syn.exp(); // parse input expression } catch (IOException io) { io.printStackTrace(); return; } System.out.print("Input expression:"); System.out.println(e.toString()); // print input expression for (int iter = 0; iter < 5; iter++) { long start = System.nanoTime(); Matrix[] result = {null}; final Expression exp = e; finish(() -> { result[0] = exp.eval(); }); long end = System.nanoTime(); System.out.println("Run " + iter); System.out.println("result[0][0] = " + result[0].data[0][0]); System.out.println("Time taken for expression evaluation = " + (end - start) / 1000000 + " milliseconds"); // printMatrix(result); } }); }// main /** * Class defining a Matrix */ public static class Matrix { /** * Number of columns */ public final int cols; /** * Number of rows */ public final int rows; /** * Elements of the matrix */ public final int[][] data; /** * Constructor for matrix * * @param rows - number of rows * @param cols - number of columns */ public Matrix(int rows, int cols) { this.rows = rows; this.cols = cols; this.data = new int[rows][cols]; } } /** * Class defining expression parse-trees */ private static abstract class Expression { /** * Error * * @param msg - error message */ public static void error(String msg) { System.out.println("\nError: " + msg); System.exit(1); } /** * Abstract function for evaluating the expression * * @return the evaluated value */ public abstract Matrix eval() throws SuspendableException; /** * To string * * @return - expression string */ public String toString() { StringBuffer sb = new StringBuffer(); // efficiency! appendSB(sb); return sb.toString(); } public abstract void appendSB(StringBuffer sb); // printing /** * This class defines a leaf expression. A leaf could be an identity matrix or a random matrix */ public static class Ident extends Expression { /** * string representing the matrix */ public final String id; /** * Number of rows */ public int rows; /** * Number of columns */ public int cols; /** * Seed for random matrix */ public int seed; /** * Constructor for class IDent * * @param id - string representing the leaf expression */ public Ident(String id) { this.id = id; extractDims(id); } /** * Extract the dimensions and seed from the input string * * @param id - string representing the matrix */ public void extractDims(String id) { int indexOfM = id.indexOf('m'); if (indexOfM != 0) { error("indexOfM != 0"); } // ident must start with 'm' int indexOfX = id.indexOf('x'); if (indexOfX == -1) { // identity matrix case rows = Integer.parseInt(id.substring(indexOfM + 1)); if (rows <= 0) { error("rows <= 0"); } cols = -1; seed = -1; } else { // random matrix case int indexOfS = id.indexOf('s'); if (indexOfX >= indexOfS) { error("indexOfX >= indexOfS"); } rows = Integer.parseInt(id.substring((indexOfM + 1), indexOfX)); if (rows <= 0) { error("rows <= 0"); } ; cols = Integer.parseInt(id.substring((indexOfX + 1), indexOfS)); if (cols <= 0) { error("cols <= 0"); } ; seed = Integer.parseInt(id.substring(indexOfS + 1)); } } /** * Append to string buffer * * @param sb - buffer */ public void appendSB(StringBuffer sb) { sb.append(id); } /** * Initialize the matrix with random values or creates an identity maatrix * * @return the evaluated value */ public Matrix eval() { Matrix result; if (cols == -1) { // identity matrix case cols = rows; result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion1D(0, rows - 1).toSeqIterable()) { int i = point.get(0); result.data[i][i] = 1; } } else { // random matrix case Random r = new Random(seed); result = new Matrix(rows, cols); for (HjPoint point : newRectangularRegion2D(0, rows - 1, 0, cols - 1).toSeqIterable()) { int i = point.get(0); int j = point.get(1); result.data[i][j] = r.nextInt(); } } return result; } } /** * Class to represent an integer constant expression */ public static class IntCon extends Expression { /** * Value of integer constant */ public final int n; /** * Constructor for IntCon * * @param n - value of int constant */ public IntCon(int n) { this.n = n; } /** * Appends to string buffer * * @param sb - string buffer */ public void appendSB(StringBuffer sb) { sb.append(String.valueOf(n)); } /** * Evaluation of integer scalar is an error * * @return the evaluated value */ public Matrix eval() { error("Unhandled Integer scalar"); return null; } } /** * Class defining a unary expression */ public static class Unary extends Expression { /** * the unary operator */ public final int opr; /** * the expression argument */ public final Expression e; // e.g. -7 /** * Constructor for unary expression * * @param opr - operator * @param e - expression */ public Unary(int opr, Expression e) { this.e = e; this.opr = opr; } /** * Append to string buffer * * @param sb - string buffer */ public void appendSB(StringBuffer sb) { sb.append("(" + Lexical.Symbol[opr] + " "); e.appendSB(sb); sb.append(")"); } /** * Evaluate the unary expression * * @return the evaluated value */ public Matrix eval() throws SuspendableException { final Matrix m = e.eval(); Matrix result = null; switch (opr) { case Lexical.plus: result = MatrixEvalFuture.matrixId(m); break; case Lexical.minus: result = MatrixEvalFuture.matrixNeg(m); break; default: error("Unhandled Unary operator"); } return result; } } /** * Class defining a binary expression */ public static class Binary extends Expression // Binary { /** * binary operator */ public final int opr; /** * left expression */ public final Expression lft; /** * right expression */ public final Expression rgt; /** * Constructor for Binary * * @param opr - operator * @param lft - left expression * @param rgt - right expression */ public Binary(int opr, Expression lft, Expression rgt) { this.opr = opr; this.lft = lft; this.rgt = rgt; } /** * Appends to string buffer * * @param sb - string buffer */ public void appendSB(StringBuffer sb) { sb.append("("); lft.appendSB(sb); sb.append(" " + Lexical.Symbol[opr] + " "); rgt.appendSB(sb); sb.append(")"); } /** * Evaluates the binary expression by recursively evaluating left and right expressions * * @return the evaluated value */ public Matrix eval() throws SuspendableException { HjFuture lft_eval = future(() -> lft.eval()); HjFuture rgt_eval = future(() -> rgt.eval()); Matrix result = null; switch (opr) { case Lexical.plus: result = MatrixEvalFuture.matrixAdd(lft_eval.get(), rgt_eval.get()); break; case Lexical.minus: result = MatrixEvalFuture.matrixMinus(lft_eval.get(), rgt_eval.get()); break; case Lexical.times: result = MatrixEvalFuture.matrixMultiply(lft_eval.get(), rgt_eval.get()); break; default: error("Unhandled binary operator"); } return result; } } } /** * This class is the lexical processor of symbols */ private static class Lexical { /** * symbol codes... */ public static final int word = 0, numeral = 1, open = 2, // ( close = 3, // ) plus = 4, // + minus = 5, // - times = 6, // * over = 7, // / eofSy = 8; public static final String[] Symbol = new String[]{"", "", "(", ")", "+", "-", "*", "/", ""};// Symbol /** * input stream */ InputStream inp; /** * Lexical state variables */ int sy = -1; char ch = ' '; byte[] buffer = new byte[1]; boolean eof = false; String theWord = ""; int theInt = 666; /** * Constructor for lexical processor * * @param inp - input stream */ public Lexical(InputStream inp) { this.inp = inp; insymbol(); } /** * get the next symbol from the input stream */ public void insymbol() { if (sy == eofSy) { return; } while (ch == ' ') { getch(); // skip white space } if (eof) { sy = eofSy; } else if (Character.isLetter(ch)) // words { StringBuffer w = new StringBuffer(); while (Character.isLetterOrDigit(ch)) { try { //Workaround compiler bug if (false) { throw new IOException(); } w.append(ch); } catch (IOException e) { } getch(); } theWord = w.toString(); sy = word; } else if (Character.isDigit(ch)) // numbers { theInt = 0; while (Character.isDigit(ch)) { theInt = theInt * 10 + ((int) ch) - ((int) '0'); getch(); } sy = numeral; } else // special symbols { int ch2 = ch; getch(); switch (ch2) { case '+': sy = plus; break; case '-': sy = minus; break; case '*': sy = times; break; // case '/': sy = over; break; case '(': sy = open; break; case ')': sy = close; break; default: error("bad symbol"); } } } /** * get character changes variable ch as a side-effect. */ void getch() { ch = '.'; if (sy == eofSy) { return; } try { int n = 0; if (inp.available() > 0) { n = inp.read(buffer); } if (n <= 0) { eof = true; } else { ch = (char) buffer[0]; } } catch (Exception e) { } if (ch == '\n' || ch == '\t') { ch = ' '; } } /** * Error * * @param msg - error message */ public void error(String msg) { System.out.println("\nError: " + msg + " sy=" + sy + " ch=" + ch + " theWord=" + theWord + " theInt=" + theInt); skipRest(); System.exit(1); } /** * skip rest of input */ void skipRest() { if (!eof) { System.out.print("skipping to end of input..."); } int n = 0; while (!eof) { if (n % 80 == 0) { System.out.println(); // break line } System.out.print(ch); n++; getch(); } System.out.println(); } public int sy() { return sy; } } private static class Syntax { /** * Lexical processor */ private final Lexical lex; /** * useful Symbol Sets */ private final long unOprs = (1L << Lexical.minus), binOprs = (1L << Lexical.plus) | (1L << Lexical.minus) | (1L << Lexical.times) | (1L << Lexical.over), startsExp = unOprs | (1L << Lexical.word) | (1L << Lexical.numeral) | (1L << Lexical.open); int[] oprPriority = new int[Lexical.eofSy]; /** * Constructor for Syntax * * @param lex - lexical processor */ public Syntax(Lexical lex) { this.lex = lex; init(); } /** * Initializes precedence of operators */ void init() { for (int i = 0; i < oprPriority.length; i++) { oprPriority[i] = 0; } oprPriority[Lexical.plus] = 1; oprPriority[Lexical.minus] = 1; oprPriority[Lexical.times] = 2; oprPriority[Lexical.over] = 2; } /** * check and skip a particular symbol * * @param sym - symbol */ private void check(int sym) { if (lex.sy() == sym) { lex.insymbol(); } else { error(Lexical.Symbol[sym] + " Expected"); } } public Expression exp() { Expression e = exp(1); check(Lexical.eofSy); return e; } /** * Returns expression */ private Expression exp(int priority) { Expression e = null; if (priority < 3) { e = exp(priority + 1); int sym = lex.sy(); while (member(sym, binOprs) && oprPriority[sym] == priority) { lex.insymbol(); // e.g. 1+2+3 e = new Expression.Binary(sym, e, exp(priority + 1)); sym = lex.sy(); } } else if (member(lex.sy(), unOprs)) // unary op, e.g. -3 { int sym = lex.sy(); lex.insymbol(); e = new Expression.Unary(sym, exp(priority)); } else { switch (lex.sy()) { case Lexical.word: e = new Expression.Ident(lex.theWord); lex.insymbol(); break; case Lexical.numeral: e = new Expression.IntCon(lex.theInt); lex.insymbol(); break; case Lexical.open: // e.g. (e) lex.insymbol(); e = exp(1); check(Lexical.close); break; default: error("bad operand"); } } return e; } /** * is n a member of the "set" s * * @param n - element to be tested for membership * @param s - set */ boolean member(int n, long s) { return ((1L << n) & s) != 0; } /** * Error * * @param msg - message to be printed */ void error(String msg) { lex.error("Syntax: " + msg); } } }