m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

129 lines
4.8 KiB

from typing import Optional
from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.numbers import Integer
from sympy.matrices.common import MatrixCommon
from .ndim_array import NDimArray
from .arrayop import derive_by_array
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix
from sympy.matrices.expressions.matexpr import _matrix_derivative
class ArrayDerivative(Derivative):
is_scalar = False
def __new__(cls, expr, *variables, **kwargs):
obj = super().__new__(cls, expr, *variables, **kwargs)
if isinstance(obj, ArrayDerivative):
obj._shape = obj._get_shape()
return obj
def _get_shape(self):
shape = ()
for v, count in self.variable_count:
if hasattr(v, "shape"):
for i in range(count):
shape += v.shape
if hasattr(self.expr, "shape"):
shape += self.expr.shape
return shape
@property
def shape(self):
return self._shape
@classmethod
def _get_zero_with_shape_like(cls, expr):
if isinstance(expr, (MatrixCommon, NDimArray)):
return expr.zeros(*expr.shape)
elif isinstance(expr, MatrixExpr):
return ZeroMatrix(*expr.shape)
else:
raise RuntimeError("Unable to determine shape of array-derivative.")
@staticmethod
def _call_derive_scalar_by_matrix(expr, v): # type: (Expr, MatrixCommon) -> Expr
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_scalar_by_matexpr(expr, v): # type: (Expr, MatrixExpr) -> Expr
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return ZeroMatrix(*v.shape)
@staticmethod
def _call_derive_scalar_by_array(expr, v): # type: (Expr, NDimArray) -> Expr
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_matrix_by_scalar(expr, v): # type: (MatrixCommon, Expr) -> Expr
return _matrix_derivative(expr, v)
@staticmethod
def _call_derive_matexpr_by_scalar(expr, v): # type: (MatrixExpr, Expr) -> Expr
return expr._eval_derivative(v)
@staticmethod
def _call_derive_array_by_scalar(expr, v): # type: (NDimArray, Expr) -> Expr
return expr.applyfunc(lambda x: x.diff(v))
@staticmethod
def _call_derive_default(expr, v): # type: (Expr, Expr) -> Optional[Expr]
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return None
@classmethod
def _dispatch_eval_derivative_n_times(cls, expr, v, count):
# Evaluate the derivative `n` times. If
# `_eval_derivative_n_times` is not overridden by the current
# object, the default in `Basic` will call a loop over
# `_eval_derivative`:
if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
return None
# TODO: this could be done with multiple-dispatching:
if expr.is_scalar:
if isinstance(v, MatrixCommon):
result = cls._call_derive_scalar_by_matrix(expr, v)
elif isinstance(v, MatrixExpr):
result = cls._call_derive_scalar_by_matexpr(expr, v)
elif isinstance(v, NDimArray):
result = cls._call_derive_scalar_by_array(expr, v)
elif v.is_scalar:
# scalar by scalar has a special
return super()._dispatch_eval_derivative_n_times(expr, v, count)
else:
return None
elif v.is_scalar:
if isinstance(expr, MatrixCommon):
result = cls._call_derive_matrix_by_scalar(expr, v)
elif isinstance(expr, MatrixExpr):
result = cls._call_derive_matexpr_by_scalar(expr, v)
elif isinstance(expr, NDimArray):
result = cls._call_derive_array_by_scalar(expr, v)
else:
return None
else:
# Both `expr` and `v` are some array/matrix type:
if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon):
result = derive_by_array(expr, v)
elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
result = cls._call_derive_default(expr, v)
elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
# if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
return None
else:
result = derive_by_array(expr, v)
if result is None:
return None
if count == 1:
return result
else:
return cls._dispatch_eval_derivative_n_times(result, v, count - 1)