#include <linux/sched.h>
#include <linux/s_context.h>
#include <linux/proc_fs.h>
#include <linux/vmalloc.h>
#include <linux/utsname.h>
#include <linux/unistd.h>

#include <asm/uaccess.h>
/*
#define S_CONTEXT_DEBUG 1
#define S_CONTEXT_DEBUG_LIST 1
*/
/***********************************************************/
rwlock_t s_context_lock = RW_LOCK_UNLOCKED; /* context locker */
unsigned int context_count = 1;
struct s_context root_context = INIT_S_CONTEXT;
/* struct s_context context_1 = INIT_S_CONTEXT; */

/************** procfs start *********************************/

static struct proc_dir_entry *vserver_root = NULL;

static int read_setup_vserver(char *page, char **start, off_t off,
                              int count, int *eof, void *data)
{
   size_t len = 0;

   len  = sprintf(page    , "vserver patch %s\n", VSERVER_VERSION);
   len += sprintf(page+len, "new context syscal %d\n", __NR_new_s_context);
   len += sprintf(page+len, "ipv4root syscal %d\n", __NR_set_ipv4root);

   read_lock(s_context_lock);
   len += sprintf(page+len, "contexts count %d\n",context_count);
   read_unlock(s_context_lock);

   *start = page + off;
   if( off + count >= len  )
   {
       *eof = 1;
   }
   else
   {
       *eof = 0;
   }

   len -= off;
   if (len > count)
   {
       len = count;
   }
   if (len < 0)
   {
       len = 0;
   }

   return len;
}

static int read_ctx_setup(char *page, char **start, off_t off,
                              int count, int *eof, void *data)
{
   size_t len = 0;
   unsigned int i = 0;
   struct s_context *ctx = data;
   struct task_box  *tbox = NULL;


   read_lock(s_context_lock);
   if( ctx == NULL )
   {
       /* self */
       ctx = current -> s_context;
   }

   if( ( current -> s_context == NULL )
      || ( ctx -> id != current -> s_context -> id ) )
   {
       read_unlock(s_context_lock);
       return -EPERM;
   }

   len  = sprintf(page    , "vserver id %d\n", ctx -> id);
   if( ctx -> nodename[0] != 0 )
   {
       len += sprintf(page+len, "nodename %s\n", ctx -> nodename);
   }
   if( ctx -> domainname[0] != 0 )
   {
       len += sprintf(page+len, "domainname %s\n", ctx -> domainname);
   }

   len += sprintf(page+len, "flags: %x - ", ctx -> flags);
   if( (ctx -> flags & S_CTX_INFO_LOCK) != 0 )
   {
       strcat(page+len,"LOCK ");
       len += sizeof("LOCK ")-1;
   }
   if( (ctx -> flags & S_CTX_INFO_SCHED) != 0 )
   {
       strcat(page+len,"SCHED ");
       len += sizeof("SCHED ")-1;
   }
   if( (ctx -> flags & S_CTX_INFO_NPROC) != 0 )
   {
       strcat(page+len,"NPROC ");
       len += sizeof("NPROC ")-1;
   }
   if( (ctx -> flags & S_CTX_INFO_PRIVATE) != 0 )
   {
       strcat(page+len,"PRIVATE ");
       len += sizeof("PRIVATE ")-1;
   }
   *(page+len)='\n';
   len ++;
   len += sprintf(page+len, "init pid: %d\n", ctx -> initpid);
   len += sprintf(page+len, "CapBset: %016x\n", cap_t(ctx -> cap_bset) );
   len += sprintf(page+len, "Process Limit: %d\n", ctx -> PROC_RLIMIT );
   len += sprintf(page+len, "tasks: ");

   for_each_task_in_context(ctx,tbox)
   {
       len += sprintf(page+len, "%d ", tbox->task->pid);
   }

   *(page+len)='\n'; len ++;

   read_unlock(s_context_lock);

   *start = page + off;
   if( off + count >= len  )
   {
       *eof = 1;
   }
   else
   {
       *eof = 0;
   }

   len -= off;
   if (len > count)
   {
       len = count;
   }
   if (len < 0)
   {
       len = 0;
   }

   return len;
}


void context_procfs_init()
{
   vserver_root = proc_mkdir("vservers",0);

   if( vserver_root == NULL )
   {
       return;
   }

   create_proc_read_entry("setup", 0, vserver_root, &read_setup_vserver, NULL );
   create_proc_read_entry("0", 0, vserver_root, &read_ctx_setup, &root_context );
   create_proc_read_entry("self", 0, vserver_root, &read_ctx_setup, NULL );
}


/************** procfs end *********************************/


/************** s_context start *********************************/

/* NEED LOCK s_context_lock */
inline struct s_context* find_s_context_by_num(unsigned int context)
{
   struct s_context* ctx = NULL;

#ifdef S_CONTEXT_DEBUG_LIST
    printk("try find context %d \n", context);
#endif

   if( context >= MAX_S_CONTEXT )
   {
       return NULL;
   }

   for_each_s_context(ctx)
   {
#ifdef S_CONTEXT_DEBUG_LIST
       printk("try find in list %d - %d \n", ctx->id, context);
#endif
       if( ctx->id == context )
       {
           break;
       }
   }

#ifdef S_CONTEXT_DEBUG_LIST
    printk("found in list %d - %d \n", ctx->id, context);
#endif

   return( (ctx->id == context) ? ctx : NULL );
}

void add_s_context_to_list(struct s_context* new_ctx)
{
   struct s_context* ctx;
   char str_id[7];

   if( new_ctx == NULL )
   {
       return;
   }
#ifdef S_CONTEXT_DEBUG_LIST
   printk("add s_context to list. ctx id %d\n", new_ctx -> id );
#endif

   write_lock( s_context_lock );
   ctx = find_s_context_by_num(new_ctx->id);
   if( ctx != NULL )
   {
       write_unlock( s_context_lock );
       printk("s_context error: ctx %d already in list\n", new_ctx -> id);
       vfree( new_ctx );
       return;
   }

   ctx = root_context.prev;
   ctx -> next = new_ctx;
   new_ctx -> next = & root_context;
   new_ctx -> prev = ctx;
   root_context.prev = new_ctx;
   context_count ++;
   write_unlock( s_context_lock );

   if( vserver_root != NULL )
   {
       snprintf(str_id, sizeof( str_id ) - 1, "%d", new_ctx->id);
       create_proc_read_entry(str_id, 0, vserver_root, &read_ctx_setup, new_ctx );
   }
}


void delete_s_context_from_list(unsigned int s_ctx_id)
{
   struct s_context* ctx = NULL;
   char str_id[7];

#ifdef S_CONTEXT_DEBUG_LIST
   printk("delete s_context from list. ctx id %d\n", s_ctx_id );
#endif

   if( s_ctx_id >= MAX_S_CONTEXT )
   {
       return;
   }

   write_lock( s_context_lock );
   ctx = find_s_context_by_num( s_ctx_id );
   if( ctx == NULL )
   {
       write_unlock( s_context_lock );
       printk("s_context error: ctx %d not in list\n", s_ctx_id);
       return;
   }
   ctx -> prev -> next = ctx -> next;
   ctx -> next -> prev = ctx -> prev;
   context_count --;
   write_unlock( s_context_lock );

   if( vserver_root != NULL )
   {
       snprintf(str_id, sizeof( str_id ) - 1, "%d", s_ctx_id);
       remove_proc_entry(str_id, vserver_root);
   }
}

void s_context_addtask(unsigned int context,struct task_struct *task)
{
   struct s_context* ctx = NULL;
   struct task_box *ctx_tasks = NULL;
   struct task_box *addtask = NULL;
#ifdef S_CONTEXT_DEBUG
   printk("s_context_addtask - ctx %d\n", context );
#endif

   read_lock(s_context_lock);

   ctx = find_s_context_by_num(context);
   if( ctx == NULL )
   {
#ifdef S_CONTEXT_DEBUG
       printk("s_context_addtask - not found context !\n");
#endif
       return;
   }

   addtask = vmalloc( sizeof(struct task_box));
   if( addtask == NULL )
   {
#ifdef S_CONTEXT_DEBUG
       printk("s_context_addtask - not have memory for add task !\n");
#endif
       read_unlock( s_context_lock );
       return;
   }

   write_lock(ctx->tasks_lock);
#ifdef S_CONTEXT_DEBUG_LIST
   printk("ctx %x\n task %x addtask %x\n", ctx, task, addtask );
#endif

   task -> s_context = ctx;
   ctx -> task_count ++;
   ctx_tasks = ctx -> tasklist;
   addtask -> task = task;
   addtask -> next = ctx -> tasklist;
   ctx -> tasklist = addtask;

   write_unlock(ctx->tasks_lock);
   read_unlock(s_context_lock);
   return;
}

void s_context_deltask(struct task_struct *task)
{
   struct s_context *ctx;
   struct task_box  *tbox;
   struct task_box  *tbox_prev;
   int delete = 0;
   unsigned int id = 0;

#ifdef S_CONTEXT_DEBUG
   printk("s_context_deltask\n");
#endif

   read_lock(s_context_lock);
   ctx = task->s_context;
   if( ctx == NULL )
   {
       return;
   }
#ifdef S_CONTEXT_DEBUG
   printk("task %x ctx %x\n", task, ctx);
#endif

   write_lock( ctx -> tasks_lock);
#ifdef S_CONTEXT_DEBUG
   printk("try find task %x\n", ctx->tasklist) ;
#endif

  tbox_prev = NULL;
  for_each_task_in_context(ctx,tbox)
  {
#ifdef S_CONTEXT_DEBUG_LIST
       printk("%x - %x\n", tbox->task, task );
#endif
       if( tbox->task == task )
       {
           break;
       }
       tbox_prev = tbox;
  }

  if( tbox != NULL )
  {
#ifdef S_CONTEXT_DEBUG
       printk("found. delete it \n") ;
#endif
       if( tbox_prev != NULL )
       {
           tbox_prev -> next = tbox -> next;
       }
       else
       {
           ctx -> tasklist = tbox -> next;
       }
       vfree(tbox);
   }
   ctx -> task_count --;
   delete = (ctx -> task_count == 0);
   id = ctx->id;
   task -> s_context = NULL;
   write_unlock(ctx -> tasks_lock);
   read_unlock(s_context_lock);

   if( delete != 0 )
   {
#ifdef S_CONTEXT_DEBUG
      printk("delete empty context \n") ;
#endif
      delete_s_context_from_list( id );
      vfree(ctx);
   }

   return;
}

void s_context_init()
{
   struct s_context* ctx;

#ifdef S_CONTEXT_DEBUG
   printk("s_context init\n");
#endif

   s_context_lock = RW_LOCK_UNLOCKED;
   context_count = 1;


   memset(&root_context,0,sizeof(struct s_context));
   root_context.next = &root_context;
   root_context.prev = &root_context;
   root_context.id = 0;
   root_context.cap_bset = CAP_FULL_SET;
/*
   context_1.id = 1;
   add_s_context_to_list( &context_1 );
*/
#ifdef S_CONTEXT_DEBUG
   printk("s_context init\n");
#endif

}

/************** s_context end *********************************/

static int set_initpid (int flags)
{
   int ret = 0;

   if ((flags & S_CTX_INFO_INIT)!=0){
      if (current->s_context == NULL){
         ret = -EINVAL;
      }else if (current->s_context->initpid != 0){
         ret = -EPERM;
      }else{
         current->s_context->initpid = current->tgid;
      }
   }
   return ret;
}

static inline int switch_user_struct(int new_context)
{
   struct user_struct *new_user;

   new_user = alloc_uid(new_context, current->uid);
   if (!new_user)
      return -ENOMEM;

   if (new_user != current->user) {
      struct user_struct *old_user = current->user;

      atomic_inc(&new_user->processes);
      atomic_dec(&old_user->processes);
      current->user = new_user;
      free_uid(old_user);
   }
   return 0;
}

/*
   Change to a new security context and reduce the capability
   basic set of the current process
*/
asmlinkage int
sys_new_s_context(int ctx, __u32 remove_cap, int flags)
{
   int ret = -EPERM;
   struct s_context *new_context;

#ifdef S_CONTEXT_DEBUG
   printk("sys_new_s_context: ctx %d \n", ctx);
#endif
   if( ( ctx <= -2 ) || (ctx > MAX_S_CONTEXT) )
   {
       return -EINVAL;
   }
   switch( ctx )
   {
       /* allocate new s_context. ctx_id if  */
       case -1:
       {
           int new_ctx_id;

           for( new_ctx_id=2; new_ctx_id < MAX_S_CONTEXT; new_ctx_id ++ )
           {
               if( find_s_context_by_num( new_ctx_id ) == NULL )
               {
                   break;
               }
           }

           if( new_ctx_id == MAX_S_CONTEXT )
           {
               /* to full */
               break;
           }
           ret = switch_user_struct(new_ctx_id);
           if( ret != 0 )
           {
               break;
           }
           new_context = vmalloc( sizeof( struct s_context ));
           if( new_context != NULL )
           {
               memset( new_context, 0, sizeof( struct s_context));
               new_context -> id = new_ctx_id;
               new_context -> cap_bset = CAP_INIT_EFF_SET;
               new_context -> cap_bset &= (~remove_cap);
               new_context -> flags |= flags;
             new_context -> PROC_RLIMIT = current ->rlim[RLIMIT_NPROC].rlim_max;
               add_s_context_to_list( new_context );

               s_context_deltask( current );
               s_context_addtask( new_ctx_id, current );
               set_initpid (flags);

               ret = new_ctx_id;
           }
           break;
       }
       /* We keep the same s_context, but lower the capabilities */
       case -2:
       {
           if( current->s_context == NULL ) break;
           ret = set_initpid(flags);
           if (ret == 0)
           {
               /* We keep the same s_context, but lower the capabilities */
               current->s_context->cap_bset &= (~remove_cap);
               ret = current->s_context->id;
               if ((flags & S_CTX_INFO_INIT)!=0)
               {
                   current->s_context->initpid = current->tgid;
               }
               current->s_context->flags |= flags;
             new_context -> PROC_RLIMIT = current ->rlim[RLIMIT_NPROC].rlim_max;
           }
           break;
       }
       /* allocate predefined s_context or join to that */
       default:
       {
           if ( ((current->s_context == NULL) && capable(CAP_SYS_ADMIN) )
                 || ( current->s_context->id == 0 )
                 || ( (current->s_context->flags & S_CTX_INFO_LOCK) == 0 ))
           {
               /* The root context can become any context it wants */
               new_context = find_s_context_by_num( ctx );
               if( ( new_context != NULL ) &&
                  ( (new_context -> flags & S_CTX_INFO_PRIVATE) != 0 ) )
               {
                   ret = -EPERM;
                   break;
               }
               if( new_context == NULL )
               {
                  /* not found - create */
                   new_context = vmalloc( sizeof( struct s_context ));
                   if( new_context == NULL )
                   {
                        /* not have memory - exit*/
                        break;
                   }
                   memset( new_context, 0, sizeof( struct s_context));
                   new_context -> id = ctx;
                   new_context -> cap_bset = CAP_INIT_EFF_SET;
                   new_context -> cap_bset &= (~remove_cap);
                   new_context -> flags |= flags;
                 new_context -> PROC_RLIMIT = current ->rlim[RLIMIT_NPROC].rlim_max;
                   add_s_context_to_list( new_context );
               }
#ifdef S_CONTEXT_DEBUG
               printk("switch to ctx %d \n", ctx);
#endif
               ret = switch_user_struct(ctx);
               if (ret == 0)
               {
#ifdef S_CONTEXT_DEBUG
                   printk("switch ok\n");
#endif
                   s_context_deltask( current );
                   s_context_addtask( ctx, current );
                   set_initpid (flags);
               }
               break;
           }
       }
   }
#ifdef S_CONTEXT_DEBUG
   printk("return %d\n", ret);
#endif

   return ret;
}

asmlinkage int sys_set_ipv4root (__u32 ip[], int nbip, __u32 bcast)
{
   int ret = -EPERM;
   __u32 tbip[NB_IPV4ROOT];
   struct s_context *ctx = current->s_context;

#ifdef S_CONTEXT_DEBUG
   printk("sys_set_ipv4root:");
#endif
   if( ctx == NULL )
   {
       return -EPERM;
   }

   if (nbip < 0 || nbip > NB_IPV4ROOT)
   {
       return -EINVAL;
   }

   if (copy_from_user(tbip,ip,nbip*sizeof(ip[0]))!=0)
   {
       return -EFAULT;
   }

   if ( ctx->ipv4[0] == 0
      || capable(CAP_NET_ADMIN))
   {
       // We are allowed to change everything
       ret = 0;
   }
   else if (ctx -> ipv4[0] != 0 )
   {
       // We are allowed to select a subset of the currently
       // installed IP numbers. No new one allowed
       // We can't change the broadcast address though
       int i;
       int found = 0;

       for (i=0; i<nbip; i++)
       {
           int j;
           __u32 ipi = tbip[i];

           for (j=0; j<ctx->nbipv4; j++)
           {
               if (ipi == ctx->ipv4[j])
               {
                   found++;
                   break;
               }
           }
       }
       if (found == nbip && bcast == ctx->v4_bcast)
       {
           ret = 0;
       }
   } // else if

   if( ret == 0 )
   {
       // assign
       ctx->nbipv4 = nbip;
       memcpy (ctx->ipv4,tbip,nbip*sizeof(tbip[0]));
       ctx->v4_bcast = bcast;
   }
   return ret;
}

