2008年12月16日星期二

A Simple Symbolic Differentiation Program in Haskell

This program is inspired from, again, SICP. Comparing to the Scheme code (section 2.3.2 of SICP), my program has a front-end parser that converts external representation (a String) of an expression to its internal representation (AST). The purpose of writing this program is to get familiar with the Parsec library in particular and to gain a better understanding of monadic parsers in general.
Loading the program in Hugs, a simple session go like this:
Main> deriv_x "x+3"
1
Main> deriv_x "x*y"
y
Main> deriv_x "(x+3)*x*y"
((x)+((x)+(3)))*(y)
Main> deriv_x "(x+3)^2"
(2)*((x)+(3))
The output is unfortunately cluttered with a lot of unnecessary parenthesizes. This can be solved by writing a specific function that normalize expressions to a canonical form. I am sure there is well-defined algorithm to do this but as I said above, writing a well-polished program doing differentiations really is not the purpose here.  The complete program is list below:
-- Simple symbolic differentiation program
module Main where

import Text.ParserCombinators.Parsec
import Text.ParserCombinators.Parsec.Expr
import qualified Text.ParserCombinators.Parsec.Token as T
import Text.ParserCombinators.Parsec.Language

-- Symbolic Expression

type Symbol = Char

data Expr = Add Expr Expr
| Mul Expr Expr
| Exp Expr Integer
| Var Symbol
| Num Integer
deriving (Eq)

instance Show Expr where
show e = case e' of
Num x -> show x
Var x -> [x]
Add u v -> "(" ++ show u ++ ")" ++ "+" ++ "(" ++ show v ++ ")"
Mul u v -> "(" ++ show u ++ ")" ++ "*" ++ "(" ++ show v ++ ")"
Exp u n -> "(" ++ show u ++ ")" ++ "^" ++ show n
where e' = simplify e

-- Deriving

deriv_ :: Symbol -> Expr -> Expr
deriv_ x e = simplify $ deriv__ x e
where deriv__ _ (Num _) = Num 0
deriv__ x (Var s) | (s == x) = Num 1
| otherwise = Num 0
deriv__ x (Add u v) = Add (deriv__ x u) (deriv__ x v)
deriv__ x (Mul u v) = Add (Mul (deriv__ x u) v) (Mul u (deriv__ x v))
deriv__ x (Exp u n) = Mul (Mul (Num n) (Exp u (n-1))) (deriv__ x u)

-- Expression Simplifier

simplify :: Expr ->; Expr
simplify e = let e' = simplify' e
in if e == e' then e else simplify e'
where
simplify' e@(Num n) = e
simplify' e@(Var x) = e
simplify' (Add (Num 0) u) = u
simplify' (Add u (Num 0)) = u
simplify' (Add (Num n) (Num m)) = Num (n + m)
simplify' (Add u v) = Add (simplify' u) (simplify' v)
simplify' (Mul (Num 0) v) = Num 0
simplify' (Mul u (Num 0)) = Num 0
simplify' (Mul (Num 1) v) = v
simplify' (Mul u (Num 1)) = u
simplify' (Mul (Num n) (Num m)) = Num (n * m)
simplify' (Mul u v) = Mul (simplify' u) (simplify' v)
simplify' (Exp u 0) = Num 1
simplify' (Exp u 1) = simplify u
simplify' (Exp (Num m) n) = Num (m ^ n)
simplify' (Exp u n) = Exp (simplify u) n

-- Parser

lang = T.makeTokenParser emptyDef

natural = T.natural lang
operator c = T.lexeme lang (char c)
variable = T.lexeme lang lower


expr = buildExpressionParser table factor
"expression"

mkNode :: (Expr -> Expr -> Expr) -> Expr -> Expr -> Expr
mkNode op t1 t2 = op t1 t2

mkAdd = mkNode Add
mkMul = mkNode Mul

mkExp :: Expr -> Expr -> Expr
mkExp e (Num n) = Exp e n
mkExp e _ = error "exponent must be a number"

table = [[op '^' (mkExp) AssocRight]
,[op '*' (mkMul) AssocLeft]
,[op '+' (mkAdd) AssocLeft]
]
where
op c f assoc
= Infix (do{ operator c; return f} <?> "operator") assoc

factor = T.parens lang expr
<|> do {v <- natural; return $ Num v }
<|> do {v <- variable; return $ Var v}
<?> "factor"

-- Driver

parseExpr :: String -> Expr
parseExpr input = case parse expr "" input of
Left err -> error $ "parse error at " ++ show err
Right out -> out

deriv_x = deriv_ 'x' . parseExpr

1 条评论:

kall 说...

Hi! There is some error in your code. Most of then is line identation (Haskell has sensibility to line identation). Other is missing in line:

expr = buildExpressionParser table factor
"expression"