POJ 1412 - Equals are Equals

http://poj.org/problem?id=1412

概要

式は

からなっている.

まず1つ式が与えられ,その後にいくつか式が与えられるので,最初の式と等価かどうか判定する.

解法

パースした後,変数に適当な値を束縛して等しいかどうかを調べる.

このとき,浮動小数で計算することになるけど,誤差の扱いに注意が必要.絶対誤差ではなく相対誤差で見る(0 のときに注意).

  1 #include <iostream>
  2 #include <string>
  3 #include <cctype>
  4 #include <cmath>
  5 #include <cassert>
  6 #include <algorithm>
  7 using namespace std;
  8 
  9 struct expr
 10 {
 11   char op;
 12   const expr *lhs, *rhs;
 13   int c;
 14   char var;
 15   expr(int a) : op(0), lhs(0), rhs(0), c(a), var(0) {}
 16   expr(char v) : op(0), lhs(0), rhs(0), c(0), var(v) {}
 17   expr(char o, const expr *l, const expr *r)
 18     : op(o), lhs(l), rhs(r), c(0), var(0)
 19   {}
 20   ~expr() { delete lhs; delete rhs; }
 21 };
 22 
 23 typedef string::const_iterator Iterator;
 24 void skip_white(Iterator& it, const Iterator& last);
 25 const expr *parse_expr(Iterator& it, const Iterator& last);
 26 const expr *parse_term(Iterator& it, const Iterator& last);
 27 const expr *parse_exp(Iterator& it, const Iterator& last);
 28 const expr *parse_factor(Iterator& it, const Iterator& last);
 29 
 30 const expr *parse(const string& s)
 31 {
 32   Iterator it = s.begin();
 33   const Iterator last = s.end();
 34   skip_white(it, last);
 35   return parse_expr(it, last);
 36 }
 37 
 38 void skip_white(Iterator& it, const Iterator& last)
 39 {
 40   for (;it != last && *it == ' '; ++it);
 41 }
 42 
 43 const expr *parse_expr(Iterator& it, const Iterator& last)
 44 {
 45   const expr *lhs = parse_term(it, last);
 46   skip_white(it, last);
 47   static const string a = "+-";
 48   while (it != last && a.find(*it) != string::npos) {
 49     const char op = *it;
 50     ++it;
 51     skip_white(it, last);
 52     const expr *rhs = parse_term(it, last);
 53     skip_white(it, last);
 54     lhs = new expr(op, lhs, rhs);
 55   }
 56   return lhs;
 57 }
 58 
 59 const expr *parse_term(Iterator& it, const Iterator& last)
 60 {
 61   const expr *lhs = parse_exp(it, last);
 62   static const string a = "+-)";
 63   while (it != last && a.find(*it) == string::npos) {
 64     skip_white(it, last);
 65     const expr *rhs = parse_exp(it, last);
 66     skip_white(it, last);
 67     lhs = new expr('*', lhs, rhs);
 68   }
 69   return lhs;
 70 }
 71 
 72 const expr *parse_exp(Iterator& it, const Iterator& last)
 73 {
 74   const expr *lhs = parse_factor(it, last);
 75   skip_white(it, last);
 76   if (*it == '^') {
 77     ++it;
 78     skip_white(it, last);
 79     const expr *rhs = parse_factor(it, last);
 80     skip_white(it, last);
 81     return new expr('^', lhs, rhs);
 82   } else {
 83     return lhs;
 84   }
 85 }
 86 
 87 const expr *parse_factor(Iterator& it, const Iterator& last)
 88 {
 89   if (*it == '(') {
 90     ++it;
 91     skip_white(it, last);
 92     const expr *r = parse_expr(it, last);
 93     skip_white(it, last);
 94     assert(*it == ')');
 95     ++it;
 96     return r;
 97   } else {
 98     if (isdigit(*it)) {
 99       int c = 0;
100       while (it != last && isdigit(*it)) {
101         c = 10*c + *it-'0';
102         ++it;
103       }
104       return new expr(c);
105     } else {
106       if (it != last && islower(*it)) {
107         const char v = *it;
108         ++it;
109         return new expr(v);
110       } else {
111         throw "parse error";
112       }
113     }
114   }
115 }
116 
117 double eval(const expr *e, const int bs[26])
118 {
119   if (!e->lhs) {
120     if (e->var) {
121       return bs[e->var-'a'];
122     } else {
123       return e->c;
124     }
125   } else {
126     const double l = eval(e->lhs, bs);
127     const double r = eval(e->rhs, bs);
128     switch (e->op) {
129       case '+': return l + r;
130       case '-': return l - r;
131       case '*': return l * r;
132       case '^': return pow(l, r);
133       default: throw "unknown operator";
134     }
135   }
136 }
137 
138 bool equal(const expr *e1, const expr *e2)
139 {
140   int binds[26];
141   for (int t = 0; t < 26; t++) {
142     for (int i = 0; i < 26; i++) {
143       binds[i] = (5*t + i*7) % 26;
144     }
145     const double l = eval(e1, binds);
146     const double r = eval(e2, binds);
147     if (!(abs(l) < 1e-6 && abs(r) < 1e-6) && abs(abs(l / r) - 1.0) > 1e-6) {
148       return false;
149     }
150   }
151   return true;
152 }
153 
154 int main()
155 {
156   string s;
157   while (getline(cin, s) && s != ".") {
158     const expr *e1 = parse(s);
159     while (getline(cin, s) && s != ".") {
160       const expr *e2 = parse(s);
161       cout << (equal(e1, e2) ? "yes" : "no") << endl;
162       delete e2;
163     }
164     delete e1;
165     cout << "." << endl;
166   }
167   return 0;
168 }
poj/1412.cc