Soar Kernel  9.3.2 08-06-12
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
reinforcement_learning.cpp
Go to the documentation of this file.
1 #include <portability.h>
2 
3 /*************************************************************************
4  * PLEASE SEE THE FILE "license.txt" (INCLUDED WITH THIS SOFTWARE PACKAGE)
5  * FOR LICENSE AND COPYRIGHT INFORMATION.
6  *************************************************************************/
7 
8 /*************************************************************************
9  *
10  * file: reinforcement_learning.cpp
11  *
12  * =======================================================================
13  * Description : Various functions for Soar-RL
14  * =======================================================================
15  */
16 
17 #include <cstdlib>
18 #include <cmath>
19 #include <vector>
20 #include <fstream>
21 #include <sstream>
22 
23 #include "agent.h"
24 #include "reinforcement_learning.h"
25 #include "production.h"
26 #include "rhsfun.h"
27 #include "instantiations.h"
28 #include "rete.h"
29 #include "wmem.h"
30 #include "tempmem.h"
31 #include "print.h"
32 #include "xml.h"
33 #include "utilities.h"
34 #include "recmem.h"
35 
36 extern Symbol *instantiate_rhs_value (agent* thisAgent, rhs_value rv, goal_stack_level new_id_level, char new_id_letter, struct token_struct *tok, wme *w);
37 extern void variablize_symbol (agent* thisAgent, Symbol **sym);
38 extern void variablize_nots_and_insert_into_conditions (agent* thisAgent, not_struct *nots, condition *conds);
39 extern void variablize_condition_list (agent* thisAgent, condition *cond);
40 
42 // Parameters
44 
45 const std::vector<std::pair<std::string, param_accessor<double> *> > &rl_param_container::get_documentation_params() {
46  static std::vector<std::pair<std::string, param_accessor<double> *> > documentation_params;
47  static bool initted = false;
48  if (!initted) {
49  initted = true;
50  // Is it okay to use new here, because this is a static variable anyway,
51  // so it's not going to happen more than once and shouldn't ever be cleaned up?
52  documentation_params.push_back(std::make_pair("rl-updates", new rl_updates_accessor()));
53  documentation_params.push_back(std::make_pair("delta-bar-delta-h", new rl_dbd_h_accessor()));
54  }
55  return documentation_params;
56 }
57 
58 rl_param_container::rl_param_container( agent *new_agent ): soar_module::param_container( new_agent )
59 {
60  // learning
62  add( learning );
63 
64  // meta-learning-rate
67 
68  // update-log-path
71 
72  // discount-rate
74  add( discount_rate );
75 
76  // learning-rate
78  add( learning_rate );
79 
80  // learning-policy
82  learning_policy->add_mapping( sarsa, "sarsa" );
83  learning_policy->add_mapping( q, "q-learning" );
85 
86  // decay-mode
88  decay_mode->add_mapping( normal_decay, "normal" );
91  decay_mode->add_mapping( delta_bar_delta_decay, "delta-bar-delta" );
92  add( decay_mode );
93 
94  // eligibility-trace-decay-rate
95  et_decay_rate = new soar_module::decimal_param( "eligibility-trace-decay-rate", 0, new soar_module::btw_predicate<double>( 0, 1, true ), new soar_module::f_predicate<double>() );
96  add( et_decay_rate );
97 
98  // eligibility-trace-tolerance
99  et_tolerance = new soar_module::decimal_param( "eligibility-trace-tolerance", 0.001, new soar_module::gt_predicate<double>( 0, false ), new soar_module::f_predicate<double>() );
100  add( et_tolerance );
101 
102  // temporal-extension
105 
106  // hrl-discount
108  add( hrl_discount );
109 
110  // temporal-discount
113 
114  // chunk-stop
116  add( chunk_stop );
117 
118  // meta
120  add( meta );
121 
122  // apoptosis
125  apoptosis->add_mapping( apoptosis_chunks, "chunks" );
126  apoptosis->add_mapping( apoptosis_rl, "rl-chunks" );
127  add( apoptosis );
128 
129  // apoptosis-decay
131  add( apoptosis_decay );
132 
133  // apoptosis-thresh
136 };
137 
138 //
139 
140 void rl_reset_data( agent* );
141 
142 rl_learning_param::rl_learning_param( const char *new_name, soar_module::boolean new_value, soar_module::predicate<soar_module::boolean> *new_prot_pred, agent *new_agent ): soar_module::boolean_param( new_name, new_value, new_prot_pred ), my_agent( new_agent ) {}
143 
145 {
146  if ( new_value != value )
147  {
148  if ( new_value == soar_module::off )
149  {
151  }
152 
153  value = new_value;
154  }
155 }
156 
157 //
158 
159 rl_apoptosis_param::rl_apoptosis_param( const char *new_name, rl_param_container::apoptosis_choices new_value, soar_module::predicate<rl_param_container::apoptosis_choices> *new_prot_pred, agent *new_agent ): soar_module::constant_param<rl_param_container::apoptosis_choices>( new_name, new_value, new_prot_pred ), my_agent( new_agent ) {}
160 
162 {
163  if ( value != new_value )
164  {
165  // from off to on (doesn't matter which)
167  {
171  }
172  // from on to off
173  else if ( new_value == rl_param_container::apoptosis_none )
174  {
176  }
177 
178  value = new_value;
179  }
180 }
181 
182 //
183 
184 rl_apoptosis_thresh_param::rl_apoptosis_thresh_param( const char* new_name, double new_value, soar_module::predicate<double>* new_val_pred, soar_module::predicate<double>* new_prot_pred ): soar_module::decimal_param( new_name, new_value, new_val_pred, new_prot_pred ) {}
185 
186 void rl_apoptosis_thresh_param::set_value( double new_value ) { value = -new_value; }
187 
188 //
189 
190 template <typename T>
191 rl_apoptosis_predicate<T>::rl_apoptosis_predicate( agent *new_agent ): soar_module::agent_predicate<T>( new_agent ) {}
192 
193 template <typename T>
194 bool rl_apoptosis_predicate<T>::operator() ( T /*val*/ ) { return ( this->my_agent->rl_params->apoptosis->get_value() != rl_param_container::apoptosis_none ); }
195 
196 
199 
201 // Stats
203 
204 rl_stat_container::rl_stat_container( agent *new_agent ): stat_container( new_agent )
205 {
206  // update-error
208  add( update_error );
209 
210  // total-reward
212  add( total_reward );
213 
214  // global-reward
216  add( global_reward );
217 };
218 
219 
222 
223 // quick shortcut to determine if rl is enabled
224 bool rl_enabled( agent *my_agent )
225 {
226  return ( my_agent->rl_params->learning->get_value() == soar_module::on );
227 }
228 
231 
232 inline void rl_add_ref( Symbol* goal, production* prod )
233 {
234  goal->id.rl_info->prev_op_rl_rules->push_back( prod );
235  prod->rl_ref_count++;
236 }
237 
238 inline void rl_remove_ref( Symbol* goal, production* prod )
239 {
240  rl_rule_list* rules = goal->id.rl_info->prev_op_rl_rules;
241 
242  for ( rl_rule_list::iterator p=rules->begin(); p!=rules->end(); p++ )
243  {
244  if ( *p == prod )
245  {
246  prod->rl_ref_count--;
247  }
248  }
249 
250  rules->remove( prod );
251 }
252 
253 void rl_clear_refs( Symbol* goal )
254 {
255  rl_rule_list* rules = goal->id.rl_info->prev_op_rl_rules;
256 
257  for ( rl_rule_list::iterator p=rules->begin(); p!=rules->end(); p++ )
258  {
259  (*p)->rl_ref_count--;
260  }
261 
262  rules->clear();
263 }
264 
267 
268 // resets rl data structures
269 void rl_reset_data( agent *my_agent )
270 {
271  Symbol *goal = my_agent->top_goal;
272  while( goal )
273  {
274  rl_data *data = goal->id.rl_info;
275 
276  data->eligibility_traces->clear();
277  rl_clear_refs( goal );
278 
279  data->previous_q = 0;
280  data->reward = 0;
281 
282  data->gap_age = 0;
283  data->hrl_age = 0;
284 
285  goal = goal->id.lower_goal;
286  }
287 }
288 
289 // removes rl references to a production (used for excise)
290 void rl_remove_refs_for_prod( agent *my_agent, production *prod )
291 {
292  for ( Symbol* state = my_agent->top_state; state; state = state->id.lower_goal )
293  {
294  state->id.rl_info->eligibility_traces->erase( prod );
295  rl_remove_ref( state, prod );
296  }
297 }
298 
299 
302 
303 // returns true if a template is valid
305 {
306  bool numeric_pref = false;
307  bool var_pref = false;
308  int num_actions = 0;
309 
310  for ( action *a = prod->action_list; a; a = a->next )
311  {
312  num_actions++;
313  if ( a->type == MAKE_ACTION )
314  {
315  if ( a->preference_type == NUMERIC_INDIFFERENT_PREFERENCE_TYPE )
316  {
317  numeric_pref = true;
318  }
319  else if ( a->preference_type == BINARY_INDIFFERENT_PREFERENCE_TYPE )
320  {
321  if ( rhs_value_is_symbol( a->referent ) && ( rhs_value_to_symbol( a->referent )->id.common_symbol_info.symbol_type == VARIABLE_SYMBOL_TYPE ) )
322  var_pref = true;
323  }
324  }
325  }
326 
327  return ( ( num_actions == 1 ) && ( numeric_pref || var_pref ) );
328 }
329 
330 // returns true if an rl rule is valid
332 {
333  bool numeric_pref = false;
334  int num_actions = 0;
335 
336  for ( action *a = prod->action_list; a; a = a->next )
337  {
338  num_actions++;
339  if ( a->type == MAKE_ACTION )
340  {
341  if ( a->preference_type == NUMERIC_INDIFFERENT_PREFERENCE_TYPE )
342  numeric_pref = true;
343  }
344  }
345 
346  return ( numeric_pref && ( num_actions == 1 ) );
347 }
348 
349 // sets rl meta-data from a production documentation string
350 void rl_rule_meta( agent* my_agent, production* prod )
351 {
352  if ( prod->documentation && ( my_agent->rl_params->meta->get_value() == soar_module::on ) )
353  {
354  std::string doc( prod->documentation );
355 
356  const std::vector<std::pair<std::string, param_accessor<double> *> > &documentation_params = my_agent->rl_params->get_documentation_params();
357  for (std::vector<std::pair<std::string, param_accessor<double> *> >::const_iterator doc_params_it = documentation_params.begin();
358  doc_params_it != documentation_params.end(); ++doc_params_it) {
359  const std::string &param_name = doc_params_it->first;
360  param_accessor<double> *accessor = doc_params_it->second;
361  std::stringstream param_name_ss;
362  param_name_ss << param_name << "=";
363  std::string search_term = param_name_ss.str();
364  size_t begin_index = doc.find(search_term);
365  if (begin_index == std::string::npos) continue;
366  begin_index += search_term.size();
367  size_t end_index = doc.find(";", begin_index);
368  if (end_index == std::string::npos) continue;
369  std::string param_value_str = doc.substr(begin_index, end_index);
370  accessor->set_param(prod, param_value_str);
371  }
372 
373  /*
374  std::string search( "rlupdates=" );
375 
376  if ( doc.length() > search.length() )
377  {
378  if ( doc.substr( 0, search.length() ).compare( search ) == 0 )
379  {
380  uint64_t val;
381  from_string( val, doc.substr( search.length() ) );
382 
383  prod->rl_update_count = static_cast< double >( val );
384  }
385  }
386  */
387  }
388 }
389 
390 
393 
394 // gets the auto-assigned id of a template instantiation
395 int rl_get_template_id( const char *prod_name )
396 {
397  std::string temp = prod_name;
398 
399  // has to be at least "rl*a*#" (where a is a single letter/number/etc)
400  if ( temp.length() < 6 )
401  return -1;
402 
403  // check first three letters are "rl*"
404  if ( temp.compare( 0, 3, "rl*" ) )
405  return -1;
406 
407  // find last * to isolate id
408  std::string::size_type last_star = temp.find_last_of( '*' );
409  if ( last_star == std::string::npos )
410  return -1;
411 
412  // make sure there's something left after last_star
413  if ( last_star == ( temp.length() - 1 ) )
414  return -1;
415 
416  // make sure id is a valid natural number
417  std::string id_str = temp.substr( last_star + 1 );
418  if ( !is_whole_number( id_str ) )
419  return -1;
420 
421  // convert id
422  int id;
423  from_string( id, id_str );
424  return id;
425 }
426 
427 // initializes the max rl template counter
429 {
430  my_agent->rl_template_count = 1;
431 }
432 
433 // updates rl template counter for a rule
434 void rl_update_template_tracking( agent *my_agent, const char *rule_name )
435 {
436  int new_id = rl_get_template_id( rule_name );
437 
438  if ( ( new_id != -1 ) && ( new_id > my_agent->rl_template_count ) )
439  my_agent->rl_template_count = ( new_id + 1 );
440 }
441 
442 // gets the next template-assigned id
443 int rl_next_template_id( agent *my_agent )
444 {
445  return (my_agent->rl_template_count++);
446 }
447 
448 // gives back a template-assigned id (on auto-retract)
449 void rl_revert_template_id( agent *my_agent )
450 {
451  my_agent->rl_template_count--;
452 }
453 
454 inline void rl_get_symbol_constant( Symbol* p_sym, Symbol* i_sym, rl_symbol_map* constants )
455 {
456  if ( ( p_sym->common.symbol_type == VARIABLE_SYMBOL_TYPE ) && ( ( i_sym->common.symbol_type != IDENTIFIER_SYMBOL_TYPE ) || ( i_sym->id.smem_lti != NIL ) ) )
457  {
458  constants->insert( std::make_pair< Symbol*, Symbol* >( p_sym, i_sym ) );
459  }
460 }
461 
462 void rl_get_test_constant( test* p_test, test* i_test, rl_symbol_map* constants )
463 {
464  if ( test_is_blank_test( *p_test ) )
465  {
466  return;
467  }
468 
469  if ( test_is_blank_or_equality_test( *p_test ) )
470  {
471  rl_get_symbol_constant( *(reinterpret_cast<Symbol**>( p_test )), *(reinterpret_cast<Symbol**>( i_test )), constants );
472 
473  return;
474  }
475 
476 
477  // complex test stuff
478  // NLD: If the code below is uncommented, it accesses bad memory on the first
479  // id test and segfaults. I'm honestly unsure why (perhaps something
480  // about state test?). Most of this code was copied/adapted from
481  // the variablize_test code in production.cpp.
482  /*
483  {
484  complex_test* p_ct = complex_test_from_test( *p_test );
485  complex_test* i_ct = complex_test_from_test( *i_test );
486 
487  if ( ( p_ct->type == GOAL_ID_TEST ) || ( p_ct->type == IMPASSE_ID_TEST ) || ( p_ct->type == DISJUNCTION_TEST ) )
488  {
489  return;
490  }
491  else if ( p_ct->type == CONJUNCTIVE_TEST )
492  {
493  cons* p_c=p_ct->data.conjunct_list;
494  cons* i_c=i_ct->data.conjunct_list;
495 
496  while ( p_c )
497  {
498  rl_get_test_constant( reinterpret_cast<test*>( &( p_c->first ) ), reinterpret_cast<test*>( &( i_c->first ) ), constants );
499 
500  p_c = p_c->rest;
501  i_c = i_c->rest;
502  }
503 
504  return;
505  }
506  else
507  {
508  rl_get_symbol_constant( p_ct->data.referent, i_ct->data.referent, constants );
509 
510  return;
511  }
512  }
513  */
514 }
515 
516 void rl_get_template_constants( condition* p_conds, condition* i_conds, rl_symbol_map* constants )
517 {
518  condition* p_cond = p_conds;
519  condition* i_cond = i_conds;
520 
521  while ( p_cond )
522  {
523  if ( ( p_cond->type == POSITIVE_CONDITION ) || ( p_cond->type == NEGATIVE_CONDITION ) )
524  {
525  rl_get_test_constant( &( p_cond->data.tests.id_test ), &( i_cond->data.tests.id_test ), constants );
526  rl_get_test_constant( &( p_cond->data.tests.attr_test ), &( i_cond->data.tests.attr_test ), constants );
527  rl_get_test_constant( &( p_cond->data.tests.value_test ), &( i_cond->data.tests.value_test ), constants );
528  }
529  else if ( p_cond->type == CONJUNCTIVE_NEGATION_CONDITION )
530  {
531  rl_get_template_constants( p_cond->data.ncc.top, i_cond->data.ncc.top, constants );
532  }
533 
534  p_cond = p_cond->next;
535  i_cond = i_cond->next;
536  }
537 }
538 
539 // builds a template instantiation
540  Symbol *rl_build_template_instantiation( agent *my_agent, instantiation *my_template_instance, struct token_struct *tok, wme *w )
541 {
542  Symbol* return_val = NULL;
543 
544  // initialize production conditions
545  if ( my_template_instance->prod->rl_template_conds == NIL )
546  {
547  not_struct* nots;
548  condition* c_top;
549  condition* c_bottom;
550 
551  p_node_to_conditions_and_nots( my_agent, my_template_instance->prod->p_node, NIL, NIL, &( c_top ), &( c_bottom ), &( nots ), NIL );
552 
553  my_template_instance->prod->rl_template_conds = c_top;
554  }
555 
556  // initialize production instantiation set
557  if ( my_template_instance->prod->rl_template_instantiations == NIL )
558  {
559  my_template_instance->prod->rl_template_instantiations = new rl_symbol_map_set;
560  }
561 
562  // get constants
563  rl_symbol_map constant_map;
564  {
565  rl_get_template_constants( my_template_instance->prod->rl_template_conds, my_template_instance->top_of_instantiated_conditions, &( constant_map ) );
566  }
567 
568  // try to insert into instantiation set
569  //if ( !constant_map.empty() )
570  {
571  std::pair< rl_symbol_map_set::iterator, bool > ins_result = my_template_instance->prod->rl_template_instantiations->insert( constant_map );
572  if ( ins_result.second )
573  {
574  Symbol *id, *attr, *value, *referent;
575  production *my_template = my_template_instance->prod;
576  action *my_action = my_template->action_list;
577  char first_letter;
578  double init_value = 0;
579  condition *cond_top, *cond_bottom;
580 
581  // make unique production name
582  Symbol *new_name_symbol;
583  std::string new_name = "";
584  std::string empty_string = "";
585  std::string temp_id;
586  int new_id;
587  do
588  {
589  new_id = rl_next_template_id( my_agent );
590  to_string( new_id, temp_id );
591  new_name = ( "rl*" + empty_string + my_template->name->sc.name + "*" + temp_id );
592  } while ( find_sym_constant( my_agent, new_name.c_str() ) != NIL );
593  new_name_symbol = make_sym_constant( my_agent, new_name.c_str() );
594 
595  // prep conditions
596  copy_condition_list( my_agent, my_template_instance->top_of_instantiated_conditions, &cond_top, &cond_bottom );
597  rl_add_goal_or_impasse_tests_to_conds( my_agent, cond_top );
598  reset_variable_generator( my_agent, cond_top, NIL );
599  my_agent->variablization_tc = get_new_tc_number( my_agent );
600  variablize_condition_list( my_agent, cond_top );
601  variablize_nots_and_insert_into_conditions( my_agent, my_template_instance->nots, cond_top );
602 
603  // get the preference value
604  id = instantiate_rhs_value( my_agent, my_action->id, -1, 's', tok, w );
605  attr = instantiate_rhs_value( my_agent, my_action->attr, id->id.level, 'a', tok, w );
606  first_letter = first_letter_from_symbol( attr );
607  value = instantiate_rhs_value( my_agent, my_action->value, id->id.level, first_letter, tok, w );
608  referent = instantiate_rhs_value( my_agent, my_action->referent, id->id.level, first_letter, tok, w );
609 
610  // clean up after yourself :)
611  symbol_remove_ref( my_agent, id );
612  symbol_remove_ref( my_agent, attr );
613  symbol_remove_ref( my_agent, value );
614  symbol_remove_ref( my_agent, referent );
615 
616  // make new action list
617  action *new_action = rl_make_simple_action( my_agent, id, attr, value, referent );
619 
620  // make new production
621  production *new_production = make_production( my_agent, USER_PRODUCTION_TYPE, new_name_symbol, &cond_top, &cond_bottom, &new_action, false );
622 
623  // set initial expected reward values
624  {
625  if ( referent->common.symbol_type == INT_CONSTANT_SYMBOL_TYPE )
626  {
627  init_value = static_cast< double >( referent->ic.value );
628  }
629  else if ( referent->common.symbol_type == FLOAT_CONSTANT_SYMBOL_TYPE )
630  {
631  init_value = referent->fc.value;
632  }
633 
634  new_production->rl_ecr = 0.0;
635  new_production->rl_efr = init_value;
636  }
637 
638  // attempt to add to rete, remove if duplicate
639  if ( add_production_to_rete( my_agent, new_production, cond_top, NULL, FALSE, TRUE ) == DUPLICATE_PRODUCTION )
640  {
641  excise_production( my_agent, new_production, false );
642  rl_revert_template_id( my_agent );
643 
644  new_name_symbol = NULL;
645  }
646  deallocate_condition_list( my_agent, cond_top );
647 
648  return_val = new_name_symbol;
649  }
650  }
651 
652  return return_val;
653 }
654 
655 // creates an action for a template instantiation
656 action *rl_make_simple_action( agent *my_agent, Symbol *id_sym, Symbol *attr_sym, Symbol *val_sym, Symbol *ref_sym )
657 {
658  action *rhs;
659  Symbol *temp;
660 
661  allocate_with_pool( my_agent, &my_agent->action_pool, &rhs );
662  rhs->next = NIL;
663  rhs->type = MAKE_ACTION;
664 
665  // id
666  temp = id_sym;
667  symbol_add_ref( temp );
668  variablize_symbol( my_agent, &temp );
669  rhs->id = symbol_to_rhs_value( temp );
670 
671  // attribute
672  temp = attr_sym;
673  symbol_add_ref( temp );
674  variablize_symbol( my_agent, &temp );
675  rhs->attr = symbol_to_rhs_value( temp );
676 
677  // value
678  temp = val_sym;
679  symbol_add_ref( temp );
680  variablize_symbol( my_agent, &temp );
681  rhs->value = symbol_to_rhs_value( temp );
682 
683  // referent
684  temp = ref_sym;
685  symbol_add_ref( temp );
686  variablize_symbol( my_agent, &temp );
687  rhs->referent = symbol_to_rhs_value( temp );
688 
689  return rhs;
690 }
691 
693 {
694  // mark each id as we add a test for it, so we don't add a test for the same id in two different places
695  Symbol *id;
696  test t;
697  complex_test *ct;
698  tc_number tc = get_new_tc_number( my_agent );
699 
700  for ( condition *cond = all_conds; cond != NIL; cond = cond->next )
701  {
702  if ( cond->type != POSITIVE_CONDITION )
703  continue;
704 
705  id = referent_of_equality_test( cond->data.tests.id_test );
706 
707  if ( ( id->id.isa_goal || id->id.isa_impasse ) && ( id->id.tc_num != tc ) )
708  {
709  allocate_with_pool( my_agent, &my_agent->complex_test_pool, &ct );
710  ct->type = static_cast<byte>( ( id->id.isa_goal )?( GOAL_ID_TEST ):( IMPASSE_ID_TEST ) );
711  t = make_test_from_complex_test( ct );
712  add_new_test_to_test( my_agent, &( cond->data.tests.id_test ), t );
713  id->id.tc_num = tc;
714  }
715  }
716 }
717 
718 
721 
722 // gathers discounted reward for a state
724 {
725  rl_data *data = goal->id.rl_info;
726 
727  if ( !data->prev_op_rl_rules->empty() )
728  {
729  slot *s = find_slot( goal->id.reward_header, my_agent->rl_sym_reward );
730  slot *t;
731  wme *w, *x;
732 
733  double reward = 0.0;
734  double discount_rate = my_agent->rl_params->discount_rate->get_value();
735 
736  if ( s )
737  {
738  for ( w=s->wmes; w; w=w->next )
739  {
740  if ( w->value->common.symbol_type == IDENTIFIER_SYMBOL_TYPE )
741  {
742  t = find_slot( w->value, my_agent->rl_sym_value );
743  if ( t )
744  {
745  for ( x=t->wmes; x; x=x->next )
746  {
747  if ( ( x->value->common.symbol_type == FLOAT_CONSTANT_SYMBOL_TYPE ) || ( x->value->common.symbol_type == INT_CONSTANT_SYMBOL_TYPE ) )
748  {
749  reward += get_number_from_symbol( x->value );
750  }
751  }
752  }
753  }
754  }
755 
756  // if temporal_discount is off, don't discount for gaps
757  unsigned int effective_age = data->hrl_age;
758  if (my_agent->rl_params->temporal_discount->get_value() == soar_module::on) {
759  effective_age += data->gap_age;
760  }
761 
762  data->reward += ( reward * pow( discount_rate, static_cast< double >( effective_age ) ) );
763  }
764 
765  // update stats
766  double global_reward = my_agent->rl_stats->global_reward->get_value();
767  my_agent->rl_stats->total_reward->set_value( reward );
768  my_agent->rl_stats->global_reward->set_value( global_reward + reward );
769 
770  if ( ( goal != my_agent->bottom_goal ) && ( my_agent->rl_params->hrl_discount->get_value() == soar_module::on ) )
771  {
772  data->hrl_age++;
773  }
774  }
775 }
776 
777 // gathers reward for all states
779 {
780  Symbol *goal = my_agent->top_goal;
781 
782  while( goal )
783  {
784  rl_tabulate_reward_value_for_goal( my_agent, goal );
785  goal = goal->id.lower_goal;
786  }
787 }
788 
789 // stores rl info for a state w.r.t. a selected operator
790 void rl_store_data( agent *my_agent, Symbol *goal, preference *cand )
791 {
792  rl_data *data = goal->id.rl_info;
793  Symbol *op = cand->value;
794 
795  bool using_gaps = ( my_agent->rl_params->temporal_extension->get_value() == soar_module::on );
796 
797  // Make list of just-fired prods
798  unsigned int just_fired = 0;
799  for ( preference *pref = goal->id.operator_slot->preferences[ NUMERIC_INDIFFERENT_PREFERENCE_TYPE ]; pref; pref = pref->next )
800  {
801  if ( ( op == pref->value ) && pref->inst->prod->rl_rule )
802  {
803  if ( ( just_fired == 0 ) && !data->prev_op_rl_rules->empty() )
804  {
805  rl_clear_refs( goal );
806  }
807 
808  rl_add_ref( goal, pref->inst->prod );
809  just_fired++;
810  }
811  }
812 
813  if ( just_fired )
814  {
815  data->previous_q = cand->numeric_value;
816  }
817  else
818  {
819  if ( my_agent->sysparams[ TRACE_RL_SYSPARAM ] && using_gaps &&
820  ( data->gap_age == 0 ) && !data->prev_op_rl_rules->empty() )
821  {
822  char buf[256];
823  SNPRINTF( buf, 254, "gap started (%c%llu)", goal->id.name_letter, static_cast<long long unsigned>(goal->id.name_number) );
824 
825  print( my_agent, buf );
826  xml_generate_warning( my_agent, buf );
827  }
828 
829  if ( !using_gaps )
830  {
831  if ( !data->prev_op_rl_rules->empty() )
832  {
833  rl_clear_refs( goal );
834  }
835 
836  data->previous_q = cand->numeric_value;
837  }
838  else
839  {
840  if ( !data->prev_op_rl_rules->empty() )
841  {
842  data->gap_age++;
843  }
844  }
845  }
846 }
847 
848 // performs the rl update at a state
849 void rl_perform_update( agent *my_agent, double op_value, bool op_rl, Symbol *goal, bool update_efr )
850 {
851  bool using_gaps = ( my_agent->rl_params->temporal_extension->get_value() == soar_module::on );
852 
853  if ( !using_gaps || op_rl )
854  {
855  rl_data *data = goal->id.rl_info;
856 
857  if ( !data->prev_op_rl_rules->empty() )
858  {
859  rl_et_map::iterator iter;
860  double alpha = my_agent->rl_params->learning_rate->get_value();
861  double lambda = my_agent->rl_params->et_decay_rate->get_value();
862  double gamma = my_agent->rl_params->discount_rate->get_value();
863  double tolerance = my_agent->rl_params->et_tolerance->get_value();
864  double theta = my_agent->rl_params->meta_learning_rate->get_value();
865 
866  // if temporal_discount is off, don't discount for gaps
867  unsigned int effective_age = data->hrl_age + 1;
868  if (my_agent->rl_params->temporal_discount->get_value() == soar_module::on) {
869  effective_age += data->gap_age;
870  }
871 
872  double discount = pow( gamma, static_cast< double >( effective_age ) );
873 
874  // notify of gap closure
875  if ( data->gap_age && using_gaps && my_agent->sysparams[ TRACE_RL_SYSPARAM ] )
876  {
877  char buf[256];
878  SNPRINTF( buf, 254, "gap ended (%c%llu)", goal->id.name_letter, static_cast<long long unsigned>(goal->id.name_number) );
879 
880  print( my_agent, buf );
881  xml_generate_warning( my_agent, buf );
882  }
883 
884  // Iterate through eligibility_traces, decay traces. If less than TOLERANCE, remove from map.
885  if ( lambda == 0 )
886  {
887  if ( !data->eligibility_traces->empty() )
888  {
889  data->eligibility_traces->clear();
890  }
891  }
892  else
893  {
894  for ( iter = data->eligibility_traces->begin(); iter != data->eligibility_traces->end(); )
895  {
896  iter->second *= lambda;
897  iter->second *= discount;
898  if ( iter->second < tolerance )
899  {
900  data->eligibility_traces->erase( iter++ );
901  }
902  else
903  {
904  ++iter;
905  }
906  }
907  }
908 
909  // Update trace for just fired prods
910  double sum_old_ecr = 0.0;
911  double sum_old_efr = 0.0;
912  if ( !data->prev_op_rl_rules->empty() )
913  {
914  double trace_increment = ( 1.0 / static_cast<double>( data->prev_op_rl_rules->size() ) );
915  rl_rule_list::iterator p;
916 
917  for ( p=data->prev_op_rl_rules->begin(); p!=data->prev_op_rl_rules->end(); p++ )
918  {
919  sum_old_ecr += (*p)->rl_ecr;
920  sum_old_efr += (*p)->rl_efr;
921 
922  iter = data->eligibility_traces->find( (*p) );
923 
924  if ( iter != data->eligibility_traces->end() )
925  {
926  iter->second += trace_increment;
927  }
928  else
929  {
930  (*data->eligibility_traces)[ (*p) ] = trace_increment;
931  }
932  }
933  }
934 
935  // For each prod with a trace, perform update
936  {
937  double old_ecr, old_efr;
938  double delta_ecr, delta_efr;
939  double new_combined, new_ecr, new_efr;
940  double delta_t = (data->reward + discount * op_value) - (sum_old_ecr + sum_old_efr);
941 
942  for ( iter = data->eligibility_traces->begin(); iter != data->eligibility_traces->end(); iter++ )
943  {
944  production *prod = iter->first;
945 
946  // get old vals
947  old_ecr = prod->rl_ecr;
948  old_efr = prod->rl_efr;
949 
950  // Adjust alpha based on decay policy
951  // Miller 11/14/2011
952  double adjusted_alpha;
953  switch (my_agent->rl_params->decay_mode->get_value())
954  {
956  adjusted_alpha = 1.0 / (prod->rl_update_count + 1.0);
957  break;
959  adjusted_alpha = 1.0 / (log(prod->rl_update_count + 1.0) + 1.0);
960  break;
962  {
963  // Note that in this case, x_i = 1.0 for all productions that are being updated.
964  // Those values have been included here for consistency with the algorithm as described in the delta bar delta paper.
965  prod->rl_delta_bar_delta_beta = prod->rl_delta_bar_delta_beta + theta * delta_t * 1.0 * prod->rl_delta_bar_delta_h;
966  adjusted_alpha = exp(prod->rl_delta_bar_delta_beta);
967  double decay_term = 1.0 - adjusted_alpha * 1.0 * 1.0;
968  if (decay_term < 0.0) decay_term = 0.0;
969  prod->rl_delta_bar_delta_h = prod->rl_delta_bar_delta_h * decay_term + adjusted_alpha * delta_t * 1.0;
970  break;
971  }
973  default:
974  adjusted_alpha = alpha;
975  break;
976  }
977 
978  // calculate updates
979  delta_ecr = ( adjusted_alpha * iter->second * ( data->reward - sum_old_ecr ) );
980 
981  if ( update_efr )
982  {
983  delta_efr = ( adjusted_alpha * iter->second * ( ( discount * op_value ) - sum_old_efr ) );
984  }
985  else
986  {
987  delta_efr = 0.0;
988  }
989 
990  // calculate new vals
991  new_ecr = ( old_ecr + delta_ecr );
992  new_efr = ( old_efr + delta_efr );
993  new_combined = ( new_ecr + new_efr );
994 
995  // print as necessary
996  if ( my_agent->sysparams[ TRACE_RL_SYSPARAM ] )
997  {
998  std::ostringstream ss;
999  ss << "RL update " << prod->name->sc.name << " "
1000  << old_ecr << " " << old_efr << " " << old_ecr + old_efr << " -> "
1001  << new_ecr << " " << new_efr << " " << new_combined ;
1002 
1003  std::string temp_str( ss.str() );
1004  print( my_agent, "%s\n", temp_str.c_str() );
1005  xml_generate_message( my_agent, temp_str.c_str() );
1006 
1007  // Log update to file if the log file has been set
1008  std::string log_path = my_agent->rl_params->update_log_path->get_value();
1009  if (!log_path.empty()) {
1010  std::ofstream file(log_path.c_str(), std::ios_base::app);
1011  file << ss.str() << std::endl;
1012  file.close();
1013  }
1014  }
1015 
1016  // Change value of rule
1018  prod->action_list->referent = symbol_to_rhs_value( make_float_constant( my_agent, new_combined ) );
1019  prod->rl_update_count += 1;
1020  prod->rl_ecr = new_ecr;
1021  prod->rl_efr = new_efr;
1022 
1023  // change documentation
1024  if ( my_agent->rl_params->meta->get_value() == soar_module::on )
1025  {
1026  if ( prod->documentation )
1027  {
1028  free_memory_block_for_string( my_agent, prod->documentation );
1029  }
1030  std::stringstream doc_ss;
1031  const std::vector<std::pair<std::string, param_accessor<double> *> > &documentation_params = my_agent->rl_params->get_documentation_params();
1032  for (std::vector<std::pair<std::string, param_accessor<double> *> >::const_iterator doc_params_it = documentation_params.begin();
1033  doc_params_it != documentation_params.end(); ++doc_params_it) {
1034  doc_ss << doc_params_it->first << "=" << doc_params_it->second->get_param(prod) << ";";
1035  }
1036  prod->documentation = make_memory_block_for_string(my_agent, doc_ss.str().c_str());
1037 
1038  /*
1039  std::string rlupdates( "rlupdates=" );
1040  std::string val;
1041  to_string( static_cast< uint64_t >( prod->rl_update_count ), val );
1042  rlupdates.append( val );
1043 
1044  prod->documentation = make_memory_block_for_string( my_agent, rlupdates.c_str() );
1045  */
1046  }
1047 
1048  // Change value of preferences generated by current instantiations of this rule
1049  if ( prod->instantiations )
1050  {
1051  for ( instantiation *inst = prod->instantiations; inst; inst = inst->next )
1052  {
1053  for ( preference *pref = inst->preferences_generated; pref; pref = pref->inst_next )
1054  {
1055  symbol_remove_ref( my_agent, pref->referent );
1056  pref->referent = make_float_constant( my_agent, new_combined );
1057  }
1058  }
1059  }
1060  }
1061  }
1062  }
1063 
1064  data->gap_age = 0;
1065  data->hrl_age = 0;
1066  data->reward = 0.0;
1067  }
1068 }
1069 
1070 // clears eligibility traces
1071 void rl_watkins_clear( agent * /*my_agent*/, Symbol *goal )
1072 {
1073  goal->id.rl_info->eligibility_traces->clear();
1074 }