Domain Model for Data Entry Applications
by Maksim Kozyarchuk
|
| |
I am including complete listing of the ReactiveFramework codebase for reference.
from decimal import Decimal
import datetime
from models.alch_model import Instrument, Trade, get_session
class Field:
def __init__(self, name, datatype, validation_method=None,
calculation_method=None, domain_mapping = None):
self.name = name
self.datatype = datatype
self.validation_method = validation_method
self.calculation_method = calculation_method
self.domain_mapping = domain_mapping
class FieldFactory:
FIELDS = [ Field(name='action', datatype=str, validation_method='must_be_provided',
domain_mapping = "Trade.action"),
Field(name='currency_pair', datatype=str, validation_method='valid_currency_pair',
domain_mapping = "map_currency_pair"),
Field(name='primary_amount', datatype=Decimal, calculation_method='calc_primary_amount',
validation_method='must_be_provided', domain_mapping = "Trade.quantity"),
Field(name='secondary_amount', datatype=Decimal, calculation_method='calc_secondary_amount',
validation_method='must_be_provided'),
Field(name='deal_fx_rate', datatype=Decimal, calculation_method='calc_deal_fx_rate',
validation_method='must_be_provided', domain_mapping = "Trade.price"),
Field(name='trade_date', datatype=datetime.date, validation_method='must_be_provided',
domain_mapping = "Trade.trade_date"),
Field(name='expiration_date', datatype=datetime.date, validation_method='after_trade_date',
domain_mapping = "map_expiration_date"),
Field(name='commission', datatype=Decimal, calculation_method='calc_commission',
validation_method='must_be_provided', domain_mapping = "Trade.commission"),
]
@classmethod
def getField(cls, field_name):
for field in cls.FIELDS:
if field.name == field_name:
return field
class FXTransaction:
INSTRUMENT = 'Instrument'
TRADE = 'Trade'
FIELD_DEPENDS = {
'action': [],
'currency_pair': [],
'primary_amount': ['secondary_amount', 'deal_fx_rate'],
'secondary_amount': ['primary_amount', 'deal_fx_rate'],
'deal_fx_rate' : ['primary_amount', 'secondary_amount'],
'trade_date' : [],
'expiration_date' : [],
'commission' : ['primary_amount']
}
def __init__(self):
self._domain_objects = {}
self._domain_objects[self.INSTRUMENT] = Instrument( ins_type = 'FX Forward')
self._domain_objects[self.TRADE] = Trade()
def bind_fields(self):
for field_name in self.FIELD_DEPENDS:
setattr(self, field_name, BoundField(FieldFactory.getField(field_name), self))
def calc_commission(self):
return self.primary_amount.value * Decimal("0.01")
def calc_secondary_amount(self):
return self.primary_amount.value * self.deal_fx_rate.value
def calc_primary_amount(self):
if self.deal_fx_rate.value:
return self.secondary_amount.value / self.deal_fx_rate.value
def calc_deal_fx_rate(self):
if self.primary_amount.value:
return self.secondary_amount.value / self.primary_amount.value
def must_be_provided(self, field):
return "" if field.has_value else "%s is missing" % field.name
def valid_currency_pair(self, field):
if self.must_be_provided(field):
return self.must_be_provided(field)
currencies = field.value.split("/")
if len(currencies) ==2 and all(3==len(curr) for curr in currencies):
return ""
else:
return "Invalid Currency Pair %s" % field.value
def after_trade_date(self, field):
if self.must_be_provided(field):
return self.must_be_provided(field)
if field.value <= self.trade_date.value:
return "%s must after trade date %s" % (field.value, self.trade_date.value)
return ""
def map_currency_pair(self, field, direction):
if direction == field.TO:
instrument = self.get_domain_object(self.INSTRUMENT)
instrument.name = "%s %s" % (field.value, self.expiration_date.value)
instrument.underlying, instrument.currency = field.value.split("/")
else:
return self.get_domain_object(self.INSTRUMENT).name.split(" ")[0]
def map_expiration_date(self, field, direction):
if direction == field.TO:
self.get_domain_object(self.INSTRUMENT).exp_date = field.value
self.get_domain_object(self.TRADE).settle_date = field.value
else:
return self.get_domain_object(self.INSTRUMENT).exp_date
def get_domain_object(self, name):
return self._domain_objects[ name ]
def save(self):
instrument = self.get_domain_object(self.INSTRUMENT)
s = get_session()
s.add(instrument)
s.flush()
trade = self._domain_objects[self.TRADE]
trade.instrument_id = instrument.id
s.add(trade)
s.commit()
trade_id = trade.id
s.close()
self._domain_objects = []
return trade_id
def load(self, trade_id):
self._domain_objects = {}
s = get_session()
self._domain_objects[self.TRADE] = s.query(Trade).get( trade_id )
self._domain_objects[self.INSTRUMENT] = s.query(Instrument).get( self._domain_objects[self.TRADE].instrument_id )
class BoundField:
TO = 'TO'
FROM = 'FROM'
def __init__(self, field_definition, model):
self.definition = field_definition
self.value = None
self.has_value = False
self.has_user_entered_value = False
self.calculation_method = self._bind_method('calculation_method', model)
self.validation_method = self._bind_method('validation_method', model)
self.domain_mapping_method = self._bind_domain_mapping_method(model)
def _bind_method(self, method, model):
if getattr(self.definition, method):
return getattr(model, getattr(self.definition, method))
def set_value(self, value, user_entered=True):
if not isinstance(value, self.definition.datatype):
self.value = self.definition.datatype(value)
else:
self.value = value
self.has_value = True
self.has_user_entered_value = user_entered
def recalc(self):
if not self.has_user_entered_value:
if self.calculation_method:
self.set_value(self.calculation_method(), user_entered=False)
return True
return False
def validate(self):
if self.validation_method:
return self.validation_method(self)
@property
def name(self):
return self.definition.name
def _bind_domain_mapping_method(self, model):
mapping = self.definition.domain_mapping
if not mapping:
return
split_map = mapping.split(".")
if len(split_map) == 1:
return self._bind_method('domain_mapping', model)
elif len(split_map) == 2:
def mapper_function_wrapper(field, direction):
domain_object = model.get_domain_object(split_map[0])
if direction == self.TO:
setattr(domain_object,split_map[1],field.value )
else:
return getattr(domain_object,split_map[1])
return mapper_function_wrapper
else:
raise Exception("Invalid domain_mapping %s" % mapping)
def map_to_domain(self):
if self.domain_mapping_method:
self.domain_mapping_method(self, self.TO)
def map_from_domain(self):
if self.domain_mapping_method:
value = self.domain_mapping_method(self, self.FROM)
if value is not None:
self.set_value(value)
class ReactiveFramework:
__slots__ = ('model', 'depends_notifty')
def __init__(self, model):
self.model = model
self.depends_notifty = {}
self._init_depends_notifty()
self.model.bind_fields()
def _init_depends_notifty(self):
for field_name, deps in self.model.FIELD_DEPENDS.items():
for dep_name in deps:
self.depends_notifty.setdefault(dep_name, [])
self.depends_notifty[dep_name].append(field_name)
def _are_dependents_set(self, field_name):
for dep_field in self.model.FIELD_DEPENDS[field_name]:
if not getattr(self.model, dep_field).has_value:
return False
return True
def _recalc_field(self, field_name, recalculated):
if self._are_dependents_set(field_name):
if getattr(self.model, field_name).recalc():
recalculated.append(field_name)
self._recalc_dependents(field_name, recalculated)
def _recalc_dependents(self, field_name, recalculated=None):
if recalculated is None:
recalculated = []
for field in self.depends_notifty.get(field_name, []):
if field not in recalculated:
self._recalc_field(field, recalculated)
return recalculated
def set_value(self, field_name, value):
getattr(self.model, field_name).set_value(value)
return self._recalc_dependents(field_name)
def get_value(self, field_name):
return getattr(self.model, field_name).value
def validate(self):
result = {}
for field in self.get_fields():
errors = field.validate()
if errors:
result[field.name] = errors
return result
def get_fields(self):
return [self.get_field( field_name)
for field_name in self.model.FIELD_DEPENDS ]
def get_field(self, field_name):
return getattr(self.model, field_name)
@property
def id(self):
return str(id(self))
def save(self):
for field in self.get_fields():
field.map_to_domain()
return self.model.save()
def load(self,trade_id):
self.model.load(trade_id)
for field in self.get_fields():
field.map_from_domain()
recalculated = []
for field in self.get_fields():
if not field.has_value:
self._recalc_field(field.name, recalculated)
for field in self.get_fields():
field.has_user_entered_value = False
if __name__ == "__main__":
fxTrade = ReactiveFramework(FXTransaction())
print( fxTrade.set_value('action', 'Buy'))
print( fxTrade.set_value('primary_amount', 100) )
print( fxTrade.set_value('primary_amount', 100) )
print( fxTrade.set_value('deal_fx_rate', 1.5))
fxTrade.set_value('trade_date', datetime.date(year=2015, month=5, day = 1))
fxTrade.set_value('expiration_date', datetime.date(year=2015, month=5, day = 2))
fxTrade.set_value('currency_pair', 'EUR/USD')
print( fxTrade.get_value("secondary_amount" ))
print( fxTrade.get_value("commission"))
print( fxTrade.validate())
saved_trade_id = fxTrade.save()
print("Saved Trade %s" % saved_trade_id)
fxTrade2 = ReactiveFramework(FXTransaction())
fxTrade2.load(trade_id = saved_trade_id)
for field1, field2 in zip(fxTrade.get_fields(), fxTrade2.get_fields()):
print ("%s is %s vs %s "% (field1.name, field1.value, field2.value))
assert field1.value == field2.value
| |
No comments:
Post a Comment